From 6396be350f2bf264db10be0f285a3feac0cf8514 Mon Sep 17 00:00:00 2001 From: Ben Luddy Date: Fri, 12 Apr 2024 13:16:23 -0400 Subject: [PATCH] Refactor sorted map encode to use fewer buffers for nested maps. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Runs a bit faster, but more importantly, only needs a single buffer to encode nested, sorted maps instead of using multiple temporary buffers. │ before.txt │ after.txt │ │ sec/op │ sec/op vs base │ MarshalCanonical/Go_map[string]string_to_CBOR_map_canonical 1.464µ ± 0% 1.395µ ± 0% -4.68% (p=0.000 n=10) MarshalCanonical/Go_map[int]int_to_CBOR_map_canonical 192.1n ± 0% 186.2n ± 1% -3.10% (p=0.000 n=10) geomean 530.2n 509.6n -3.89% │ before.txt │ after.txt │ │ B/op │ B/op vs base │ MarshalCanonical/Go_map[string]string_to_CBOR_map_canonical 88.00 ± 0% 112.00 ± 0% +27.27% (p=0.000 n=10) MarshalCanonical/Go_map[int]int_to_CBOR_map_canonical 3.000 ± 0% 3.000 ± 0% ~ (p=1.000 n=10) ¹ geomean 16.25 18.33 +12.82% ¹ all samples are equal │ before.txt │ after.txt │ │ allocs/op │ allocs/op vs base │ MarshalCanonical/Go_map[string]string_to_CBOR_map_canonical 2.000 ± 0% 2.000 ± 0% ~ (p=1.000 n=10) ¹ MarshalCanonical/Go_map[int]int_to_CBOR_map_canonical 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=10) ¹ geomean 1.414 1.414 +0.00% ¹ all samples are equal Signed-off-by: Ben Luddy --- encode.go | 103 ++++++++++++++++++++++---------------------- encode_map.go | 34 ++++++++++----- encode_map_go117.go | 30 ++++++++----- encode_test.go | 8 ++-- 4 files changed, 99 insertions(+), 76 deletions(-) diff --git a/encode.go b/encode.go index f068a9c8..cd9430cb 100644 --- a/encode.go +++ b/encode.go @@ -1217,25 +1217,58 @@ func (me mapEncodeFunc) encode(e *bytes.Buffer, em *encMode, v reflect.Value) er if mlen == 0 { return e.WriteByte(byte(cborTypeMap)) } - switch em.sort { - case SortNone, SortFastShuffle: - default: - if mlen > 1 { - return me.encodeCanonical(e, em, v) - } - } + encodeHead(e, byte(cborTypeMap), uint64(mlen)) + if em.sort == SortNone || em.sort == SortFastShuffle || mlen <= 1 { + return me.e(e, em, v, nil) + } + + kvsp := getKeyValues(v.Len()) // for sorting keys + defer putKeyValues(kvsp) + kvs := *kvsp + + kvBeginOffset := e.Len() + if err := me.e(e, em, v, kvs); err != nil { + return err + } + kvTotalLen := e.Len() - kvBeginOffset + + // Use the capacity at the tail of the encode buffer as a staging area to rearrange the + // encoded pairs into sorted order. + e.Grow(kvTotalLen) + tmp := e.Bytes()[e.Len() : e.Len()+kvTotalLen] // Can use e.AvailableBuffer() in Go 1.21+. + dst := e.Bytes()[kvBeginOffset:] + + if em.sort == SortBytewiseLexical { + sort.Sort(&bytewiseKeyValueSorter{kvs: kvs, data: dst}) + } else { + sort.Sort(&lengthFirstKeyValueSorter{kvs: kvs, data: dst}) + } + + // This is where the encoded bytes are actually rearranged in the output buffer to reflect + // the desired order. + sortedOffset := 0 + for _, kv := range kvs { + copy(tmp[sortedOffset:], dst[kv.offset:kv.nextOffset]) + sortedOffset += kv.nextOffset - kv.offset + } + copy(dst, tmp[:kvTotalLen]) + + return nil - return me.e(e, em, v, nil) } +// keyValue is the position of an encoded pair in a buffer. All offsets are zero-based and relative +// to the first byte of the first encoded pair. type keyValue struct { - keyCBORData, keyValueCBORData []byte - keyLen, keyValueLen int + offset int + valueOffset int + nextOffset int } type bytewiseKeyValueSorter struct { - kvs []keyValue + kvs []keyValue + data []byte } func (x *bytewiseKeyValueSorter) Len() int { @@ -1247,11 +1280,13 @@ func (x *bytewiseKeyValueSorter) Swap(i, j int) { } func (x *bytewiseKeyValueSorter) Less(i, j int) bool { - return bytes.Compare(x.kvs[i].keyCBORData, x.kvs[j].keyCBORData) <= 0 + kvi, kvj := x.kvs[i], x.kvs[j] + return bytes.Compare(x.data[kvi.offset:kvi.valueOffset], x.data[kvj.offset:kvj.valueOffset]) <= 0 } type lengthFirstKeyValueSorter struct { - kvs []keyValue + kvs []keyValue + data []byte } func (x *lengthFirstKeyValueSorter) Len() int { @@ -1263,10 +1298,11 @@ func (x *lengthFirstKeyValueSorter) Swap(i, j int) { } func (x *lengthFirstKeyValueSorter) Less(i, j int) bool { - if len(x.kvs[i].keyCBORData) != len(x.kvs[j].keyCBORData) { - return len(x.kvs[i].keyCBORData) < len(x.kvs[j].keyCBORData) + kvi, kvj := x.kvs[i], x.kvs[j] + if keyLengthDifference := (kvi.valueOffset - kvi.offset) - (kvj.valueOffset - kvj.offset); keyLengthDifference != 0 { + return keyLengthDifference < 0 } - return bytes.Compare(x.kvs[i].keyCBORData, x.kvs[j].keyCBORData) <= 0 + return bytes.Compare(x.data[kvi.offset:kvi.valueOffset], x.data[kvj.offset:kvj.valueOffset]) <= 0 } var keyValuePool = sync.Pool{} @@ -1294,41 +1330,6 @@ func putKeyValues(x *[]keyValue) { keyValuePool.Put(x) } -func (me mapEncodeFunc) encodeCanonical(e *bytes.Buffer, em *encMode, v reflect.Value) error { - kve := getEncodeBuffer() // accumulated cbor encoded key-values - defer putEncodeBuffer(kve) - - kvsp := getKeyValues(v.Len()) // for sorting keys - defer putKeyValues(kvsp) - - kvs := *kvsp - - err := me.e(kve, em, v, kvs) - if err != nil { - return err - } - - b := kve.Bytes() - for i, off := 0, 0; i < len(kvs); i++ { - kvs[i].keyCBORData = b[off : off+kvs[i].keyLen] - kvs[i].keyValueCBORData = b[off : off+kvs[i].keyValueLen] - off += kvs[i].keyValueLen - } - - if em.sort == SortBytewiseLexical { - sort.Sort(&bytewiseKeyValueSorter{kvs}) - } else { - sort.Sort(&lengthFirstKeyValueSorter{kvs}) - } - - encodeHead(e, byte(cborTypeMap), uint64(len(kvs))) - for i := 0; i < len(kvs); i++ { - e.Write(kvs[i].keyValueCBORData) - } - - return nil -} - func encodeStructToArray(e *bytes.Buffer, em *encMode, v reflect.Value) (err error) { structType, err := getEncodingStructType(v.Type()) if err != nil { diff --git a/encode_map.go b/encode_map.go index 42ec26a9..8b4b4bbc 100644 --- a/encode_map.go +++ b/encode_map.go @@ -17,7 +17,6 @@ type mapKeyValueEncodeFunc struct { } func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *bytes.Buffer, em *encMode, v reflect.Value, kvs []keyValue) error { - trackKeyValueLength := len(kvs) == v.Len() iterk := me.kpool.Get().(*reflect.Value) defer func() { iterk.SetZero() @@ -28,24 +27,39 @@ func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *bytes.Buffer, em *encMode, v iterv.SetZero() me.vpool.Put(iterv) }() - iter := v.MapRange() - for i := 0; iter.Next(); i++ { - off := e.Len() + + if kvs == nil { + for i, iter := 0, v.MapRange(); iter.Next(); i++ { + iterk.SetIterKey(iter) + iterv.SetIterValue(iter) + + if err := me.kf(e, em, *iterk); err != nil { + return err + } + if err := me.ef(e, em, *iterv); err != nil { + return err + } + } + return nil + } + + initial := e.Len() + for i, iter := 0, v.MapRange(); iter.Next(); i++ { iterk.SetIterKey(iter) iterv.SetIterValue(iter) + offset := e.Len() if err := me.kf(e, em, *iterk); err != nil { return err } - if trackKeyValueLength { - kvs[i].keyLen = e.Len() - off - } - + valueOffset := e.Len() if err := me.ef(e, em, *iterv); err != nil { return err } - if trackKeyValueLength { - kvs[i].keyValueLen = e.Len() - off + kvs[i] = keyValue{ + offset: offset - initial, + valueOffset: valueOffset - initial, + nextOffset: e.Len() - initial, } } diff --git a/encode_map_go117.go b/encode_map_go117.go index 6cbb8b09..31c39336 100644 --- a/encode_map_go117.go +++ b/encode_map_go117.go @@ -15,24 +15,32 @@ type mapKeyValueEncodeFunc struct { } func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *bytes.Buffer, em *encMode, v reflect.Value, kvs []keyValue) error { - trackKeyValueLength := len(kvs) == v.Len() - - iter := v.MapRange() - for i := 0; iter.Next(); i++ { - off := e.Len() + if kvs == nil { + for i, iter := 0, v.MapRange(); iter.Next(); i++ { + if err := me.kf(e, em, iter.Key()); err != nil { + return err + } + if err := me.ef(e, em, iter.Value()); err != nil { + return err + } + } + return nil + } + initial := e.Len() + for i, iter := 0, v.MapRange(); iter.Next(); i++ { + offset := e.Len() if err := me.kf(e, em, iter.Key()); err != nil { return err } - if trackKeyValueLength { - kvs[i].keyLen = e.Len() - off - } - + valueOffset := e.Len() if err := me.ef(e, em, iter.Value()); err != nil { return err } - if trackKeyValueLength { - kvs[i].keyValueLen = e.Len() - off + kvs[i] = keyValue{ + offset: offset - initial, + valueOffset: valueOffset - initial, + nextOffset: e.Len() - initial, } } diff --git a/encode_test.go b/encode_test.go index 7a6cd657..9ca2a432 100644 --- a/encode_test.go +++ b/encode_test.go @@ -2429,18 +2429,18 @@ func TestMarshalUnmarshalStructToArray(t *testing.T) { } func TestMapSort(t *testing.T) { - m := make(map[interface{}]bool) + m := make(map[interface{}]interface{}) m[10] = true m[100] = true m[-1] = true - m["z"] = true + m["z"] = "zzz" m["aa"] = true m[[1]int{100}] = true m[[1]int{-1}] = true m[false] = true - lenFirstSortedCborData := hexDecode("a80af520f5f4f51864f5617af58120f5626161f5811864f5") // sorted keys: 10, -1, false, 100, "z", [-1], "aa", [100] - bytewiseSortedCborData := hexDecode("a80af51864f520f5617af5626161f5811864f58120f5f4f5") // sorted keys: 10, 100, -1, "z", "aa", [100], [-1], false + lenFirstSortedCborData := hexDecode("a80af520f5f4f51864f5617a637a7a7a8120f5626161f5811864f5") // sorted keys: 10, -1, false, 100, "z", [-1], "aa", [100] + bytewiseSortedCborData := hexDecode("a80af51864f520f5617a637a7a7a626161f5811864f58120f5f4f5") // sorted keys: 10, 100, -1, "z", "aa", [100], [-1], false testCases := []struct { name string