Skip to content

Commit

Permalink
driver: context support (#997)
Browse files Browse the repository at this point in the history
* wip

* poc: context support

* move context management into driver
Close instead of SetTimeout hack
implement more interfaces, with assertions

* lint

* avoid excessive spinning when context canceled & waiting on contexts channel to close

* compressedHeader tweak

* mysql doesn't have named parameters

---------

Co-authored-by: lance6716 <[email protected]>
  • Loading branch information
serprex and lance6716 authored Feb 24, 2025
1 parent 08630ce commit 5fc8c87
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 20 deletions.
15 changes: 15 additions & 0 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,21 @@ func (c *Conn) Begin() error {
return errors.Trace(err)
}

func (c *Conn) BeginTx(readOnly bool, txIsolation string) error {
if txIsolation != "" {
if _, err := c.exec("SET TRANSACTION ISOLATION LEVEL " + txIsolation); err != nil {
return errors.Trace(err)
}
}
var err error
if readOnly {
_, err = c.exec("START TRANSACTION READ ONLY")
} else {
_, err = c.exec("START TRANSACTION")
}
return errors.Trace(err)
}

func (c *Conn) Commit() error {
_, err := c.exec("COMMIT")
return errors.Trace(err)
Expand Down
193 changes: 175 additions & 18 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package driver

import (
"context"
"crypto/tls"
"database/sql"
sqldriver "database/sql/driver"
Expand All @@ -21,6 +22,23 @@ import (
"github.com/pingcap/errors"
)

var (
_ sqldriver.Driver = &driver{}
_ sqldriver.DriverContext = &driver{}
_ sqldriver.Connector = &connInfo{}
_ sqldriver.NamedValueChecker = &conn{}
_ sqldriver.Validator = &conn{}
_ sqldriver.Conn = &conn{}
_ sqldriver.Pinger = &conn{}
_ sqldriver.ConnBeginTx = &conn{}
_ sqldriver.ConnPrepareContext = &conn{}
_ sqldriver.ExecerContext = &conn{}
_ sqldriver.QueryerContext = &conn{}
_ sqldriver.Stmt = &stmt{}
_ sqldriver.StmtExecContext = &stmt{}
_ sqldriver.StmtQueryContext = &stmt{}
)

var customTLSMutex sync.Mutex

// Map of dsn address (makes more sense than full dsn?) to tls Config
Expand Down Expand Up @@ -101,16 +119,18 @@ func parseDSN(dsn string) (connInfo, error) {
// Open takes a supplied DSN string and opens a connection
// See ParseDSN for more information on the form of the DSN
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
var (
c *client.Conn
// by default database/sql driver retries will be enabled
retries = true
)

ci, err := parseDSN(dsn)
if err != nil {
return nil, err
}
return ci.Connect(context.Background())
}

func (ci connInfo) Connect(ctx context.Context) (sqldriver.Conn, error) {
var c *client.Conn
var err error
// by default database/sql driver retries will be enabled
retries := true

if ci.standardDSN {
var timeout time.Duration
Expand Down Expand Up @@ -159,45 +179,86 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
}
}

if timeout > 0 {
c, err = client.ConnectWithTimeout(ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...)
} else {
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, configuredOptions...)
if timeout <= 0 {
timeout = 10 * time.Second
}
c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...)
} else {
// No more processing here. Let's only support url parameters with the newer style DSN
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db)
c, err = client.ConnectWithContext(ctx, ci.addr, ci.user, ci.password, ci.db, 10*time.Second)
}
if err != nil {
return nil, err
}

contexts := make(chan context.Context)
go func() {
ctx := context.Background()
for {
var ok bool
select {
case <-ctx.Done():
ctx = context.Background()
_ = c.Conn.Close()
case ctx, ok = <-contexts:
if !ok {
return
}
}
}
}()

// if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3
// retries by the database/sql package. If retries are 'off' then we'll return
// the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
// In this case the sqldriver.Validator interface is implemented and will return
// false for IsValid() signaling the connection is bad and should be discarded.
return &conn{Conn: c, state: &state{valid: true, useStdLibErrors: retries}}, nil
return &conn{
Conn: c,
state: &state{contexts: contexts, valid: true, useStdLibErrors: retries},
}, nil
}

type CheckNamedValueFunc func(*sqldriver.NamedValue) error
func (d driver) OpenConnector(name string) (sqldriver.Connector, error) {
return parseDSN(name)
}

var (
_ sqldriver.NamedValueChecker = &conn{}
_ sqldriver.Validator = &conn{}
)
func (ci connInfo) Driver() sqldriver.Driver {
return driver{}
}

type CheckNamedValueFunc func(*sqldriver.NamedValue) error

type state struct {
valid bool
contexts chan context.Context
valid bool
// when true, the driver connection will return ErrBadConn from the golang Standard Library
useStdLibErrors bool
}

func (s *state) watchCtx(ctx context.Context) func() {
s.contexts <- ctx
return func() {
s.contexts <- context.Background()
}
}

func (s *state) Close() {
if s.contexts != nil {
close(s.contexts)
s.contexts = nil
}
}

