Skip to content

Commit 81973d9

Browse files
committed
feat: replace custom joined error with errors.Join
Signed-off-by: Mark Sagi-Kazar <mark.sagikazar@gmail.com>
1 parent 114dfbb commit 81973d9

File tree

4 files changed

+58
-54
lines changed

4 files changed

+58
-54
lines changed

internal/errors/errors.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,7 @@ import "errors"
55
func New(text string) error {
66
return errors.New(text)
77
}
8+
9+
func As(err error, target interface{}) bool {
10+
return errors.As(err, target)
11+
}

mapstructure.go

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,15 @@ func NewDecoder(config *DecoderConfig) (*Decoder, error) {
415415
// Decode decodes the given raw interface to the target pointer specified
416416
// by the configuration.
417417
func (d *Decoder) Decode(input interface{}) error {
418-
return d.decode("", input, reflect.ValueOf(d.config.Result).Elem())
418+
err := d.decode("", input, reflect.ValueOf(d.config.Result).Elem())
419+
420+
// Retain some of the original behavior when multiple errors ocurr
421+
var joinedErr interface{ Unwrap() []error }
422+
if errors.As(err, &joinedErr) {
423+
return fmt.Errorf("decoding failed due to the following error(s):\n\n%w", err)
424+
}
425+
426+
return err
419427
}
420428

421429
// Decodes an unknown data type into a specific reflection value.
@@ -882,7 +890,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
882890
valElemType := valType.Elem()
883891

884892
// Accumulate errors
885-
errors := make([]string, 0)
893+
var errs []error
886894

887895
// If the input data is empty, then we just match what the input data is.
888896
if dataVal.Len() == 0 {
@@ -904,15 +912,15 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
904912
// First decode the key into the proper type
905913
currentKey := reflect.Indirect(reflect.New(valKeyType))
906914
if err := d.decode(fieldName, k.Interface(), currentKey); err != nil {
907-
errors = appendErrors(errors, err)
915+
errs = append(errs, err)
908916
continue
909917
}
910918

911919
// Next decode the data into the proper type
912920
v := dataVal.MapIndex(k).Interface()
913921
currentVal := reflect.Indirect(reflect.New(valElemType))
914922
if err := d.decode(fieldName, v, currentVal); err != nil {
915-
errors = appendErrors(errors, err)
923+
errs = append(errs, err)
916924
continue
917925
}
918926

@@ -922,12 +930,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
922930
// Set the built up map to the value
923931
val.Set(valMap)
924932

925-
// If we had errors, return those
926-
if len(errors) > 0 {
927-
return &joinedError{errors}
928-
}
929-
930-
return nil
933+
return errors.Join(errs...)
931934
}
932935

933936
func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error {
@@ -1165,7 +1168,7 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)
11651168
}
11661169

11671170
// Accumulate any errors
1168-
errors := make([]string, 0)
1171+
var errs []error
11691172

11701173
for i := 0; i < dataVal.Len(); i++ {
11711174
currentData := dataVal.Index(i).Interface()
@@ -1176,19 +1179,14 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)
11761179

11771180
fieldName := name + "[" + strconv.Itoa(i) + "]"
11781181
if err := d.decode(fieldName, currentData, currentField); err != nil {
1179-
errors = appendErrors(errors, err)
1182+
errs = append(errs, err)
11801183
}
11811184
}
11821185

11831186
// Finally, set the value to the slice we built up
11841187
val.Set(valSlice)
11851188

1186-
// If there were errors, we return those
1187-
if len(errors) > 0 {
1188-
return &joinedError{errors}
1189-
}
1190-
1191-
return nil
1189+
return errors.Join(errs...)
11921190
}
11931191

11941192
func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) error {
@@ -1234,27 +1232,22 @@ func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value)
12341232
}
12351233

12361234
// Accumulate any errors
1237-
errors := make([]string, 0)
1235+
var errs []error
12381236

12391237
for i := 0; i < dataVal.Len(); i++ {
12401238
currentData := dataVal.Index(i).Interface()
12411239
currentField := valArray.Index(i)
12421240

12431241
fieldName := name + "[" + strconv.Itoa(i) + "]"
12441242
if err := d.decode(fieldName, currentData, currentField); err != nil {
1245-
errors = appendErrors(errors, err)
1243+
errs = append(errs, err)
12461244
}
12471245
}
12481246

12491247
// Finally, set the value to the array we built up
12501248
val.Set(valArray)
12511249

1252-
// If there were errors, we return those
1253-
if len(errors) > 0 {
1254-
return &joinedError{errors}
1255-
}
1256-
1257-
return nil
1250+
return errors.Join(errs...)
12581251
}
12591252

12601253
func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) error {
@@ -1316,7 +1309,8 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
13161309
}
13171310

13181311
targetValKeysUnused := make(map[interface{}]struct{})
1319-
errors := make([]string, 0)
1312+
1313+
var errs []error
13201314

13211315
// This slice will keep track of all the structs we'll be decoding.
13221316
// There can be more than one struct if there are embedded structs
@@ -1370,8 +1364,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
13701364

