Skip to content

Commit

Permalink
Merge pull request #537 from benluddy/map-sort-refactor
Browse files Browse the repository at this point in the history
Refactor sorted map encode to use fewer buffers for nested maps.
  • Loading branch information
fxamacker authored May 20, 2024
2 parents 367b524 + 6396be3 commit 6d407ed
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 76 deletions.
103 changes: 52 additions & 51 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -1217,25 +1217,58 @@ func (me mapEncodeFunc) encode(e *bytes.Buffer, em *encMode, v reflect.Value) er
if mlen == 0 {
return e.WriteByte(byte(cborTypeMap))
}
switch em.sort {
case SortNone, SortFastShuffle:
default:
if mlen > 1 {
return me.encodeCanonical(e, em, v)
}
}

encodeHead(e, byte(cborTypeMap), uint64(mlen))
if em.sort == SortNone || em.sort == SortFastShuffle || mlen <= 1 {
return me.e(e, em, v, nil)
}

kvsp := getKeyValues(v.Len()) // for sorting keys
defer putKeyValues(kvsp)
kvs := *kvsp

kvBeginOffset := e.Len()
if err := me.e(e, em, v, kvs); err != nil {
return err
}
kvTotalLen := e.Len() - kvBeginOffset

// Use the capacity at the tail of the encode buffer as a staging area to rearrange the
// encoded pairs into sorted order.
e.Grow(kvTotalLen)
tmp := e.Bytes()[e.Len() : e.Len()+kvTotalLen] // Can use e.AvailableBuffer() in Go 1.21+.
dst := e.Bytes()[kvBeginOffset:]

if em.sort == SortBytewiseLexical {
sort.Sort(&bytewiseKeyValueSorter{kvs: kvs, data: dst})
} else {
sort.Sort(&lengthFirstKeyValueSorter{kvs: kvs, data: dst})
}

// This is where the encoded bytes are actually rearranged in the output buffer to reflect
// the desired order.
sortedOffset := 0
for _, kv := range kvs {
copy(tmp[sortedOffset:], dst[kv.offset:kv.nextOffset])
sortedOffset += kv.nextOffset - kv.offset
}
copy(dst, tmp[:kvTotalLen])

return nil

return me.e(e, em, v, nil)
}

// keyValue is the position of an encoded pair in a buffer. All offsets are zero-based and relative
// to the first byte of the first encoded pair.
type keyValue struct {
keyCBORData, keyValueCBORData []byte
keyLen, keyValueLen int
offset int
valueOffset int
nextOffset int
}

type bytewiseKeyValueSorter struct {
kvs []keyValue
kvs []keyValue
data []byte
}

