From 2eee0b00f1c2e9b11a7517d2228844393ed9cfba Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Thu, 20 Jun 2024 00:38:01 -0400 Subject: [PATCH] Add additional methods to Batch similar to what exists on Queryx --- batchx.go | 34 ++++++++++ batchx_test.go | 172 +++++++++++++++++++++++++++++++++++++------------ 2 files changed, 164 insertions(+), 42 deletions(-) diff --git a/batchx.go b/batchx.go index d3e4dcbd..c75cf2f9 100644 --- a/batchx.go +++ b/batchx.go @@ -1,6 +1,8 @@ package gocqlx import ( + "fmt" + "github.com/gocql/gocql" ) @@ -27,6 +29,38 @@ func (b *Batch) BindStruct(qry *Queryx, arg interface{}) error { return nil } +// Bind binds query parameters to values from args. +// If value cannot be found an error is reported. +func (b *Batch) Bind(qry *Queryx, args ...interface{}) error { + if len(qry.Names) != len(args) { + return fmt.Errorf("query requires %d arguments, but %d provided", len(qry.Names), len(args)) + } + b.Query(qry.Statement(), args...) + return nil +} + +// BindMap binds query named parameters to values from arg using a mapper. +// If value cannot be found an error is reported. +func (b *Batch) BindMap(qry *Queryx, arg map[string]interface{}) error { + args, err := qry.bindMapArgs(arg) + if err != nil { + return err + } + b.Query(qry.Statement(), args...) + return nil +} + +// BindStructMap binds query named parameters to values from arg0 and arg1 using a mapper. +// If value cannot be found an error is reported. +func (b *Batch) BindStructMap(qry *Queryx, arg0 interface{}, arg1 map[string]interface{}) error { + args, err := qry.bindStructArgs(arg0, arg1) + if err != nil { + return err + } + b.Query(qry.Statement(), args...) + return nil +} + // ExecuteBatch executes a batch operation and returns nil if successful // otherwise an error describing the failure. func (s *Session) ExecuteBatch(batch *Batch) error { diff --git a/batchx_test.go b/batchx_test.go index 9e5f3ddf..d2b87b9f 100644 --- a/batchx_test.go +++ b/batchx_test.go @@ -52,52 +52,140 @@ func TestBatch(t *testing.T) { SongID: mustParseUUID("60fc234a-8481-4343-93bb-72ecab404863"), } - insertSong := qb.Insert("batch_test.songs"). - Columns("id", "title", "album", "artist", "tags", "data").Query(session) - insertPlaylist := qb.Insert("batch_test.playlists"). - Columns("id", "title", "album", "artist", "song_id").Query(session) - selectSong := qb.Select("batch_test.songs").Where(qb.Eq("id")).Query(session) - selectPlaylist := qb.Select("batch_test.playlists").Where(qb.Eq("id")).Query(session) - t.Run("batch inserts", func(t *testing.T) { t.Parallel() - type batchQry struct { - qry *gocqlx.Queryx - arg interface{} - } - - qrys := []batchQry{ - {qry: insertSong, arg: song}, - {qry: insertPlaylist, arg: playlist}, - } - - b := session.NewBatch(gocql.LoggedBatch) - for _, qry := range qrys { - if err := b.BindStruct(qry.qry, qry.arg); err != nil { - t.Fatal("BindStruct failed:", err) - } - } - if err := session.ExecuteBatch(b); err != nil { - t.Fatal("batch execution:", err) - } - - // verify song was inserted - var gotSong Song - if err := selectSong.BindStruct(song).Get(&gotSong); err != nil { - t.Fatal("select song:", err) - } - if diff := cmp.Diff(gotSong, song); diff != "" { - t.Errorf("expected %v song, got %v, diff: %q", song, gotSong, diff) - } - - // verify playlist item was inserted - var gotPlayList PlaylistItem - if err := selectPlaylist.BindStruct(playlist).Get(&gotPlayList); err != nil { - t.Fatal("select song:", err) + tcases := []struct { + name string + methodSong func(*gocqlx.Batch, *gocqlx.Queryx, Song) error + methodPlaylist func(*gocqlx.Batch, *gocqlx.Queryx, PlaylistItem) error + }{ + { + name: "BindStruct", + methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error { + return b.BindStruct(q, song) + }, + methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error { + return b.BindStruct(q, playlist) + }, + }, + { + name: "BindMap", + methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error { + return b.BindMap(q, map[string]interface{}{ + "id": song.ID, + "title": song.Title, + "album": song.Album, + "artist": song.Artist, + "tags": song.Tags, + "data": song.Data, + }) + }, + methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error { + return b.BindMap(q, map[string]interface{}{ + "id": playlist.ID, + "title": playlist.Title, + "album": playlist.Album, + "artist": playlist.Artist, + "song_id": playlist.SongID, + }) + }, + }, + { + name: "Bind", + methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error { + return b.Bind(q, song.ID, song.Title, song.Album, song.Artist, song.Tags, song.Data) + }, + methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error { + return b.Bind(q, playlist.ID, playlist.Title, playlist.Album, playlist.Artist, playlist.SongID) + }, + }, + { + name: "BindStructMap", + methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error { + in := map[string]interface{}{ + "title": song.Title, + "album": song.Album, + } + return b.BindStructMap(q, struct { + ID gocql.UUID + Artist string + Tags []string + Data []byte + }{ + ID: song.ID, + Artist: song.Artist, + Tags: song.Tags, + Data: song.Data, + }, in) + }, + methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error { + in := map[string]interface{}{ + "title": playlist.Title, + "album": playlist.Album, + } + return b.BindStructMap(q, struct { + ID gocql.UUID + Artist string + SongID gocql.UUID + }{ + ID: playlist.ID, + Artist: playlist.Artist, + SongID: playlist.SongID, + }, + in, + ) + }, + }, } - if diff := cmp.Diff(gotPlayList, playlist); diff != "" { - t.Errorf("expected %v playList, got %v, diff: %q", playlist, gotPlayList, diff) + for _, tcase := range tcases { + t.Run(tcase.name, func(t *testing.T) { + insertSong := qb.Insert("batch_test.songs"). + Columns("id", "title", "album", "artist", "tags", "data").Query(session) + insertPlaylist := qb.Insert("batch_test.playlists"). + Columns("id", "title", "album", "artist", "song_id").Query(session) + selectSong := qb.Select("batch_test.songs").Where(qb.Eq("id")).Query(session) + selectPlaylist := qb.Select("batch_test.playlists").Where(qb.Eq("id")).Query(session) + deleteSong := qb.Delete("batch_test.songs").Where(qb.Eq("id")).Query(session) + deletePlaylist := qb.Delete("batch_test.playlists").Where(qb.Eq("id")).Query(session) + + b := session.NewBatch(gocql.LoggedBatch) + + if err = tcase.methodSong(b, insertSong, song); err != nil { + t.Fatal("insert song:", err) + } + if err = tcase.methodPlaylist(b, insertPlaylist, playlist); err != nil { + t.Fatal("insert playList:", err) + } + + if err := session.ExecuteBatch(b); err != nil { + t.Fatal("batch execution:", err) + } + + // verify song was inserted + var gotSong Song + if err := selectSong.BindStruct(song).Get(&gotSong); err != nil { + t.Fatal("select song:", err) + } + if diff := cmp.Diff(gotSong, song); diff != "" { + t.Errorf("expected %v song, got %v, diff: %q", song, gotSong, diff) + } + + // verify playlist item was inserted + var gotPlayList PlaylistItem + if err := selectPlaylist.BindStruct(playlist).Get(&gotPlayList); err != nil { + t.Fatal("select playList:", err) + } + if diff := cmp.Diff(gotPlayList, playlist); diff != "" { + t.Errorf("expected %v playList, got %v, diff: %q", playlist, gotPlayList, diff) + } + if err = deletePlaylist.BindStruct(playlist).Exec(); err != nil { + t.Error("delete playlist:", err) + } + if err = deleteSong.BindStruct(song).Exec(); err != nil { + t.Error("delete song:", err) + } + }) } }) }