diff --git a/pkg/executor/insert.go b/pkg/executor/insert.go index 22dd28c673c1e..132ebbc979170 100644 --- a/pkg/executor/insert.go +++ b/pkg/executor/insert.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/table/tables" "github.com/pingcap/tidb/pkg/tablecodec" @@ -186,7 +187,15 @@ func (e *InsertValues) prefetchDataCache(ctx context.Context, txn kv.Transaction } // updateDupRow updates a duplicate row to a new row. -func (e *InsertExec) updateDupRow(ctx context.Context, idxInBatch int, txn kv.Transaction, row toBeCheckedRow, handle kv.Handle, _ []*expression.Assignment, autoColIdx int) error { +func (e *InsertExec) updateDupRow( + ctx context.Context, + idxInBatch int, + txn kv.Transaction, + row toBeCheckedRow, + handle kv.Handle, + _ []*expression.Assignment, + autoColIdx int, +) error { oldRow, err := getOldRow(ctx, e.Ctx(), txn, row.t, handle, e.GenExprs) if err != nil { return err @@ -385,8 +394,14 @@ func (e *InsertExec) initEvalBuffer4Dup() { } // doDupRowUpdate updates the duplicate row. -func (e *InsertExec) doDupRowUpdate(ctx context.Context, handle kv.Handle, oldRow []types.Datum, newRow []types.Datum, - extraCols []types.Datum, cols []*expression.Assignment, idxInBatch int, autoColIdx int) error { +func (e *InsertExec) doDupRowUpdate( + ctx context.Context, + handle kv.Handle, + oldRow, newRow, extraCols []types.Datum, + assigns []*expression.Assignment, + idxInBatch int, + autoColIdx int, +) error { assignFlag := make([]bool, len(e.Table.WritableCols())) // See http://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values e.curInsertVals.SetDatums(newRow...) @@ -400,40 +415,72 @@ func (e *InsertExec) doDupRowUpdate(ctx context.Context, handle kv.Handle, oldRo e.row4Update = append(e.row4Update, extraCols...) e.row4Update = append(e.row4Update, newRow...) - // Update old row when the key is duplicated. - e.evalBuffer4Dup.SetDatums(e.row4Update...) - sc := e.Ctx().GetSessionVars().StmtCtx - warnCnt := int(sc.WarningCount()) - for _, col := range cols { - if col.LazyErr != nil { - return col.LazyErr - } - val, err1 := col.Expr.Eval(e.evalBuffer4Dup.ToRow()) - if err1 != nil { - return err1 - } - c := col.Col.ToInfo() - c.Name = col.ColName - e.row4Update[col.Col.Index], err1 = table.CastValue(e.Ctx(), val, c, false, false) - if err1 != nil { - return err1 + // Only evaluate non-generated columns here, + // other fields will be evaluated in updateRecord. + var generated, nonGenerated []*expression.Assignment + cols := e.Table.Cols() + for _, assign := range assigns { + if cols[assign.Col.Index].IsGenerated() { + generated = append(generated, assign) + } else { + nonGenerated = append(nonGenerated, assign) } + } + + warnCnt := int(e.Ctx().GetSessionVars().StmtCtx.WarningCount()) + errorHandler := func(sctx sessionctx.Context, assign *expression.Assignment, val *types.Datum, err error) error { + c := assign.Col.ToInfo() + c.Name = assign.ColName + sc := sctx.GetSessionVars().StmtCtx + if newWarnings := sc.TruncateWarnings(warnCnt); len(newWarnings) > 0 { for k := range newWarnings { // Use `idxInBatch` here for simplicity, since the offset of the batch is unknown under the current context. - newWarnings[k].Err = completeInsertErr(c, &val, idxInBatch, newWarnings[k].Err) + newWarnings[k].Err = completeInsertErr(c, val, idxInBatch, newWarnings[k].Err) } sc.AppendWarnings(newWarnings) warnCnt += len(newWarnings) } - e.evalBuffer4Dup.SetDatum(col.Col.Index, e.row4Update[col.Col.Index]) - assignFlag[col.Col.Index] = true + return err + } + + // Update old row when the key is duplicated. + e.evalBuffer4Dup.SetDatums(e.row4Update...) + sctx := e.Ctx() + for _, assign := range nonGenerated { + var val types.Datum + if assign.LazyErr != nil { + return assign.LazyErr + } + val, err := assign.Expr.Eval(e.evalBuffer4Dup.ToRow()) + if err != nil { + return err + } + + c := assign.Col.ToInfo() + idx := assign.Col.Index + c.Name = assign.ColName + val, err = table.CastValue(sctx, val, c, false, false) + if err != nil { + return err + } + + _ = errorHandler(sctx, assign, &val, nil) + e.evalBuffer4Dup.SetDatum(idx, val) + e.row4Update[assign.Col.Index] = val + assignFlag[assign.Col.Index] = true } newData := e.row4Update[:len(oldRow)] - _, err := updateRecord(ctx, e.Ctx(), handle, oldRow, newData, assignFlag, e.Table, true, e.memTracker, e.fkChecks, e.fkCascades) + _, err := updateRecord( + ctx, e.Ctx(), + handle, oldRow, newData, + 0, generated, e.evalBuffer4Dup, errorHandler, + assignFlag, e.Table, + true, e.memTracker, e.fkChecks, e.fkCascades) + if err != nil { - return err + return errors.Trace(err) } if autoColIdx >= 0 { diff --git a/pkg/executor/insert_test.go b/pkg/executor/insert_test.go index 1bc6f6b126f5e..a6336b8c84d47 100644 --- a/pkg/executor/insert_test.go +++ b/pkg/executor/insert_test.go @@ -249,6 +249,24 @@ func testInsertOnDuplicateKey(t *testing.T, tk *testkit.TestKit) { "//x/1.2", "//x/1.2")) + // Test issue 56829 + tk.MustExec(` + CREATE TABLE cache ( + cache_key varchar(512) NOT NULL, + updated_at datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + expired_at datetime GENERATED ALWAYS AS (if(expires > 0, date_add(updated_at, interval expires second), date_add(updated_at, interval 99 year))) VIRTUAL, + expires int(11), + PRIMARY KEY (cache_key) /*T![clustered_index] CLUSTERED */, + KEY idx_c_on_expired_at (expired_at) + )`) + tk.MustExec("INSERT INTO cache(cache_key, expires) VALUES ('2001-01-01 11:11:11', 60) ON DUPLICATE KEY UPDATE expires = expires + 1") + tk.MustExec("select sleep(1)") + tk.MustExec("INSERT INTO cache(cache_key, expires) VALUES ('2001-01-01 11:11:11', 60) ON DUPLICATE KEY UPDATE expires = expires + 1") + tk.MustExec("admin check table cache") + rs1 := tk.MustQuery("select cache_key, expired_at from cache use index() order by cache_key") + rs2 := tk.MustQuery("select cache_key, expired_at from cache use index(idx_c_on_expired_at) order by cache_key") + require.True(t, rs1.Equal(rs2.Rows())) + // reproduce insert on duplicate key update bug under new row format. tk.MustExec(`drop table if exists t1`) tk.MustExec(`create table t1(c1 decimal(6,4), primary key(c1))`) diff --git a/pkg/executor/update.go b/pkg/executor/update.go index a836f720d6a4b..c24e0085f8eb1 100644 --- a/pkg/executor/update.go +++ b/pkg/executor/update.go @@ -20,6 +20,7 @@ import ( "fmt" "runtime/trace" + "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/executor/internal/exec" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/kv" @@ -27,6 +28,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -39,7 +41,8 @@ import ( type UpdateExec struct { exec.BaseExecutor - OrderedList []*expression.Assignment + OrderedList []*expression.Assignment + assignmentsPerTable map[int][]*expression.Assignment // updatedRowKeys is a map for unique (TableAlias, handle) pair. // The value is true if the row is changed, or false otherwise @@ -117,7 +120,7 @@ func (e *UpdateExec) prepare(row []types.Datum) (err error) { return nil } -func (e *UpdateExec) merge(row, newData []types.Datum, mergeGenerated bool) error { +func (e *UpdateExec) mergeNonGenerated(row, newData []types.Datum) error { if e.mergedRowData == nil { e.mergedRowData = make(map[int64]*kv.MemAwareHandleMap[[]types.Datum]) } @@ -148,7 +151,7 @@ func (e *UpdateExec) merge(row, newData []types.Datum, mergeGenerated bool) erro if v, ok := e.mergedRowData[content.TblID].Get(handle); ok { mergedData = v for i, flag := range flags { - if tbl.WritableCols()[i].IsGenerated() != mergeGenerated { + if tbl.WritableCols()[i].IsGenerated() { continue } mergedData[i].Copy(&oldData[i]) @@ -169,12 +172,84 @@ func (e *UpdateExec) merge(row, newData []types.Datum, mergeGenerated bool) erro return nil } -func (e *UpdateExec) exec(ctx context.Context, _ *expression.Schema, row, newData []types.Datum) error { +func (e *UpdateExec) mergeGenerated(row, newData []types.Datum, i int, beforeEval bool) error { + if e.virtualAssignmentsOffset >= len(e.OrderedList) { + return nil + } + + var mergedData []types.Datum + // merge updates from and into mergedRowData + var totalMemDelta int64 + + content := e.tblColPosInfos[i] + + if !e.multiUpdateOnSameTable[content.TblID] { + // No need to merge if not multi-updated + return nil + } + if !e.tableUpdatable[i] { + // If there's nothing to update, we can just skip current row + return nil + } + if e.changed[i] { + // Each matched row is updated once, even if it matches the conditions multiple times. + return nil + } + handle := e.handles[i] + flags := e.assignFlag[content.Start:content.End] + + if e.mergedRowData[content.TblID] == nil { + e.mergedRowData[content.TblID] = kv.NewMemAwareHandleMap[[]types.Datum]() + } + tbl := e.tblID2table[content.TblID] + oldData := row[content.Start:content.End] + newTableData := newData[content.Start:content.End] + + // We don't check the second return value here because we have already called mergeNonGenerated() before. + mergedData, _ = e.mergedRowData[content.TblID].Get(handle) + for i, flag := range flags { + if !tbl.WritableCols()[i].IsGenerated() { + continue + } + // Before evaluate generated columns, + // we need to copy new values to both oldData and newData. + // After evaluation, + /// we need to copy new generated values into mergedData. + if beforeEval { + mergedData[i].Copy(&oldData[i]) + if flag < 0 { + mergedData[i].Copy(&newTableData[i]) + } + } else { + if flag >= 0 { + newTableData[i].Copy(&mergedData[i]) + } + } + } + + memDelta := e.mergedRowData[content.TblID].Set(handle, mergedData) + memDelta += types.EstimatedMemUsage(mergedData, 1) + int64(handle.ExtraMemSize()) + totalMemDelta += memDelta + + e.memTracker.Consume(totalMemDelta) + return nil +} + +func (e *UpdateExec) exec( + ctx context.Context, + _ *expression.Schema, + rowIdx int, row, newData []types.Datum, +) error { defer trace.StartRegion(ctx, "UpdateExec").End() bAssignFlag := make([]bool, len(e.assignFlag)) for i, flag := range e.assignFlag { bAssignFlag[i] = flag >= 0 } + + errorHandler := func(_ sessionctx.Context, assign *expression.Assignment, _ *types.Datum, err error) error { + return handleUpdateError(assign.ColName, rowIdx, err) + } + for i, content := range e.tblColPosInfos { if !e.tableUpdatable[i] { // If there's nothing to update, we can just skip current row @@ -195,10 +270,32 @@ func (e *UpdateExec) exec(ctx context.Context, _ *expression.Schema, row, newDat newTableData := newData[content.Start:content.End] flags := bAssignFlag[content.Start:content.End] + // Evaluate generated columns and write to table. + // Evaluated values will be stored in newRow. + var assignments []*expression.Assignment + if a, ok := e.assignmentsPerTable[i]; ok { + assignments = a + } + + // Copy data from merge row to old and new rows + if err := e.mergeGenerated(row, newData, i, true); err != nil { + return errors.Trace(err) + } + // Update row - fkChecks := e.fkChecks[content.TblID] - fkCascades := e.fkCascades[content.TblID] - changed, err1 := updateRecord(ctx, e.Ctx(), handle, oldData, newTableData, flags, tbl, false, e.memTracker, fkChecks, fkCascades) + changed, err1 := updateRecord( + ctx, e.Ctx(), + handle, oldData, newTableData, + content.Start, assignments, e.evalBuffer, errorHandler, + flags, tbl, false, e.memTracker, + e.fkChecks[content.TblID], + e.fkCascades[content.TblID]) + + // Copy data from new row to merge row + if err := e.mergeGenerated(row, newData, i, false); err != nil { + return errors.Trace(err) + } + if err1 == nil { _, exist := e.updatedRowKeys[content.Start].Get(handle) memDelta := e.updatedRowKeys[content.Start].Set(handle, changed) @@ -261,6 +358,21 @@ func (e *UpdateExec) updateRows(ctx context.Context) (int, error) { } memUsageOfChk := int64(0) totalNumRows := 0 + + if e.virtualAssignmentsOffset < len(e.OrderedList) { + e.assignmentsPerTable = make(map[int][]*expression.Assignment, 0) + for _, assign := range e.OrderedList[e.virtualAssignmentsOffset:] { + tblIdx := e.assignFlag[assign.Col.Index] + if tblIdx < 0 { + continue + } + if _, ok := e.assignmentsPerTable[tblIdx]; !ok { + e.assignmentsPerTable[tblIdx] = make([]*expression.Assignment, 0) + } + e.assignmentsPerTable[tblIdx] = append(e.assignmentsPerTable[tblIdx], assign) + } + } + for { e.memTracker.Consume(-memUsageOfChk) err := exec.Next(ctx, e.Children(0), chk) @@ -301,24 +413,20 @@ func (e *UpdateExec) updateRows(ctx context.Context) (int, error) { return 0, err } // merge non-generated columns - if err := e.merge(datumRow, newRow, false); err != nil { + if err := e.mergeNonGenerated(datumRow, newRow); err != nil { return 0, err } + if e.virtualAssignmentsOffset < len(e.OrderedList) { - // compose generated columns - newRow, err = e.composeGeneratedColumns(globalRowIdx, newRow, colsInfo) - if err != nil { - return 0, err - } - // merge generated columns - if err := e.merge(datumRow, newRow, true); err != nil { - return 0, err - } + e.evalBuffer.SetDatums(newRow...) } - // write to table - if err := e.exec(ctx, e.Children(0).Schema(), datumRow, newRow); err != nil { + + if err := e.exec( + ctx, e.Children(0).Schema(), + globalRowIdx, datumRow, newRow); err != nil { return 0, err } + globalRowIdx++ } totalNumRows += chk.NumRows() chk = chunk.Renew(chk, e.MaxChunkSize()) @@ -326,7 +434,7 @@ func (e *UpdateExec) updateRows(ctx context.Context) (int, error) { return totalNumRows, nil } -func (*UpdateExec) handleErr(colName model.CIStr, rowIdx int, err error) error { +func handleUpdateError(colName model.CIStr, rowIdx int, err error) error { if err == nil { return nil } @@ -351,7 +459,7 @@ func (e *UpdateExec) fastComposeNewRow(rowIdx int, oldRow []types.Datum, cols [] } con := assign.Expr.(*expression.Constant) val, err := con.Eval(emptyRow) - if err = e.handleErr(assign.ColName, rowIdx, err); err != nil { + if err = handleUpdateError(assign.ColName, rowIdx, err); err != nil { return nil, err } @@ -359,7 +467,7 @@ func (e *UpdateExec) fastComposeNewRow(rowIdx int, oldRow []types.Datum, cols [] // No need to cast `_tidb_rowid` column value. if cols[assign.Col.Index] != nil { val, err = table.CastValue(e.Ctx(), val, cols[assign.Col.Index].ColumnInfo, false, false) - if err = e.handleErr(assign.ColName, rowIdx, err); err != nil { + if err = handleUpdateError(assign.ColName, rowIdx, err); err != nil { return nil, err } } @@ -385,43 +493,14 @@ func (e *UpdateExec) composeNewRow(rowIdx int, oldRow []types.Datum, cols []*tab // info of `_tidb_rowid` column is nil. // No need to cast `_tidb_rowid` column value. if cols[assign.Col.Index] != nil { - val, err = table.CastValue(e.Ctx(), val, cols[assign.Col.Index].ColumnInfo, false, false) - if err = e.handleErr(assign.ColName, rowIdx, err); err != nil { - return nil, err - } - } - - val.Copy(&newRowData[assign.Col.Index]) - } - return newRowData, nil -} - -func (e *UpdateExec) composeGeneratedColumns(rowIdx int, newRowData []types.Datum, cols []*table.Column) ([]types.Datum, error) { - if e.allAssignmentsAreConstant { - return newRowData, nil - } - e.evalBuffer.SetDatums(newRowData...) - for _, assign := range e.OrderedList[e.virtualAssignmentsOffset:] { - tblIdx := e.assignFlag[assign.Col.Index] - if tblIdx >= 0 && !e.tableUpdatable[tblIdx] { - continue - } - val, err := assign.Expr.Eval(e.evalBuffer.ToRow()) - if err = e.handleErr(assign.ColName, rowIdx, err); err != nil { - return nil, err - } - - // info of `_tidb_rowid` column is nil. - // No need to cast `_tidb_rowid` column value. - if cols[assign.Col.Index] != nil { - val, err = table.CastValue(e.Ctx(), val, cols[assign.Col.Index].ColumnInfo, false, false) - if err = e.handleErr(assign.ColName, rowIdx, err); err != nil { + colInfo := cols[assign.Col.Index].ColumnInfo + val, err = table.CastValue(e.Ctx(), val, colInfo, false, false) + if err = handleUpdateError(assign.ColName, rowIdx, err); err != nil { return nil, err } } val.Copy(&newRowData[assign.Col.Index]) - e.evalBuffer.SetDatum(assign.Col.Index, val) } return newRowData, nil } diff --git a/pkg/executor/update_test.go b/pkg/executor/update_test.go index fdbc1590e6933..41eae451d23ac 100644 --- a/pkg/executor/update_test.go +++ b/pkg/executor/update_test.go @@ -655,3 +655,26 @@ func TestLockUnchangedUniqueKeys(t *testing.T) { } } } + +func TestUpdateWithOnUpdateAndAutoGenerated(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewSteppedTestKit(t, store) + tk.MustExec("use test") + + tk.MustExec(` + CREATE TABLE cache ( + cache_key varchar(512) NOT NULL, + updated_at datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + expired_at datetime GENERATED ALWAYS AS (if(expires > 0, date_add(updated_at, interval expires second), date_add(updated_at, interval 99 year))) VIRTUAL, + expires int(11), + PRIMARY KEY (cache_key) /*T![clustered_index] CLUSTERED */, + KEY idx_c_on_expired_at (expired_at) + )`) + tk.MustExec("INSERT INTO cache(cache_key, expires) VALUES ('2001-01-01 11:11:11', 60) ON DUPLICATE KEY UPDATE expires = expires + 1") + tk.MustExec("select sleep(1)") + tk.MustExec("UPDATE cache SET expires = expires + 1 WHERE cache_key = '2001-01-01 11:11:11';") + tk.MustExec("admin check table cache") + rs1 := tk.MustQuery("select cache_key, expired_at from cache use index(idx_c_on_expired_at) order by cache_key") + rs2 := tk.MustQuery("select cache_key, expired_at from cache use index() order by cache_key") + require.True(t, rs1.Equal(rs2.Rows())) +} diff --git a/pkg/executor/write.go b/pkg/executor/write.go index 39187586979d5..1dcba29d59f93 100644 --- a/pkg/executor/write.go +++ b/pkg/executor/write.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/collate" "github.com/pingcap/tidb/pkg/util/memory" "github.com/pingcap/tidb/pkg/util/tracing" @@ -47,88 +48,127 @@ var ( _ exec.Executor = &LoadDataExec{} ) -// updateRecord updates the row specified by the handle `h`, from `oldData` to `newData`. -// `modified` means which columns are really modified. It's used for secondary indices. -// Length of `oldData` and `newData` equals to length of `t.WritableCols()`. -// The return values: -// 1. changed (bool) : does the update really change the row values. e.g. update set i = 1 where i = 1; -// 2. err (error) : error in the update. +/* + * updateRecord updates the row specified by the handle `h`, from `oldData` to `newData`. + * It is used both in update/insert on duplicate statements. + * + * The `modified` inputed indicates whether columns are explicitly set. + * And this slice will be reused in this function to record which columns are really modified, which is used for secondary indices. + * + * offset, assignments, evalBuffer and errorHandler are used to update auto-generated columns. + * We need to evaluate assignments, and set the result value in newData and evalBuffer respectively. + * Since the column indices in assignments are based on evalbuffer, and newData may be a subset of evalBuffer, + * offset is needed when assigning to newData. + * + * |<---- newData ---->| + * ------------------------------------------------------- + * | t1 | t1 | t3 | + * ------------------------------------------------------- + * |<------------------ evalBuffer ---|----------------->| + * | + * | + * |<------------------------- assign.Col.Idx + * + * Length of `oldData` and `newData` equals to length of `t.Cols()`. + * + * The return values: + * 1. changed (bool): does the update really change the row values. e.g. update set i = 1 where i = 1; + * 2. err (error): error in the update. + */ func updateRecord( - ctx context.Context, sctx sessionctx.Context, h kv.Handle, oldData, newData []types.Datum, modified []bool, + ctx context.Context, sctx sessionctx.Context, + h kv.Handle, oldData, newData []types.Datum, + offset int, + assignments []*expression.Assignment, + evalBuffer chunk.MutRow, + errorHandler func(sctx sessionctx.Context, assign *expression.Assignment, val *types.Datum, err error) error, + modified []bool, t table.Table, - onDup bool, _ *memory.Tracker, fkChecks []*FKCheckExec, fkCascades []*FKCascadeExec, -) (bool, error) { + onDup bool, + _ *memory.Tracker, + fkChecks []*FKCheckExec, + fkCascades []*FKCascadeExec, +) (changed bool, retErr error) { r, ctx := tracing.StartRegionEx(ctx, "executor.updateRecord") defer r.End() - sc := sctx.GetSessionVars().StmtCtx + sessVars := sctx.GetSessionVars() + sc := sessVars.StmtCtx + + // changed, handleChanged indicated whether row/handle is changed changed, handleChanged := false, false - // onUpdateSpecified is for "UPDATE SET ts_field = old_value", the - // timestamp field is explicitly set, but not changed in fact. - onUpdateSpecified := make(map[int]bool) + // onUpdateNeedModify is for "UPDATE SET ts_field = old_value". + // If the on-update-now timestamp field is explicitly set, we don't need to update it again. + onUpdateNeedModify := make(map[int]bool) // We can iterate on public columns not writable columns, // because all of them are sorted by their `Offset`, which // causes all writable columns are after public columns. + cols := t.Cols() - // Handle the bad null error. - for i, col := range t.Cols() { - var err error - if err = col.HandleBadNull(&newData[i], sc, 0); err != nil { - return false, err + // A wrapper function to check whether certain column is changed after evaluation. + checkColumnFunc := func(i int, skipGenerated bool) error { + col := cols[i] + if col.IsGenerated() && skipGenerated { + return nil } - } - // Handle exchange partition - tbl := t.Meta() - if tbl.ExchangePartitionInfo != nil && tbl.GetPartitionInfo() == nil { - if err := checkRowForExchangePartition(sctx, newData, tbl); err != nil { - return false, err + // modified[i] == false means this on-update-now field is not explicited set. + if mysql.HasOnUpdateNowFlag(col.GetFlag()) { + onUpdateNeedModify[i] = !modified[i] } - } - // Compare datum, then handle some flags. - for i, col := range t.Cols() { // We should use binary collation to compare datum, otherwise the result will be incorrect. cmp, err := newData[i].Compare(sc, &oldData[i], collate.GetBinaryCollator()) if err != nil { - return false, err + return err } + modified[i] = cmp != 0 if cmp != 0 { changed = true - modified[i] = true // Rebase auto increment id if the field is changed. if mysql.HasAutoIncrementFlag(col.GetFlag()) { recordID, err := getAutoRecordID(newData[i], &col.FieldType, false) if err != nil { - return false, err + return err } if err = t.Allocators(sctx).Get(autoid.AutoIncrementType).Rebase(ctx, recordID, true); err != nil { - return false, err + return err } } if col.IsPKHandleColumn(t.Meta()) { handleChanged = true // Rebase auto random id if the field is changed. if err := rebaseAutoRandomValue(ctx, sctx, t, &newData[i], col); err != nil { - return false, err + return err } } if col.IsCommonHandleColumn(t.Meta()) { handleChanged = true } - } else { - if mysql.HasOnUpdateNowFlag(col.GetFlag()) && modified[i] { - // It's for "UPDATE t SET ts = ts" and ts is a timestamp. - onUpdateSpecified[i] = true - } - modified[i] = false + } + + return nil + } + + // Before do actual update, We need to ensure that all columns are evaluated in the following order: + // Step 1: non-generated columns (These columns should be evaluated outside this function). + // Step 2: check whether there are some columns changed. + // Step 3: on-update-now columns if non-generated columns are changed. + // Step 4: generated columns if non-generated columns are changed. + // Step 5: handle foreign key errors, bad null errors and exchange partition errors. + // After these are done, we can finally update the record. + + // Step 2: compare already evaluated columns and update changed, handleChanged and handleChanged flags. + for i := range cols { + if err := checkColumnFunc(i, true); err != nil { + return false, err } } - sc.AddTouchedRows(1) // If no changes, nothing to do, return directly. if !changed { + sc.AddTouchedRows(1) // See https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS if sctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 { sc.AddAffectedRows(1) @@ -141,15 +181,19 @@ func updateRecord( return false, err } - // Fill values into on-update-now fields, only if they are really changed. + // Step 3: fill values into on-update-now fields. for i, col := range t.Cols() { - if mysql.HasOnUpdateNowFlag(col.GetFlag()) && !modified[i] && !onUpdateSpecified[i] { - v, err := expression.GetTimeValue(sctx, strings.ToUpper(ast.CurrentTimestamp), col.GetType(), col.GetDecimal(), nil) + var err error + if mysql.HasOnUpdateNowFlag(col.GetFlag()) && onUpdateNeedModify[i] { + newData[i], err = expression.GetTimeValue(sctx, strings.ToUpper(ast.CurrentTimestamp), col.GetType(), col.GetDecimal(), nil) + modified[i] = true + // For update statement, evalBuffer is initialized on demand. + if chunk.Row(evalBuffer).Chunk() != nil { + evalBuffer.SetDatum(i+offset, newData[i]) + } if err != nil { return false, err } - newData[i] = v - modified[i] = true // Only TIMESTAMP and DATETIME columns can be automatically updated, so it cannot be PKIsHandle. // Ref: https://dev.mysql.com/doc/refman/8.0/en/timestamp-initialization.html if col.IsPKHandleColumn(t.Meta()) { @@ -161,6 +205,48 @@ func updateRecord( } } + // Step 4: fill auto generated columns + for _, assign := range assignments { + // Insert statements may have LazyErr, handle it first. + if assign.LazyErr != nil { + return false, assign.LazyErr + } + + // For Update statements, Index may be larger than len(newData) + // e.g. update t a, t b set a.c1 = 1, b.c2 = 2; + idxInCols := assign.Col.Index - offset + rawVal, err := assign.Expr.Eval(evalBuffer.ToRow()) + if err == nil { + newData[idxInCols], err = table.CastValue(sctx, rawVal, assign.Col.ToInfo(), false, false) + } + evalBuffer.SetDatum(assign.Col.Index, newData[idxInCols]) + + err = errorHandler(sctx, assign, &rawVal, err) + if err != nil { + return false, err + } + + if err := checkColumnFunc(idxInCols, false); err != nil { + return false, err + } + } + + // Step 5: handle bad null errors and exchange partition errors. + for i, col := range t.Cols() { + var err error + if err = col.HandleBadNull(&newData[i], sc, 0); err != nil { + return false, err + } + } + + tbl := t.Meta() + if tbl.ExchangePartitionInfo != nil && tbl.GetPartitionInfo() == nil { + if err := checkRowForExchangePartition(sctx, newData, tbl); err != nil { + return false, err + } + } + + sc.AddTouchedRows(1) // If handle changed, remove the old then add the new record, otherwise update the record. if handleChanged { // For `UPDATE IGNORE`/`INSERT IGNORE ON DUPLICATE KEY UPDATE`