Skip to content

Commit 0f7e438

Browse files
authored
Merge pull request #35 from deeglaze/fwerr
Fix SevFirmwareErr implementation and clean up error conditionals
2 parents 0d57edf + 6a17176 commit 0f7e438

File tree

8 files changed

+50
-33
lines changed

8 files changed

+50
-33
lines changed

abi/amdsp.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,10 @@ const GuestRequestInvalidLength SevFirmwareStatus = 0x100000000
106106

107107
// SevFirmwareErr is an error that interprets firmware status codes from the AMD secure processor.
108108
type SevFirmwareErr struct {
109-
error
110109
Status SevFirmwareStatus
111110
}
112111

113-
func (e SevFirmwareErr) Error() string {
112+
func (e *SevFirmwareErr) Error() string {
114113
if e.Status == Success {
115114
return "success"
116115
}

client/client.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ func message(d Device, command uintptr, req *labi.SnpUserGuestRequest) error {
4646
// indicates a problem certificate length. We need to
4747
// communicate that specifically.
4848
if req.FwErr != 0 {
49-
return abi.SevFirmwareErr{Status: abi.SevFirmwareStatus(req.FwErr)}
49+
return &abi.SevFirmwareErr{Status: abi.SevFirmwareStatus(req.FwErr)}
5050
}
5151
return err
5252
}
5353
if result != uintptr(labi.EsOk) {
54-
return labi.SevEsErr{Result: labi.EsResult(result)}
54+
return &labi.SevEsErr{Result: labi.EsResult(result)}
5555
}
5656
return nil
5757
}
@@ -113,7 +113,7 @@ func getExtendedReportIn(d Device, reportData [64]byte, vmpl int, certs []byte)
113113
}
114114
// Query the length required for certs.
115115
if err := message(d, labi.IocSnpGetExtendedReport, &userGuestReq); err != nil {
116-
var fwErr abi.SevFirmwareErr
116+
var fwErr *abi.SevFirmwareErr
117117
if errors.As(err, &fwErr) && fwErr.Status == abi.GuestRequestInvalidLength {
118118
return nil, snpExtReportReq.CertsLength, nil
119119
}

client/client_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func initDevice() {
4444
now := time.Date(2022, time.May, 3, 9, 0, 0, 0, time.UTC)
4545
for _, tc := range test.TestCases() {
4646
// Don't test faked errors when running real hardware tests.
47-
if !UseDefaultSevGuest() && tc.WantErr != nil {
47+
if !UseDefaultSevGuest() && tc.WantErr != "" {
4848
continue
4949
}
5050
tests = append(tests, tc)
@@ -137,11 +137,11 @@ func TestOpenGetReportClose(t *testing.T) {
137137

138138
// Does the proto report match expectations?
139139
got, err := GetReport(device, tc.Input)
140-
if err != tc.WantErr {
140+
if !test.Match(err, tc.WantErr) {
141141
t.Fatalf("GetReport(device, %v) = %v, %v. Want err: %v", tc.Input, got, err, tc.WantErr)
142142
}
143143

144-
if tc.WantErr == nil {
144+
if tc.WantErr == "" {
145145
cleanReport(got)
146146
want := reportProto
147147
want.Signature = got.Signature // Zeros were placeholders.
@@ -156,10 +156,10 @@ func TestOpenGetRawExtendedReportClose(t *testing.T) {
156156
devMu.Do(initDevice)
157157
for _, tc := range tests {
158158
raw, certs, err := GetRawExtendedReport(device, tc.Input)
159-
if err != tc.WantErr {
159+
if !test.Match(err, tc.WantErr) {
160160
t.Fatalf("%s: GetRawExtendedReport(device, %v) = %v, %v, %v. Want err: %v", tc.Name, tc.Input, raw, certs, err, tc.WantErr)
161161
}
162-
if tc.WantErr == nil {
162+
if tc.WantErr == "" {
163163
if err := cleanRawReport(raw); err != nil {
164164
t.Fatal(err)
165165
}
@@ -189,10 +189,10 @@ func TestOpenGetExtendedReportClose(t *testing.T) {
189189
devMu.Do(initDevice)
190190
for _, tc := range tests {
191191
ereport, err := GetExtendedReport(device, tc.Input)
192-
if err != tc.WantErr {
192+
if !test.Match(err, tc.WantErr) {
193193
t.Fatalf("%s: GetExtendedReport(device, %v) = %v, %v. Want err: %v", tc.Name, tc.Input, ereport, err, tc.WantErr)
194194
}
195-
if tc.WantErr == nil {
195+
if tc.WantErr == "" {
196196
reportProto := &spb.Report{}
197197
if err := prototext.Unmarshal([]byte(tc.OutputProto), reportProto); err != nil {
198198
t.Fatalf("test failure: %v", err)

client/linuxabi/linux_abi.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,10 @@ const (
8282

8383
// SevEsErr is an error that interprets SEV-ES guest-host communication results.
8484
type SevEsErr struct {
85-
error
8685
Result EsResult
8786
}
8887

89-
func (err SevEsErr) Error() string {
88+
func (err *SevEsErr) Error() string {
9089
if err.Result == EsUnsupported {
9190
return "requested operation not supported"
9291
}

kds/kds_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"encoding/hex"
1919
"fmt"
2020
"net/url"
21+
"strings"
2122
"testing"
2223

2324
"github.com/google/go-cmp/cmp"
@@ -80,7 +81,7 @@ func TestParseProductBaseURL(t *testing.T) {
8081
for _, tc := range tcs {
8182
t.Run(tc.name, func(t *testing.T) {
8283
gotProduct, gotURL, err := parseBaseProductURL(tc.url)
83-
if (err != nil && err.Error() != tc.wantErr) || (err == nil && tc.wantErr != "") {
84+
if (err == nil && tc.wantErr != "") || (err != nil && !strings.Contains(err.Error(), tc.wantErr)) {
8485
t.Fatalf("parseBaseProductURL(%q) = _, _, %v, want %q", tc.url, err, tc.wantErr)
8586
}
8687
if err == nil {
@@ -144,7 +145,7 @@ func TestParseVCEKCertURL(t *testing.T) {
144145
for _, tc := range tcs {
145146
t.Run(tc.name, func(t *testing.T) {
146147
got, err := ParseVCEKCertURL(tc.url)
147-
if (err != nil && err.Error() != tc.wantErr) || (err == nil && tc.wantErr != "") {
148+
if (err == nil && tc.wantErr != "") || (err != nil && !strings.Contains(err.Error(), tc.wantErr)) {
148149
t.Fatalf("ParseVCEKCertURL(%q) = _, %v, want %q", tc.url, err, tc.wantErr)
149150
}
150151
if err == nil {

testing/match.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package testing
16+
17+
import "strings"
18+
19+
// Match returns true iff both errors match expectations closely enough
20+
func Match(got error, want string) bool {
21+
if got == nil {
22+
return want == ""
23+
}
24+
return strings.Contains(got.Error(), want)
25+
}

testing/test_cases.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ type TestCase struct {
134134
OutputProto string
135135
FwErr abi.SevFirmwareStatus
136136
EsResult labi.EsResult
137-
WantErr error
137+
WantErr string
138138
}
139139

140140
// TestCases returns common test cases for get_report.
@@ -158,7 +158,7 @@ func TestCases() []TestCase {
158158
Name: "fw oom",
159159
Input: userZeros11,
160160
FwErr: abi.ResourceLimit,
161-
WantErr: abi.SevFirmwareErr{Status: abi.ResourceLimit},
161+
WantErr: (&abi.SevFirmwareErr{Status: abi.ResourceLimit}).Error(),
162162
},
163163
}
164164
}

verify/verify_test.go

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"math/big"
2424
"math/rand"
2525
"os"
26-
"strings"
2726
"sync"
2827
"testing"
2928
"time"
@@ -151,10 +150,10 @@ func TestSnpReportSignature(t *testing.T) {
151150
}
152151
// Does the Raw report match expectations?
153152
raw, err := sg.GetRawReport(d, tc.Input)
154-
if err != tc.WantErr {
153+
if !test.Match(err, tc.WantErr) {
155154
t.Fatalf("GetRawReport(d, %v) = %v, %v. Want err: %v", tc.Input, raw, err, tc.WantErr)
156155
}
157-
if tc.WantErr == nil {
156+
if tc.WantErr == "" {
158157
got := abi.SignedComponent(raw)
159158
want := abi.SignedComponent(tc.Output[:])
160159
if !bytes.Equal(got, want) {
@@ -303,14 +302,8 @@ func TestKdsMetadataLogic(t *testing.T) {
303302
options = &Options{}
304303
}
305304
vcek, _, err := VcekDER(newSigner.Vcek.Raw, newSigner.Ask.Raw, newSigner.Ark.Raw, options)
306-
if err == nil && tc.wantErr != "" {
307-
t.Errorf("%s: VcekDER(...) = %+v did not error as expected.", tc.name, vcek)
308-
}
309-
if err != nil && tc.wantErr == "" {
310-
t.Errorf("%s: VcekDER(...) errored unexpectedly: %v", tc.name, err)
311-
}
312-
if err != nil && tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr) {
313-
t.Errorf("%s: VcekDER(...) did not error as expected. Got %v, want %s", tc.name, err, tc.wantErr)
305+
if !test.Match(err, tc.wantErr) {
306+
t.Errorf("%s: VcekDER(...) = %+v, %v did not error as expected. Want %q", tc.name, vcek, err, tc.wantErr)
314307
}
315308
}
316309
}
@@ -374,7 +367,7 @@ func TestCRLRootValidity(t *testing.T) {
374367
},
375368
}
376369
wantErr := "CRL is not signed by ARK"
377-
if err := VcekNotRevoked(root, g2, signer2.Vcek); err == nil || !strings.Contains(err.Error(), wantErr) {
370+
if err := VcekNotRevoked(root, g2, signer2.Vcek); !test.Match(err, wantErr) {
378371
t.Errorf("Bad Root: VcekNotRevoked(%v) did not error as expected. Got %v, want %v", signer.Vcek, err, wantErr)
379372
}
380373

@@ -385,7 +378,7 @@ func TestCRLRootValidity(t *testing.T) {
385378
AskX509: signer2.Ask,
386379
}
387380
wantErr2 := "ASK was revoked at 2022-06-14 12:01:00 +0000 UTC"
388-
if err := VcekNotRevoked(root2, g2, signer2.Vcek); err == nil || !strings.Contains(err.Error(), wantErr2) {
381+
if err := VcekNotRevoked(root2, g2, signer2.Vcek); !test.Match(err, wantErr2) {
389382
t.Errorf("Bad ASK: VcekNotRevoked(%v) did not error as expected. Got %v, want %v", signer.Vcek, err, wantErr2)
390383
}
391384
}
@@ -422,10 +415,10 @@ func TestOpenGetExtendedReportVerifyClose(t *testing.T) {
422415
}
423416
for _, getReport := range reportGetters {
424417
ereport, err := getReport.getter(d, tc.Input)
425-
if err != tc.WantErr {
418+
if !test.Match(err, tc.WantErr) {
426419
t.Fatalf("%s: %s(d, %v) = %v, %v. Want err: %v", tc.Name, getReport.name, tc.Input, ereport, err, tc.WantErr)
427420
}
428-
if tc.WantErr == nil {
421+
if tc.WantErr == "" {
429422
if err := SnpAttestation(ereport, options); err != nil {
430423
t.Errorf("SnpAttestation(%v) errored unexpectedly: %v", ereport, err)
431424
}

0 commit comments

Comments
 (0)