From 49a31f1402614c6e7fbcd9119b059ba8edc1a8ab Mon Sep 17 00:00:00 2001 From: Lynn Date: Thu, 23 Apr 2020 15:18:01 +0800 Subject: [PATCH] *: fix the problem that PointGet returns wrong results in the case of overflow (#14776) --- executor/point_get_test.go | 40 +++++++---- planner/core/point_get_plan.go | 121 +++++++++++++++++++-------------- util/testkit/testkit.go | 2 +- 3 files changed, 97 insertions(+), 66 deletions(-) diff --git a/executor/point_get_test.go b/executor/point_get_test.go index 73bca25ea6f98..89a09d9eade62 100644 --- a/executor/point_get_test.go +++ b/executor/point_get_test.go @@ -141,10 +141,10 @@ func (s *testPointGetSuite) TestPointGetCharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows(`a b`)) tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows()) - tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t where a = "";`).Check(testkit.Rows(` `)) tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows()) - tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows()) } @@ -157,7 +157,7 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) - tk.MustPointGet(`select * from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t tmp where a = "aab";`).Check(testkit.Rows()) tk.MustExec(`truncate table t;`) tk.MustExec(`insert into t values("a ", "b ");`) @@ -165,7 +165,7 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`)) tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) - tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) // Test CHAR BINARY. tk.MustExec(`drop table if exists t;`) @@ -176,10 +176,10 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`)) tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) - tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t tmp where a = "";`).Check(testkit.Rows(` `)) tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows()) - tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t tmp where a = " ";`).Check(testkit.Rows()) // Test both wildcard and column name exist in select field list tk.MustExec(`set @@sql_mode="";`) @@ -192,9 +192,9 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) { tk.MustPointGet(`select tmp.* from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb aa bb`)) - tk.MustPointGet(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows()) - tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows()) - tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustTableDual(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustTableDual(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustTableDual(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows()) // Test using table alias in where clause tk.MustPointGet(`select * from t tmp where tmp.a = "aa";`).Check(testkit.Rows(`aa bb`)) @@ -265,7 +265,7 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t where a = "aa";`).Check(testkit.Rows(`aa bb`)) - tk.MustPointGet(`select * from t where a = "aab";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = "aab";`).Check(testkit.Rows()) tk.MustExec(`truncate table t;`) tk.MustExec(`insert into t values("a ", "b ");`) @@ -273,7 +273,7 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `)) - tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows()) // // Test VARCHAR BINARY. tk.MustExec(`drop table if exists t;`) @@ -284,10 +284,10 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `)) - tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows(` `)) - tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows()) } @@ -368,6 +368,20 @@ func (s *testPointGetSuite) TestIndexLookupBinary(c *C) { } +func (s *testPointGetSuite) TestOverflowOrTruncated(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create table t6 (id bigint, a bigint, primary key(id), unique key(a));") + tk.MustExec("insert into t6 values(9223372036854775807, 9223372036854775807);") + tk.MustExec("insert into t6 values(1, 1);") + var nilVal []string + // for unique key + tk.MustQuery("select * from t6 where a = 9223372036854775808").Check(testkit.Rows(nilVal...)) + tk.MustQuery("select * from t6 where a = '1.123'").Check(testkit.Rows(nilVal...)) + // for primary key + tk.MustQuery("select * from t6 where id = 9223372036854775808").Check(testkit.Rows(nilVal...)) + tk.MustQuery("select * from t6 where id = '1.123'").Check(testkit.Rows(nilVal...)) +} + func (s *testPointGetSuite) TestIssue10448(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 3c01372b193ec..e01deb38ab1e5 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -19,6 +19,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/charset" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" @@ -652,6 +653,9 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP return nil } pi := tbl.GetPartitionInfo() + if pi != nil && pi.Type != model.PartitionTypeHash { + return nil + } for _, col := range tbl.Columns { // Do not handle generated columns. if col.IsGenerated() { @@ -662,53 +666,40 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP return nil } } + schema, names := buildSchemaFromFields(tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields) + if schema == nil { + return nil + } + dbName := tblName.Schema.L + if dbName == "" { + dbName = ctx.GetSessionVars().CurrentDB + } + pairs := make([]nameValuePair, 0, 4) - pairs = getNameValuePairs(pairs, tblAlias, selStmt.Where) - if pairs == nil { + pairs, isTableDual := getNameValuePairs(ctx.GetSessionVars().StmtCtx, tbl, tblAlias, pairs, selStmt.Where) + if pairs == nil && !isTableDual { return nil } var partitionInfo *model.PartitionDefinition var pos int if pi != nil { - if pi.Type != model.PartitionTypeHash { - return nil - } partitionInfo, pos = getPartitionInfo(ctx, tbl, pairs) if partitionInfo == nil { return nil } } + handlePair, fieldType := findPKHandle(tbl, pairs) if handlePair.value.Kind() != types.KindNull && len(pairs) == 1 { - schema, names := buildSchemaFromFields(tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields) - if schema == nil { - return nil - } - dbName := tblName.Schema.L - if dbName == "" { - dbName = ctx.GetSessionVars().CurrentDB - } - p := newPointGetPlan(ctx, dbName, schema, tbl, names) - intDatum, err := handlePair.value.ConvertTo(ctx.GetSessionVars().StmtCtx, fieldType) - if err != nil { - if terror.ErrorEqual(types.ErrOverflow, err) { - p.IsTableDual = true - return p - } - // some scenarios cast to int with error, but we may use this value in point get - if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) { - return nil - } - } - cmp, err := intDatum.CompareDatum(ctx.GetSessionVars().StmtCtx, &handlePair.value) - if err != nil { - return nil - } else if cmp != 0 { + if isTableDual { + p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl, names) p.IsTableDual = true return p } - p.Handle = intDatum.GetInt64() + + p := newPointGetPlan(ctx, dbName, schema, tbl, names) + p.Handle = handlePair.value.GetInt64() p.UnsignedHandle = mysql.HasUnsignedFlag(fieldType.Flag) p.HandleParam = handlePair.param p.PartitionInfo = partitionInfo @@ -722,18 +713,16 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if idxInfo.State != model.StatePublic { continue } + if isTableDual { + p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl, names) + p.IsTableDual = true + return p + } + idxValues, idxValueParams := getIndexValues(idxInfo, pairs) if idxValues == nil { continue } - schema, names := buildSchemaFromFields(tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields) - if schema == nil { - return nil - } - dbName := tblName.Schema.L - if dbName == "" { - dbName = ctx.GetSessionVars().CurrentDB - } p := newPointGetPlan(ctx, dbName, schema, tbl, names) p.IndexInfo = idxInfo p.IndexValues = idxValues @@ -864,21 +853,22 @@ func getSingleTableNameAndAlias(tableRefs *ast.TableRefsClause) (tblName *ast.Ta } // getNameValuePairs extracts `column = constant/paramMarker` conditions from expr as name value pairs. -func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.ExprNode) []nameValuePair { +func getNameValuePairs(stmtCtx *stmtctx.StatementContext, tbl *model.TableInfo, tblName model.CIStr, nvPairs []nameValuePair, expr ast.ExprNode) ( + pairs []nameValuePair, isTableDual bool) { binOp, ok := expr.(*ast.BinaryOperationExpr) if !ok { - return nil + return nil, false } if binOp.Op == opcode.LogicAnd { - nvPairs = getNameValuePairs(nvPairs, tblName, binOp.L) - if nvPairs == nil { - return nil + nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.L) + if nvPairs == nil || isTableDual { + return nil, isTableDual } - nvPairs = getNameValuePairs(nvPairs, tblName, binOp.R) - if nvPairs == nil { - return nil + nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.R) + if nvPairs == nil || isTableDual { + return nil, isTableDual } - return nvPairs + return nvPairs, isTableDual } else if binOp.Op == opcode.EQ { var d types.Datum var colName *ast.ColumnNameExpr @@ -901,17 +891,44 @@ func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.Ex param = x } } else { - return nil + return nil, false } if d.IsNull() { - return nil + return nil, false + } + // Views' columns have no FieldType. + if tbl.IsView() { + return nil, false } if colName.Name.Table.L != "" && colName.Name.Table.L != tblName.L { - return nil + return nil, false + } + col := model.FindColumnInfo(tbl.Cols(), colName.Name.Name.L) + if col == nil || // Handling the case when the column is _tidb_rowid. + (col.Tp == mysql.TypeString && col.Collate == charset.CollationBin) { // This type we needn't to pad `\0` in here. + return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), false } - return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}) + dVal, err := d.ConvertTo(stmtCtx, &col.FieldType) + if err != nil { + if terror.ErrorEqual(types.ErrOverflow, err) { + return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), true + } + // Some scenarios cast to int with error, but we may use this value in point get. + if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) { + return nil, false + } + } + // The converted result must be same as original datum. + cmp, err := d.CompareDatum(stmtCtx, &dVal) + if err != nil { + return nil, false + } else if cmp != 0 { + return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), true + } + + return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), false } - return nil + return nil, false } func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair, fieldType *types.FieldType) { diff --git a/util/testkit/testkit.go b/util/testkit/testkit.go index 87edb0a020ecd..cd7c389385437 100644 --- a/util/testkit/testkit.go +++ b/util/testkit/testkit.go @@ -224,7 +224,7 @@ func (tk *TestKit) MustTableDual(sql string, args ...interface{}) *Result { func (tk *TestKit) MustPointGet(sql string, args ...interface{}) *Result { rs := tk.MustQuery("explain "+sql, args...) tk.c.Assert(len(rs.rows), check.Equals, 1) - tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue) + tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue, check.Commentf("plan %v", rs.rows[0][0])) return tk.MustQuery(sql, args...) }