From 950368c2bc303fb2d861bb496ceb85728c4672c5 Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Wed, 28 Dec 2022 20:05:45 -0600 Subject: [PATCH] Support decoding byte string map key into Go map Added support for decoding CBOR map with byte string keys to Go map with empty interface key type. This supports decoding these CBOR map key types: - byte string - tagged byte string - nested tagged byte string Previous commit already added support for decoding CBOR maps with byte string keys to empty Go interface value. --- decode.go | 27 +++++++++++++--- decode_test.go | 84 ++++++++++++++++++++++++++++---------------------- 2 files changed, 69 insertions(+), 42 deletions(-) diff --git a/decode.go b/decode.go index 2a0d69db..c444451e 100644 --- a/decode.go +++ b/decode.go @@ -250,7 +250,14 @@ func (idm IntDecMode) valid() bool { } // MapKeyByteStringMode specifies how to decode CBOR byte string (major type 2) -// as Go map key when decoding CBOR map into an empty Go interface value. +// as Go map key when decoding CBOR map key into an empty Go interface value. +// Specifically, this option applies when decoding CBOR map into +// - Go empty interface, or +// - Go map with empty interface as key type. +// The CBOR map key types handled by this option are +// - byte string +// - tagged byte string +// - nested tagged byte string type MapKeyByteStringMode int const ( @@ -1357,11 +1364,21 @@ func (d *decoder) parseMapToMap(v reflect.Value, tInfo *typeInfo) error { //noli // Detect if CBOR map key can be used as Go map key. if keyIsInterfaceType && keyValue.Elem().IsValid() { if !isHashableValue(keyValue.Elem()) { - if err == nil { - err = &InvalidMapKeyTypeError{keyValue.Elem().Type().String()} + var converted bool + if d.dm.mapKeyByteString == MapKeyByteStringAllowed { + var k interface{} + k, converted = convertByteSliceToByteString(keyValue.Elem().Interface()) + if converted { + keyValue.Set(reflect.ValueOf(k)) + } + } + if !converted { + if err == nil { + err = &InvalidMapKeyTypeError{keyValue.Elem().Type().String()} + } + d.skip() + continue } - d.skip() - continue } } diff --git a/decode_test.go b/decode_test.go index b5ae2baf..f0dcc9e8 100644 --- a/decode_test.go +++ b/decode_test.go @@ -35,6 +35,7 @@ var ( typeByteArray = reflect.TypeOf([5]byte{}) typeIntSlice = reflect.TypeOf([]int{}) typeStringSlice = reflect.TypeOf([]string{}) + typeMapIntfIntf = reflect.TypeOf(map[interface{}]interface{}{}) typeMapStringInt = reflect.TypeOf(map[string]int{}) typeMapStringString = reflect.TypeOf(map[string]string{}) typeMapStringIntf = reflect.TypeOf(map[string]interface{}{}) @@ -4669,30 +4670,40 @@ func TestDecModeInvalidMapKeyByteString(t *testing.T) { } func TestMapKeyByteString(t *testing.T) { + bsForbiddenMode, err := DecOptions{MapKeyByteString: MapKeyByteStringForbidden}.DecMode() + if err != nil { + t.Errorf("DecMode() returned an error %+v", err) + } + + bsAllowedMode, err := DecOptions{MapKeyByteString: MapKeyByteStringAllowed}.DecMode() + if err != nil { + t.Errorf("DecMode() returned an error %+v", err) + } + testCases := []struct { - name string - cborData []byte - wantObj interface{} - wantErrorMsg string - mapKeyByteString MapKeyByteStringMode + name string + cborData []byte + wantObj interface{} + wantErrorMsg string + dm DecMode }{ { - name: "byte string map key with MapKeyByteStringForbidden", - cborData: hexDecode("a143abcdef187b"), - wantErrorMsg: "cbor: invalid map key type: []uint8", - mapKeyByteString: MapKeyByteStringForbidden, + name: "byte string map key with MapKeyByteStringForbidden", + cborData: hexDecode("a143abcdef187b"), + wantErrorMsg: "cbor: invalid map key type: []uint8", + dm: bsForbiddenMode, }, { - name: "tagged byte string map key with MapKeyByteStringForbidden", - cborData: hexDecode("a1d86443abcdef187b"), - wantErrorMsg: "cbor: invalid map key type: cbor.Tag", - mapKeyByteString: MapKeyByteStringForbidden, + name: "tagged byte string map key with MapKeyByteStringForbidden", + cborData: hexDecode("a1d86443abcdef187b"), + wantErrorMsg: "cbor: invalid map key type: cbor.Tag", + dm: bsForbiddenMode, }, { - name: "nested tagged byte string map key with MapKeyByteStringForbidden", - cborData: hexDecode("a1d865d86443abcdef187b"), - wantErrorMsg: "cbor: invalid map key type: cbor.Tag", - mapKeyByteString: MapKeyByteStringForbidden, + name: "nested tagged byte string map key with MapKeyByteStringForbidden", + cborData: hexDecode("a1d865d86443abcdef187b"), + wantErrorMsg: "cbor: invalid map key type: cbor.Tag", + dm: bsForbiddenMode, }, { name: "byte string map key with MapKeyByteStringAllowed", @@ -4700,7 +4711,7 @@ func TestMapKeyByteString(t *testing.T) { wantObj: map[interface{}]interface{}{ ByteString("\xab\xcd\xef"): uint64(123), }, - mapKeyByteString: MapKeyByteStringAllowed, + dm: bsAllowedMode, }, { name: "tagged byte string map key with MapKeyByteStringAllowed", @@ -4708,7 +4719,7 @@ func TestMapKeyByteString(t *testing.T) { wantObj: map[interface{}]interface{}{ Tag{Number: 100, Content: ByteString("\xab\xcd\xef")}: uint64(123), }, - mapKeyByteString: MapKeyByteStringAllowed, + dm: bsAllowedMode, }, { name: "nested tagged byte string map key with MapKeyByteStringAllowed", @@ -4716,28 +4727,27 @@ func TestMapKeyByteString(t *testing.T) { wantObj: map[interface{}]interface{}{ Tag{Number: 101, Content: Tag{Number: 100, Content: ByteString("\xab\xcd\xef")}}: uint64(123), }, - mapKeyByteString: MapKeyByteStringAllowed, + dm: bsAllowedMode, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - dm, err := DecOptions{MapKeyByteString: tc.mapKeyByteString}.DecMode() - if err != nil { - t.Errorf("DecMode() returned an error %+v", err) - } - var v interface{} - err = dm.Unmarshal(tc.cborData, &v) - if err == nil { - if tc.wantErrorMsg != "" { - t.Errorf("Unmarshal(0x%x) didn't return an error, want %q", tc.cborData, tc.wantErrorMsg) - } else if !reflect.DeepEqual(v, tc.wantObj) { - t.Errorf("Unmarshal(0x%x) return %v (%T), want %v (%T)", tc.cborData, v, v, tc.wantObj, tc.wantObj) - } - } else { - if tc.wantErrorMsg == "" { - t.Errorf("Unmarshal(0x%x) returned error %q", tc.cborData, err) - } else if !strings.Contains(err.Error(), tc.wantErrorMsg) { - t.Errorf("Unmarshal(0x%x) returned error %q, want %q", tc.cborData, err.Error(), tc.wantErrorMsg) + for _, typ := range []reflect.Type{typeIntf, typeMapIntfIntf} { + v := reflect.New(typ) + vPtr := v.Interface() + err = tc.dm.Unmarshal(tc.cborData, vPtr) + if err == nil { + if tc.wantErrorMsg != "" { + t.Errorf("Unmarshal(0x%x) didn't return an error, want %q", tc.cborData, tc.wantErrorMsg) + } else if !reflect.DeepEqual(v.Elem().Interface(), tc.wantObj) { + t.Errorf("Unmarshal(0x%x) return %v (%T), want %v (%T)", tc.cborData, v.Elem().Interface(), v.Elem().Interface(), tc.wantObj, tc.wantObj) + } + } else { + if tc.wantErrorMsg == "" { + t.Errorf("Unmarshal(0x%x) returned error %q", tc.cborData, err) + } else if !strings.Contains(err.Error(), tc.wantErrorMsg) { + t.Errorf("Unmarshal(0x%x) returned error %q, want %q", tc.cborData, err.Error(), tc.wantErrorMsg) + } } } })