Skip to content

Commit

Permalink
Merge pull request #376 from fxamacker/feature/bytestring
Browse files Browse the repository at this point in the history
Define ByteString type to support CBOR byte string as map keys and other uses.

Go doesn't allow []byte as map key, so ByteString can be used to support data formats
having CBOR map with byte string keys. Another use for ByteString is to encode
invalid UTF-8 string as CBOR byte string.
  • Loading branch information
fxamacker authored Dec 29, 2022
2 parents 1c59246 + 950368c commit 7c3a599
Show file tree
Hide file tree
Showing 5 changed files with 467 additions and 118 deletions.
62 changes: 62 additions & 0 deletions bytestring.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Faye Amacker. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

package cbor

import (
"errors"
)

// ByteString represents CBOR byte string (major type 2). ByteString can be used
// when using a Go []byte is not possible or convenient. For example, Go doesn't
// allow []byte as map key, so ByteString can be used to support data formats
// having CBOR map with byte string keys. ByteString can also be used to
// encode invalid UTF-8 string as CBOR byte string.
// See DecOption.MapKeyByteStringMode for more details.
type ByteString string

// Bytes returns bytes representing ByteString.
func (bs ByteString) Bytes() []byte {
return []byte(bs)
}

// MarshalCBOR encodes ByteString as CBOR byte string (major type 2).
func (bs ByteString) MarshalCBOR() ([]byte, error) {
e := getEncoderBuffer()
defer putEncoderBuffer(e)

// Encode length
encodeHead(e, byte(cborTypeByteString), uint64(len(bs)))

// Encode data
buf := make([]byte, e.Len()+len(bs))
n := copy(buf, e.Bytes())
copy(buf[n:], bs)

return buf, nil
}

// UnmarshalCBOR decodes CBOR byte string (major type 2) to ByteString.
// Decoding CBOR null and CBOR undefined sets ByteString to be empty.
func (bs *ByteString) UnmarshalCBOR(data []byte) error {
if bs == nil {
return errors.New("cbor.ByteString: UnmarshalCBOR on nil pointer")
}

// Decoding CBOR null and CBOR undefined to ByteString resets data.
// This behavior is similar to decoding CBOR null and CBOR undefined to []byte.
if len(data) == 1 && (data[0] == 0xf6 || data[0] == 0xf7) {
*bs = ""
return nil
}

d := decoder{data: data, dm: defaultDecMode}

// Check if CBOR data type is byte string
if typ := d.nextCBORType(); typ != cborTypeByteString {
return &UnmarshalTypeError{CBORType: typ.String(), GoType: typeByteString.String()}
}

*bs = ByteString(d.parseByteString())
return nil
}
101 changes: 101 additions & 0 deletions bytestring_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) Faye Amacker. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

package cbor

import "testing"

func TestByteString(t *testing.T) {
type s1 struct {
A ByteString `cbor:"a"`
}
type s2 struct {
A *ByteString `cbor:"a"`
}
type s3 struct {
A ByteString `cbor:"a,omitempty"`
}
type s4 struct {
A *ByteString `cbor:"a,omitempty"`
}

emptybs := ByteString("")
bs := ByteString("\x01\x02\x03\x04")

testCases := []roundTripTest{
{
name: "empty",
obj: emptybs,
wantCborData: hexDecode("40"),
},
{
name: "not empty",
obj: bs,
wantCborData: hexDecode("4401020304"),
},
{
name: "array",
obj: []ByteString{bs},
wantCborData: hexDecode("814401020304"),
},
{
name: "map with ByteString key",
obj: map[ByteString]bool{bs: true},
wantCborData: hexDecode("a14401020304f5"),
},
{
name: "empty ByteString field",
obj: s1{},
wantCborData: hexDecode("a1616140"),
},
{
name: "not empty ByteString field",
obj: s1{A: bs},
wantCborData: hexDecode("a161614401020304"),
},
{
name: "nil *ByteString field",
obj: s2{},
wantCborData: hexDecode("a16161f6"),
},
{
name: "empty *ByteString field",
obj: s2{A: &emptybs},
wantCborData: hexDecode("a1616140"),
},
{
name: "not empty *ByteString field",
obj: s2{A: &bs},
wantCborData: hexDecode("a161614401020304"),
},
{
name: "empty ByteString field with omitempty option",
obj: s3{},
wantCborData: hexDecode("a0"),
},
{
name: "not empty ByteString field with omitempty option",
obj: s3{A: bs},
wantCborData: hexDecode("a161614401020304"),
},
{
name: "nil *ByteString field with omitempty option",
obj: s4{},
wantCborData: hexDecode("a0"),
},
{
name: "empty *ByteString field with omitempty option",
obj: s4{A: &emptybs},
wantCborData: hexDecode("a1616140"),
},
{
name: "not empty *ByteString field with omitempty option",
obj: s4{A: &bs},
wantCborData: hexDecode("a161614401020304"),
},
}

em, _ := EncOptions{}.EncMode()
dm, _ := DecOptions{}.DecMode()
testRoundTrip(t, testCases, em, dm)
}
103 changes: 95 additions & 8 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ func (e *UnmarshalTypeError) Error() string {
return s
}

// InvalidMapKeyTypeError describes invalid Go map key type when decoding CBOR map.
// For example, Go doesn't allow slice as map key.
type InvalidMapKeyTypeError struct {
GoType string
}

func (e *InvalidMapKeyTypeError) Error() string {
return "cbor: invalid map key type: " + e.GoType
}