type conn struct {
*client.Conn
state *state
}

func (c *conn) watchCtx(ctx context.Context) func() {
return c.state.watchCtx(ctx)
}

func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
for _, nvChecker := range namedValueCheckers {
err := nvChecker(nv)
Expand All @@ -220,6 +281,17 @@ func (c *conn) IsValid() bool {
return c.state.valid
}

func (c *conn) Ping(ctx context.Context) error {
defer c.watchCtx(ctx)()
if err := c.Conn.Ping(); err != nil {
if err == context.DeadlineExceeded || err == context.Canceled {
return err
}
return sqldriver.ErrBadConn
}
return nil
}

func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
st, err := c.Conn.Prepare(query)
if err != nil {
Expand All @@ -229,7 +301,13 @@ func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
return &stmt{Stmt: st, connectionState: c.state}, nil
}

func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) {
defer c.watchCtx(ctx)()
return c.Prepare(query)
}

func (c *conn) Close() error {
c.state.Close()
return c.Conn.Close()
}

Expand All @@ -242,6 +320,29 @@ func (c *conn) Begin() (sqldriver.Tx, error) {
return &tx{c.Conn}, nil
}

var isolationLevelTransactionIsolation = map[sql.IsolationLevel]string{
sql.LevelDefault: "",
sql.LevelRepeatableRead: "REPEATABLE READ",
sql.LevelReadCommitted: "READ COMMITTED",
sql.LevelReadUncommitted: "READ UNCOMMITTED",
sql.LevelSerializable: "SERIALIZABLE",
}

func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
defer c.watchCtx(ctx)()

isolation := sql.IsolationLevel(opts.Isolation)
txIsolation, ok := isolationLevelTransactionIsolation[isolation]
if !ok {
return nil, fmt.Errorf("invalid mysql transaction isolation level %s", isolation)
}
err := c.Conn.BeginTx(opts.ReadOnly, txIsolation)
if err != nil {
return nil, errors.Trace(err)
}
return &tx{c.Conn}, nil
}

func buildArgs(args []sqldriver.Value) []interface{} {
a := make([]interface{}, len(args))

Expand All @@ -252,6 +353,16 @@ func buildArgs(args []sqldriver.Value) []interface{} {
return a
}

func buildNamedArgs(args []sqldriver.NamedValue) []interface{} {
a := make([]interface{}, len(args))

for i, arg := range args {
a[i] = arg.Value
}

return a
}

func (st *state) replyError(err error) error {
isBadConnection := mysql.ErrorEqual(err, mysql.ErrBadConn)

Expand All @@ -275,6 +386,16 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
return &result{r}, nil
}

func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) {
defer c.watchCtx(ctx)()
a := buildNamedArgs(args)
r, err := c.Conn.Execute(query, a...)
if err != nil {
return nil, c.state.replyError(err)
}
return &result{r}, nil
}

func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, error) {
a := buildArgs(args)
r, err := c.Conn.Execute(query, a...)
Expand All @@ -284,11 +405,25 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
return newRows(r.Resultset)
}

func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
defer c.watchCtx(ctx)()
a := buildNamedArgs(args)
r, err := c.Conn.Execute(query, a...)
if err != nil {
return nil, c.state.replyError(err)
}
return newRows(r.Resultset)
}

type stmt struct {
*client.Stmt
connectionState *state
}

func (s *stmt) watchCtx(ctx context.Context) func() {
return s.connectionState.watchCtx(ctx)
}

func (s *stmt) Close() error {
return s.Stmt.Close()
}
Expand All @@ -306,6 +441,17 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
return &result{r}, nil
}

func (s *stmt) ExecContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Result, error) {
defer s.watchCtx(ctx)()

a := buildNamedArgs(args)
r, err := s.Stmt.Execute(a...)
if err != nil {
return nil, s.connectionState.replyError(err)
}
return &result{r}, nil
}

func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
a := buildArgs(args)
r, err := s.Stmt.Execute(a...)
Expand All @@ -315,6 +461,17 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
return newRows(r.Resultset)
}

func (s *stmt) QueryContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
defer s.watchCtx(ctx)()

a := buildNamedArgs(args)
r, err := s.Stmt.Execute(a...)
if err != nil {
return nil, s.connectionState.replyError(err)
}
return newRows(r.Resultset)
}

type tx struct {
*client.Conn
}
Expand Down
4 changes: 2 additions & 2 deletions packet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) {
var (
compressedLength, uncompressedLength int
payload *bytes.Buffer
compressedHeader = make([]byte, 7)
compressedHeader [7]byte
)

if len(data) > MinCompressionLength {
Expand Down Expand Up @@ -406,7 +406,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) {
compressedHeader[4] = byte(uncompressedLength)
compressedHeader[5] = byte(uncompressedLength >> 8)
compressedHeader[6] = byte(uncompressedLength >> 16)
if _, err = compressedPacket.Write(compressedHeader); err != nil {
if _, err = compressedPacket.Write(compressedHeader[:]); err != nil {
return 0, err
}
c.CompressedSequence++
Expand Down

0 comments on commit 5fc8c87

Please sign in to comment.