Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to reuse functions and improve code coverage #531

Merged
merged 2 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ type encodingStructType struct {
omitEmptyFieldsIdx []int
err error
toArray bool
fixedLength bool // Struct type doesn't have any omitempty or anonymous fields.
maxHeadLen int
}

func (st *encodingStructType) getFields(em *encMode) fields {
Expand Down Expand Up @@ -232,13 +230,10 @@ func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {
return getEncodingStructToArrayType(t, flds)
}

nOptional := 0

var err error
var hasKeyAsInt bool
var hasKeyAsStr bool
var omitEmptyIdx []int
fixedLength := true
e := getEncoderBuffer()
for i := 0; i < len(flds); i++ {
// Get field's encodeFunc
Expand Down Expand Up @@ -286,20 +281,10 @@ func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {
hasKeyAsStr = true
}

// Check if field is from embedded struct
if len(flds[i].idx) > 1 {
fixedLength = false
}

// Check if field can be omitted when empty
if flds[i].omitEmpty {
fixedLength = false
omitEmptyIdx = append(omitEmptyIdx, i)
}

if len(flds[i].idx) > 1 || flds[i].omitEmpty {
nOptional++
}
}
putEncoderBuffer(e)

Expand All @@ -326,8 +311,6 @@ func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {
bytewiseFields: bytewiseFields,
lengthFirstFields: lengthFirstFields,
omitEmptyFieldsIdx: omitEmptyIdx,
fixedLength: fixedLength,
maxHeadLen: encodedHeadLen(uint64(len(flds))),
}

encodingStructTypeCache.Store(t, structType)
Expand All @@ -346,9 +329,8 @@ func getEncodingStructToArrayType(t reflect.Type, flds fields) (*encodingStructT
}

structType := &encodingStructType{
fields: flds,
toArray: true,
fixedLength: true,
fields: flds,
toArray: true,
}
encodingStructTypeCache.Store(t, structType)
return structType, structType.err
Expand Down
80 changes: 19 additions & 61 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -1250,36 +1250,6 @@ func encodeStructToArray(e *encoderBuffer, em *encMode, v reflect.Value) (err er
return nil
}

func encodeFixedLengthStruct(e *encoderBuffer, em *encMode, v reflect.Value, flds fields) error {
if b := em.encTagBytes(v.Type()); b != nil {
e.Write(b)
}

encodeHead(e, byte(cborTypeMap), uint64(len(flds)))

start := 0
if em.sort == SortFastShuffle {
start = rand.Intn(len(flds)) //nolint:gosec // Don't need a CSPRNG for deck cutting.
}

for offset := 0; offset < len(flds); offset++ {
i := (start + offset) % len(flds)
f := flds[i]
if !f.keyAsInt && em.fieldName == FieldNameToByteString {
e.Write(f.cborNameByteString)
} else { // int or text string
e.Write(f.cborName)
}

fv := v.Field(f.idx[0])
if err := f.ef(e, em, fv); err != nil {
return err
}
}

return nil
}

func encodeStruct(e *encoderBuffer, em *encMode, v reflect.Value) (err error) {
structType, err := getEncodingStructType(v.Type())
if err != nil {
Expand All @@ -1288,10 +1258,6 @@ func encodeStruct(e *encoderBuffer, em *encMode, v reflect.Value) (err error) {

flds := structType.getFields(em)

if structType.fixedLength {
return encodeFixedLengthStruct(e, em, v, flds)
}

start := 0
if em.sort == SortFastShuffle {
start = rand.Intn(len(flds)) //nolint:gosec // Don't need a CSPRNG for deck cutting.
Expand All @@ -1301,8 +1267,9 @@ func encodeStruct(e *encoderBuffer, em *encMode, v reflect.Value) (err error) {
e.Write(b)
}

// Reserve space in the output buffer for the head if its encoded size is fixed.
encodeHead(e, byte(cborTypeMap), uint64(len(flds)))
// Encode head with struct field count.
// Head is rewritten later if actual encoded field count is different from struct field count.
encodedHeadLen := encodeHead(e, byte(cborTypeMap), uint64(len(flds)))

kvbegin := e.Len()
kvcount := 0
Expand Down Expand Up @@ -1345,14 +1312,19 @@ func encodeStruct(e *encoderBuffer, em *encMode, v reflect.Value) (err error) {
kvcount++
}

if len(flds) == kvcount {
// Encoded element count in head is the same as actual element count.
return nil
}

// Overwrite the bytes that were reserved for the head before encoding the map entries.
var actualHeadLen int
{
headbuf := encoderBuffer{Buffer: *bytes.NewBuffer(e.Bytes()[kvbegin-structType.maxHeadLen : kvbegin-structType.maxHeadLen : kvbegin])}
encodeHead(&headbuf, byte(cborTypeMap), uint64(kvcount))
headbuf := encoderBuffer{Buffer: *bytes.NewBuffer(e.Bytes()[kvbegin-encodedHeadLen : kvbegin-encodedHeadLen : kvbegin])}
actualHeadLen = encodeHead(&headbuf, byte(cborTypeMap), uint64(kvcount))
}

actualHeadLen := encodedHeadLen(uint64(kvcount))
if actualHeadLen == structType.maxHeadLen {
if actualHeadLen == encodedHeadLen {
// The bytes reserved for the encoded head were exactly the right size, so the
// encoded entries are already in their final positions.
return nil
Expand All @@ -1361,7 +1333,7 @@ func encodeStruct(e *encoderBuffer, em *encMode, v reflect.Value) (err error) {
// We reserved more bytes than needed for the encoded head, based on the number of fields
// encoded. The encoded entries are offset to the right by the number of excess reserved
// bytes. Shift the entries left to remove the gap.
excessReservedBytes := structType.maxHeadLen - actualHeadLen
excessReservedBytes := encodedHeadLen - actualHeadLen
dst := e.Bytes()[kvbegin-excessReservedBytes : e.Len()-excessReservedBytes]
src := e.Bytes()[kvbegin:e.Len()]
copy(dst, src)
Expand Down Expand Up @@ -1519,47 +1491,33 @@ func encodeTag(e *encoderBuffer, em *encMode, v reflect.Value) error {
return encode(e, em, reflect.ValueOf(t.Content))
}

func encodeHead(e *encoderBuffer, t byte, n uint64) {
// encodeHead writes CBOR head of specified type t and returns number of bytes written.
func encodeHead(e *encoderBuffer, t byte, n uint64) int {
if n <= 23 {
e.WriteByte(t | byte(n))
return
return 1
}
if n <= math.MaxUint8 {
e.scratch[0] = t | byte(24)
e.scratch[1] = byte(n)
e.Write(e.scratch[:2])
return
return 2
}
if n <= math.MaxUint16 {
e.scratch[0] = t | byte(25)
binary.BigEndian.PutUint16(e.scratch[1:], uint16(n))
e.Write(e.scratch[:3])
return
return 3
}
if n <= math.MaxUint32 {
e.scratch[0] = t | byte(26)
binary.BigEndian.PutUint32(e.scratch[1:], uint32(n))
e.Write(e.scratch[:5])
return
return 5
}
e.scratch[0] = t | byte(27)
binary.BigEndian.PutUint64(e.scratch[1:], n)
e.Write(e.scratch[:9])
}

// encodedHeadLen returns the number of bytes that will be written by a call to encodeHead with the
// given argument. This must be kept in sync with encodeHead.
func encodedHeadLen(arg uint64) int {
switch {
case arg <= 23:
return 1
case arg <= math.MaxUint8:
return 2
case arg <= math.MaxUint16:
return 3
case arg <= math.MaxUint32:
return 5
}
return 9
}

Expand Down
Loading