// DupMapKeyError describes detected duplicate map key in CBOR map.
type DupMapKeyError struct {
Key interface{}
Expand Down Expand Up @@ -239,6 +249,36 @@ func (idm IntDecMode) valid() bool {
return idm < maxIntDec
}

// MapKeyByteStringMode specifies how to decode CBOR byte string (major type 2)
// as Go map key when decoding CBOR map key into an empty Go interface value.
// Specifically, this option applies when decoding CBOR map into
// - Go empty interface, or
// - Go map with empty interface as key type.
// The CBOR map key types handled by this option are
// - byte string
// - tagged byte string
// - nested tagged byte string
type MapKeyByteStringMode int

const (
// MapKeyByteStringAllowed allows CBOR byte string to be decoded as Go map key.
// Since Go doesn't allow []byte as map key, CBOR byte string is decoded to
// ByteString which has underlying string type.
// This is the default setting.
MapKeyByteStringAllowed MapKeyByteStringMode = iota

// MapKeyByteStringForbidden forbids CBOR byte string being decoded as Go map key.
// Attempting to decode CBOR byte string as map key into empty interface value
// returns a decoding error.
MapKeyByteStringForbidden

maxMapKeyByteStringMode
)

func (mkbsm MapKeyByteStringMode) valid() bool {
return mkbsm < maxMapKeyByteStringMode
}

// ExtraDecErrorCond specifies extra conditions that should be treated as errors.
type ExtraDecErrorCond uint

Expand Down Expand Up @@ -309,6 +349,12 @@ type DecOptions struct {
// when decoding CBOR int (major type 0 and 1) to Go interface{}.
IntDec IntDecMode

// MapKeyByteString specifies how to decode CBOR byte string as map key
// when decoding CBOR map with byte string key into an empty interface value.
// By default, an error is returned when attempting to decode CBOR byte string
// as map key because Go doesn't allow []byte as map key.
MapKeyByteString MapKeyByteStringMode

// ExtraReturnErrors specifies extra conditions that should be treated as errors.
ExtraReturnErrors ExtraDecErrorCond

Expand Down Expand Up @@ -401,6 +447,9 @@ func (opts DecOptions) decMode() (*decMode, error) {
if !opts.IntDec.valid() {
return nil, errors.New("cbor: invalid IntDec " + strconv.Itoa(int(opts.IntDec)))
}
if !opts.MapKeyByteString.valid() {
return nil, errors.New("cbor: invalid MapKeyByteString " + strconv.Itoa(int(opts.MapKeyByteString)))
}
if opts.MaxNestedLevels == 0 {
opts.MaxNestedLevels = 32
} else if opts.MaxNestedLevels < 4 || opts.MaxNestedLevels > 65535 {
Expand Down Expand Up @@ -434,6 +483,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
indefLength: opts.IndefLength,
tagsMd: opts.TagsMd,
intDec: opts.IntDec,
mapKeyByteString: opts.MapKeyByteString,
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
Expand Down Expand Up @@ -467,6 +517,7 @@ type decMode struct {
indefLength IndefLengthMode
tagsMd TagsMode
intDec IntDecMode
mapKeyByteString MapKeyByteStringMode
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
Expand All @@ -485,6 +536,7 @@ func (dm *decMode) DecOptions() DecOptions {
IndefLength: dm.indefLength,
TagsMd: dm.tagsMd,
IntDec: dm.intDec,
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
UTF8: dm.utf8,
}
Expand Down Expand Up @@ -1220,11 +1272,17 @@ func (d *decoder) parseMap() (interface{}, error) {
// Detect if CBOR map key can be used as Go map key.
rv := reflect.ValueOf(k)
if !isHashableValue(rv) {
if err == nil {
err = errors.New("cbor: invalid map key type: " + rv.Type().String())
var converted bool
if d.dm.mapKeyByteString == MapKeyByteStringAllowed {
k, converted = convertByteSliceToByteString(k)
}
if !converted {
if err == nil {
err = &InvalidMapKeyTypeError{rv.Type().String()}
}
d.skip()
continue
}
d.skip()
continue
}

// Parse CBOR map value.
Expand Down Expand Up @@ -1306,11 +1364,21 @@ func (d *decoder) parseMapToMap(v reflect.Value, tInfo *typeInfo) error { //noli
// Detect if CBOR map key can be used as Go map key.
if keyIsInterfaceType && keyValue.Elem().IsValid() {
if !isHashableValue(keyValue.Elem()) {
if err == nil {
err = errors.New("cbor: invalid map key type: " + keyValue.Elem().Type().String())
var converted bool
if d.dm.mapKeyByteString == MapKeyByteStringAllowed {
var k interface{}
k, converted = convertByteSliceToByteString(keyValue.Elem().Interface())
if converted {
keyValue.Set(reflect.ValueOf(k))
}
}
if !converted {
if err == nil {
err = &InvalidMapKeyTypeError{keyValue.Elem().Type().String()}
}
d.skip()
continue
}
d.skip()
continue
}
}

Expand Down Expand Up @@ -1936,6 +2004,25 @@ func isHashableValue(rv reflect.Value) bool {
return true
}

// convertByteSliceToByteString converts []byte to ByteString if
// - v is []byte type, or
// - v is Tag type and tag content type is []byte
// This function also handles nested tags.
// CBOR data is already verified to be well-formed before this function is used,
// so the recursion won't exceed max nested levels.
func convertByteSliceToByteString(v interface{}) (interface{}, bool) {
switch v := v.(type) {
case []byte:
return ByteString(v), true
case Tag:
content, converted := convertByteSliceToByteString(v.Content)
if converted {
return Tag{Number: v.Number, Content: content}, true
}
}
return v, false
}

// validBuiltinTag checks that supported built-in tag numbers are followed by expected content types.
func validBuiltinTag(tagNum uint64, contentHead byte) error {
t := cborType(contentHead & 0xe0)
Expand Down
Loading

0 comments on commit 7c3a599

Please sign in to comment.