13711365
if squash {
13721366
if fieldVal.Kind() != reflect.Struct {
1373-
errors = appendErrors(errors,
1374-
fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind()))
1367+
errs = append(errs, fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind()))
13751368
} else {
13761369
structs = append(structs, fieldVal)
13771370
}
@@ -1450,7 +1443,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
14501443
}
14511444

14521445
if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil {
1453-
errors = appendErrors(errors, err)
1446+
errs = append(errs, err)
14541447
}
14551448
}
14561449

@@ -1465,7 +1458,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
14651458

14661459
// Decode it as-if we were just decoding this map onto our map.
14671460
if err := d.decodeMap(name, remain, remainField.val); err != nil {
1468-
errors = appendErrors(errors, err)
1461+
errs = append(errs, err)
14691462
}
14701463

14711464
// Set the map to nil so we have none so that the next check will
@@ -1481,7 +1474,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
14811474
sort.Strings(keys)
14821475

14831476
err := fmt.Errorf("'%s' has invalid keys: %s", name, strings.Join(keys, ", "))
1484-
errors = appendErrors(errors, err)
1477+
errs = append(errs, err)
14851478
}
14861479

14871480
if d.config.ErrorUnset && len(targetValKeysUnused) > 0 {
@@ -1492,11 +1485,11 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
14921485
sort.Strings(keys)
14931486

14941487
err := fmt.Errorf("'%s' has unset fields: %s", name, strings.Join(keys, ", "))
1495-
errors = appendErrors(errors, err)
1488+
errs = append(errs, err)
14961489
}
14971490

1498-
if len(errors) > 0 {
1499-
return &joinedError{errors}
1491+
if err := errors.Join(errs...); err != nil {
1492+
return err
15001493
}
15011494

15021495
// Add the unused keys to the list of unused keys if we're tracking metadata

mapstructure_examples_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ func ExampleDecode_errors() {
6363

6464
fmt.Println(err.Error())
6565
// Output:
66-
// 5 error(s) decoding:
66+
// decoding failed due to the following error(s):
6767
//
68-
// * 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value'
69-
// * 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1'
70-
// * 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2'
71-
// * 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3'
72-
// * 'Name' expected type 'string', got unconvertible type 'int', value: '123'
68+
// 'Name' expected type 'string', got unconvertible type 'int', value: '123'
69+
// 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value'
70+
// 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1'
71+
// 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2'
72+
// 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3'
7373
}
7474

7575
func ExampleDecode_metadata() {

mapstructure_test.go

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package mapstructure
22

33
import (
44
"encoding/json"
5+
"errors"
56
"io"
67
"reflect"
78
"sort"
@@ -2323,13 +2324,17 @@ func TestInvalidType(t *testing.T) {
23232324
t.Fatal("error should exist")
23242325
}
23252326

2326-
derr, ok := err.(*joinedError)
2327-
if !ok {
2328-
t.Fatalf("error should be kind of joinedError, instead: %#v", err)
2327+
var derr interface {
2328+
Unwrap() []error
2329+
}
2330+
2331+
if !errors.As(err, &derr) {
2332+
t.Fatalf("error should be a type implementing Unwrap() []error, instead: %#v", err)
23292333
}
23302334

2331-
if derr.Errors[0] !=
2332-
"'Vstring' expected type 'string', got unconvertible type 'int', value: '42'" {
2335+
errs := derr.Unwrap()
2336+
2337+
if errs[0].Error() != "'Vstring' expected type 'string', got unconvertible type 'int', value: '42'" {
23332338
t.Errorf("got unexpected error: %s", err)
23342339
}
23352340

@@ -2342,12 +2347,13 @@ func TestInvalidType(t *testing.T) {
23422347
t.Fatal("error should exist")
23432348
}
23442349

2345-
derr, ok = err.(*joinedError)
2346-
if !ok {
2347-
t.Fatalf("error should be kind of joinedError, instead: %#v", err)
2350+
if !errors.As(err, &derr) {
2351+
t.Fatalf("error should be a type implementing Unwrap() []error, instead: %#v", err)
23482352
}
23492353

2350-
if derr.Errors[0] != "cannot parse 'Vuint', -42 overflows uint" {
2354+
errs = derr.Unwrap()
2355+
2356+
if errs[0].Error() != "cannot parse 'Vuint', -42 overflows uint" {
23512357
t.Errorf("got unexpected error: %s", err)
23522358
}
23532359

@@ -2360,12 +2366,13 @@ func TestInvalidType(t *testing.T) {
23602366
t.Fatal("error should exist")
23612367
}
23622368

2363-
derr, ok = err.(*joinedError)
2364-
if !ok {
2365-
t.Fatalf("error should be kind of joinedError, instead: %#v", err)
2369+
if !errors.As(err, &derr) {
2370+
t.Fatalf("error should be a type implementing Unwrap() []error, instead: %#v", err)
23662371
}
23672372

2368-
if derr.Errors[0] != "cannot parse 'Vuint', -42.000000 overflows uint" {
2373+
errs = derr.Unwrap()
2374+
2375+
if errs[0].Error() != "cannot parse 'Vuint', -42.000000 overflows uint" {
23692376
t.Errorf("got unexpected error: %s", err)
23702377
}
23712378
}

0 commit comments

Comments
 (0)