diff --git a/decode.go b/decode.go index 0c8b398c..ca1921b6 100644 --- a/decode.go +++ b/decode.go @@ -496,6 +496,23 @@ func (tttam TimeTagToAnyMode) valid() bool { return tttam >= 0 && tttam < maxTimeTagToAnyMode } +// ByteStringToTimeMode specifies the behavior when decoding a CBOR byte string into a Go time.Time. +type ByteStringToTimeMode int + +const ( + // ByteStringToTimeForbidden generates an error on an attempt to decode a CBOR byte string into a Go time.Time. + ByteStringToTimeForbidden ByteStringToTimeMode = iota + + // ByteStringToTimeAllowed permits decoding a CBOR byte string into a Go time.Time. + ByteStringToTimeAllowed + + maxByteStringToTimeMode +) + +func (bttm ByteStringToTimeMode) valid() bool { + return bttm >= 0 && bttm < maxByteStringToTimeMode +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -592,6 +609,9 @@ type DecOptions struct { // TimeTagToAnyMode specifies how to decode CBOR tag 0 and 1 into an empty interface (any). // Based on the specified mode, Unmarshal can return a time.Time value or a time string in a specific format. TimeTagToAny TimeTagToAnyMode + + // ByteStringToTimeMode specifies the behavior when decoding a CBOR byte string into a Go time.Time. + ByteStringToTime ByteStringToTimeMode } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -749,6 +769,10 @@ func (opts DecOptions) decMode() (*decMode, error) { return nil, errors.New("cbor: invalid TimeTagToAny " + strconv.Itoa(int(opts.TimeTagToAny))) } + if !opts.ByteStringToTime.valid() { + return nil, errors.New("cbor: invalid ByteStringToTime " + strconv.Itoa(int(opts.ByteStringToTime))) + } + dm := decMode{ dupMapKey: opts.DupMapKey, timeTag: opts.TimeTag, @@ -769,6 +793,7 @@ func (opts DecOptions) decMode() (*decMode, error) { fieldNameByteString: opts.FieldNameByteString, unrecognizedTagToAny: opts.UnrecognizedTagToAny, timeTagToAny: opts.TimeTagToAny, + byteStringToTime: opts.ByteStringToTime, } return &dm, nil @@ -841,6 +866,7 @@ type decMode struct { fieldNameByteString FieldNameByteStringMode unrecognizedTagToAny UnrecognizedTagToAnyMode timeTagToAny TimeTagToAnyMode + byteStringToTime ByteStringToTimeMode } var defaultDecMode, _ = DecOptions{}.decMode() @@ -867,6 +893,7 @@ func (dm *decMode) DecOptions() DecOptions { FieldNameByteString: dm.fieldNameByteString, UnrecognizedTagToAny: dm.unrecognizedTagToAny, TimeTagToAny: dm.timeTagToAny, + ByteStringToTime: dm.byteStringToTime, } } @@ -1328,6 +1355,16 @@ func (d *decoder) parseToTime() (time.Time, bool, error) { } switch t := d.nextCBORType(); t { + case cborTypeByteString: + if d.dm.byteStringToTime == ByteStringToTimeAllowed { + b, _ := d.parseByteString() + t, err := time.Parse(time.RFC3339, string(b)) + if err != nil { + return time.Time{}, false, errors.New("cbor: cannot set " + string(b) + " for time.Time: " + err.Error()) + } + return t, true, nil + } + return time.Time{}, false, &UnmarshalTypeError{CBORType: t.String(), GoType: typeTime.String()} case cborTypeTextString: s, err := d.parseTextString() if err != nil { diff --git a/decode_test.go b/decode_test.go index 9bc4c881..5c66ada6 100644 --- a/decode_test.go +++ b/decode_test.go @@ -4913,6 +4913,7 @@ func TestDecOptions(t *testing.T) { FieldNameByteString: FieldNameByteStringAllowed, UnrecognizedTagToAny: UnrecognizedTagContentToAny, TimeTagToAny: TimeTagToRFC3339, + ByteStringToTime: ByteStringToTimeAllowed, } ov := reflect.ValueOf(opts1) for i := 0; i < ov.NumField(); i++ { @@ -8748,3 +8749,83 @@ func TestDecModeTimeTagToAny(t *testing.T) { }) } } + +func TestDecModeInvalidByteStringToTimeMode(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{ByteStringToTime: -1}, + wantErrorMsg: "cbor: invalid ByteStringToTime -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{ByteStringToTime: 4}, + wantErrorMsg: "cbor: invalid ByteStringToTime 4", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("Expected non nil error from DecMode()") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("Expected error: %q, want: %q \n", tc.wantErrorMsg, err.Error()) + } + }) + } +} + +func TestDecModeByteStringToTime(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + in []byte + want time.Time + wantErrorMsg string + }{ + { + name: "Unmarshal byte string to time.Time when ByteStringToTime is not set", + opts: DecOptions{}, + in: hexDecode("54323031332D30332D32315432303A30343A30305A"), + wantErrorMsg: "cbor: cannot unmarshal byte string into Go value of type time.Time", + }, + { + name: "Unmarshal byte string to time.Time when ByteStringToTime is set to ByteStringToTimeAllowed", + opts: DecOptions{ByteStringToTime: ByteStringToTimeAllowed}, + in: hexDecode("54323031332D30332D32315432303A30343A30305A"), // '2013-03-21T20:04:00Z' + want: time.Date(2013, 3, 21, 20, 4, 0, 0, time.UTC), + }, + { + name: "Unmarshal byte string to time.Time with nano when ByteStringToTime is set to ByteStringToTimeAllowed", + opts: DecOptions{ByteStringToTime: ByteStringToTimeAllowed}, + in: hexDecode("56323031332D30332D32315432303A30343A30302E355A"), // '2013-03-21T20:04:00.5Z' + want: time.Date(2013, 3, 21, 20, 4, 0, 500000000, time.UTC), + }, + { + name: "Unmarshal an invalid byte string to time.Time when ByteStringToTime is set to ByteStringToTimeAllowed", + opts: DecOptions{ByteStringToTime: ByteStringToTimeAllowed}, + in: hexDecode("4B696E76616C696454657874"), // 'invalidText' + wantErrorMsg: "cbor: cannot set invalidText for time.Time: parsing time \"invalidText\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"invalidText\" as \"2006\"", + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := tc.opts.DecMode() + if err != nil { + t.Fatal(err) + } + + var got time.Time + if err := dm.Unmarshal(tc.in, &got); err != nil { + if tc.wantErrorMsg != err.Error() { + t.Errorf("unexpected error: %v", err) + } + } else { + compareNonFloats(t, tc.in, got, tc.want) + } + + }) + } +}