Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

textsplitter: add an optional lenFunc to MarkdownTextSplitter #1096

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions textsplitter/markdown_splitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"reflect"
"strings"
"unicode/utf8"

"gitlab.com/golang-commonmark/markdown"
)
Expand All @@ -25,6 +24,7 @@ func NewMarkdownTextSplitter(opts ...Option) *MarkdownTextSplitter {
ReferenceLinks: options.ReferenceLinks,
HeadingHierarchy: options.KeepHeadingHierarchy,
JoinTableRows: options.JoinTableRows,
LenFunc: options.LenFunc,
}

if sp.SecondSplitter == nil {
Expand All @@ -36,6 +36,7 @@ func NewMarkdownTextSplitter(opts ...Option) *MarkdownTextSplitter {
"\n", // new line
" ", // space
}),
WithLenFunc(options.LenFunc),
)
}

Expand All @@ -57,6 +58,7 @@ type MarkdownTextSplitter struct {
ReferenceLinks bool
HeadingHierarchy bool
JoinTableRows bool
LenFunc func(string) int
}

// SplitText splits a text into multiple text.
Expand All @@ -76,6 +78,7 @@ func (sp MarkdownTextSplitter) SplitText(text string) ([]string, error) {
joinTableRows: sp.JoinTableRows,
hTitleStack: []string{},
hTitlePrependHierarchy: sp.HeadingHierarchy,
lenFunc: sp.LenFunc,
}

chunks := mc.splitText()
Expand Down Expand Up @@ -133,6 +136,9 @@ type markdownContext struct {
// joinTableRows determines whether a chunk should contain multiple table rows,
// or if each row in a table should be split into a separate chunk.
joinTableRows bool

// lenFunc represents the function to calculate the length of a string.
lenFunc func(string) int
}

// splitText splits Markdown text.
Expand Down Expand Up @@ -193,6 +199,8 @@ func (mc *markdownContext) clone(startAt, endAt int) *markdownContext {
chunkSize: mc.chunkSize,
chunkOverlap: mc.chunkOverlap,
secondSplitter: mc.secondSplitter,

lenFunc: mc.lenFunc,
}
}

Expand Down Expand Up @@ -438,7 +446,7 @@ func (mc *markdownContext) splitTableRows(header []string, bodies [][]string) {
// If we're at the start of the current snippet, or adding the current line would
// overflow the chunk size, prepend the header to the line (so that the new chunk
// will include the table header).
if len(mc.curSnippet) == 0 || utf8.RuneCountInString(mc.curSnippet)+utf8.RuneCountInString(line) >= mc.chunkSize {
if len(mc.curSnippet) == 0 || mc.lenFunc(mc.curSnippet+line) >= mc.chunkSize {
line = fmt.Sprintf("%s\n%s", headerMD, line)
}

Expand Down Expand Up @@ -617,7 +625,7 @@ func (mc *markdownContext) joinSnippet(snippet string) {
}

// check whether current chunk exceeds chunk size, if so, apply to chunks
if utf8.RuneCountInString(mc.curSnippet)+utf8.RuneCountInString(snippet) >= mc.chunkSize {
if mc.lenFunc(mc.curSnippet+snippet) >= mc.chunkSize {
mc.applyToChunks()
mc.curSnippet = snippet
} else {
Expand All @@ -634,7 +642,7 @@ func (mc *markdownContext) applyToChunks() {
var chunks []string
if mc.curSnippet != "" {
// check whether current chunk is over ChunkSize,if so, re-split current chunk
if utf8.RuneCountInString(mc.curSnippet) <= mc.chunkSize+mc.chunkOverlap {
if mc.lenFunc(mc.curSnippet) <= mc.chunkSize+mc.chunkOverlap {
chunks = []string{mc.curSnippet}
} else {
// split current snippet to chunks
Expand Down
45 changes: 45 additions & 0 deletions textsplitter/markdown_splitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"os"
"testing"

"github.com/pkoukk/tiktoken-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/schema"
Expand Down Expand Up @@ -579,3 +580,47 @@ func TestMarkdownHeaderTextSplitter_SplitInline(t *testing.T) {
})
}
}

func TestMarkdownHeaderTextSplitter_LenFunc(t *testing.T) {
t.Parallel()

tokenEncoder, _ := tiktoken.GetEncoding("cl100k_base")

sampleText := "The quick brown fox jumped over the lazy dog."
tokensPerChunk := len(tokenEncoder.Encode(sampleText, nil, nil))

type testCase struct {
markdown string
expectedDocs []schema.Document
}

testCases := []testCase{
{
markdown: `# Title` + "\n" + sampleText + "\n" + sampleText,
expectedDocs: []schema.Document{
{
PageContent: "# Title" + "\n" + sampleText,
Metadata: map[string]any{},
},
{
PageContent: "# Title" + "\n" + sampleText,
Metadata: map[string]any{},
},
},
},
}

splitter := NewMarkdownTextSplitter(
WithChunkSize(tokensPerChunk+1),
WithChunkOverlap(0),
WithLenFunc(func(s string) int {
return len(tokenEncoder.Encode(s, nil, nil))
}),
)

for _, tc := range testCases {
docs, err := CreateDocuments(splitter, []string{tc.markdown}, nil)
require.NoError(t, err)
assert.Equal(t, tc.expectedDocs, docs)
}
}
Loading