func (x *bytewiseKeyValueSorter) Len() int {
Expand All @@ -1247,11 +1280,13 @@ func (x *bytewiseKeyValueSorter) Swap(i, j int) {
}

func (x *bytewiseKeyValueSorter) Less(i, j int) bool {
return bytes.Compare(x.kvs[i].keyCBORData, x.kvs[j].keyCBORData) <= 0
kvi, kvj := x.kvs[i], x.kvs[j]
return bytes.Compare(x.data[kvi.offset:kvi.valueOffset], x.data[kvj.offset:kvj.valueOffset]) <= 0
}

type lengthFirstKeyValueSorter struct {
kvs []keyValue
kvs []keyValue
data []byte
}

func (x *lengthFirstKeyValueSorter) Len() int {
Expand All @@ -1263,10 +1298,11 @@ func (x *lengthFirstKeyValueSorter) Swap(i, j int) {
}

func (x *lengthFirstKeyValueSorter) Less(i, j int) bool {
if len(x.kvs[i].keyCBORData) != len(x.kvs[j].keyCBORData) {
return len(x.kvs[i].keyCBORData) < len(x.kvs[j].keyCBORData)
kvi, kvj := x.kvs[i], x.kvs[j]
if keyLengthDifference := (kvi.valueOffset - kvi.offset) - (kvj.valueOffset - kvj.offset); keyLengthDifference != 0 {
return keyLengthDifference < 0
}
return bytes.Compare(x.kvs[i].keyCBORData, x.kvs[j].keyCBORData) <= 0
return bytes.Compare(x.data[kvi.offset:kvi.valueOffset], x.data[kvj.offset:kvj.valueOffset]) <= 0
}

var keyValuePool = sync.Pool{}
Expand Down Expand Up @@ -1294,41 +1330,6 @@ func putKeyValues(x *[]keyValue) {
keyValuePool.Put(x)
}

func (me mapEncodeFunc) encodeCanonical(e *bytes.Buffer, em *encMode, v reflect.Value) error {
kve := getEncodeBuffer() // accumulated cbor encoded key-values
defer putEncodeBuffer(kve)

kvsp := getKeyValues(v.Len()) // for sorting keys
defer putKeyValues(kvsp)

kvs := *kvsp

err := me.e(kve, em, v, kvs)
if err != nil {
return err
}

b := kve.Bytes()
for i, off := 0, 0; i < len(kvs); i++ {
kvs[i].keyCBORData = b[off : off+kvs[i].keyLen]
kvs[i].keyValueCBORData = b[off : off+kvs[i].keyValueLen]
off += kvs[i].keyValueLen
}

if em.sort == SortBytewiseLexical {
sort.Sort(&bytewiseKeyValueSorter{kvs})
} else {
sort.Sort(&lengthFirstKeyValueSorter{kvs})
}

encodeHead(e, byte(cborTypeMap), uint64(len(kvs)))
for i := 0; i < len(kvs); i++ {
e.Write(kvs[i].keyValueCBORData)
}

return nil
}

func encodeStructToArray(e *bytes.Buffer, em *encMode, v reflect.Value) (err error) {
structType, err := getEncodingStructType(v.Type())
if err != nil {
Expand Down
34 changes: 24 additions & 10 deletions encode_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ type mapKeyValueEncodeFunc struct {
}

func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *bytes.Buffer, em *encMode, v reflect.Value, kvs []keyValue) error {
trackKeyValueLength := len(kvs) == v.Len()
iterk := me.kpool.Get().(*reflect.Value)
defer func() {
iterk.SetZero()
Expand All @@ -28,24 +27,39 @@ func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *bytes.Buffer, em *encMode, v
iterv.SetZero()
me.vpool.Put(iterv)
}()
iter := v.MapRange()
for i := 0; iter.Next(); i++ {
off := e.Len()

if kvs == nil {
for i, iter := 0, v.MapRange(); iter.Next(); i++ {
iterk.SetIterKey(iter)
iterv.SetIterValue(iter)

if err := me.kf(e, em, *iterk); err != nil {
return err
}
if err := me.ef(e, em, *iterv); err != nil {
return err
}
}
return nil
}

initial := e.Len()
for i, iter := 0, v.MapRange(); iter.Next(); i++ {
iterk.SetIterKey(iter)
iterv.SetIterValue(iter)

offset := e.Len()
if err := me.kf(e, em, *iterk); err != nil {
return err
}
if trackKeyValueLength {
kvs[i].keyLen = e.Len() - off
}

valueOffset := e.Len()
if err := me.ef(e, em, *iterv); err != nil {
return err
}
if trackKeyValueLength {
kvs[i].keyValueLen = e.Len() - off
kvs[i] = keyValue{
offset: offset - initial,
valueOffset: valueOffset - initial,
nextOffset: e.Len() - initial,
}
}

Expand Down
30 changes: 19 additions & 11 deletions encode_map_go117.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,32 @@ type mapKeyValueEncodeFunc struct {
}

func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *bytes.Buffer, em *encMode, v reflect.Value, kvs []keyValue) error {
trackKeyValueLength := len(kvs) == v.Len()

iter := v.MapRange()
for i := 0; iter.Next(); i++ {
off := e.Len()
if kvs == nil {
for i, iter := 0, v.MapRange(); iter.Next(); i++ {
if err := me.kf(e, em, iter.Key()); err != nil {
return err
}
if err := me.ef(e, em, iter.Value()); err != nil {
return err
}
}
return nil
}

initial := e.Len()
for i, iter := 0, v.MapRange(); iter.Next(); i++ {
offset := e.Len()
if err := me.kf(e, em, iter.Key()); err != nil {
return err
}
if trackKeyValueLength {
kvs[i].keyLen = e.Len() - off
}

valueOffset := e.Len()
if err := me.ef(e, em, iter.Value()); err != nil {
return err
}
if trackKeyValueLength {
kvs[i].keyValueLen = e.Len() - off
kvs[i] = keyValue{
offset: offset - initial,
valueOffset: valueOffset - initial,
nextOffset: e.Len() - initial,
}
}

Expand Down
8 changes: 4 additions & 4 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2429,18 +2429,18 @@ func TestMarshalUnmarshalStructToArray(t *testing.T) {
}

func TestMapSort(t *testing.T) {
m := make(map[interface{}]bool)
m := make(map[interface{}]interface{})
m[10] = true
m[100] = true
m[-1] = true
m["z"] = true
m["z"] = "zzz"
m["aa"] = true
m[[1]int{100}] = true
m[[1]int{-1}] = true
m[false] = true

lenFirstSortedCborData := hexDecode("a80af520f5f4f51864f5617af58120f5626161f5811864f5") // sorted keys: 10, -1, false, 100, "z", [-1], "aa", [100]
bytewiseSortedCborData := hexDecode("a80af51864f520f5617af5626161f5811864f58120f5f4f5") // sorted keys: 10, 100, -1, "z", "aa", [100], [-1], false
lenFirstSortedCborData := hexDecode("a80af520f5f4f51864f5617a637a7a7a8120f5626161f5811864f5") // sorted keys: 10, -1, false, 100, "z", [-1], "aa", [100]
bytewiseSortedCborData := hexDecode("a80af51864f520f5617a637a7a7a626161f5811864f58120f5f4f5") // sorted keys: 10, 100, -1, "z", "aa", [100], [-1], false

testCases := []struct {
name string
Expand Down

0 comments on commit 6d407ed

Please sign in to comment.