diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 1047bb2457..95c99cc05c 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -49,6 +49,7 @@ const ( ruleId_ValidateCreateSchema // validateCreateSchema ruleId_ResolveAlterColumn // resolveAlterColumn ruleId_ValidateCreateFunction + ruleId_ResolveValuesTypes // resolveValuesTypes ) // Init adds additional rules to the analyzer to handle Doltgres-specific functionality. @@ -56,6 +57,7 @@ func Init() { analyzer.AlwaysBeforeDefault = append(analyzer.AlwaysBeforeDefault, analyzer.Rule{Id: ruleId_ResolveType, Apply: ResolveType}, analyzer.Rule{Id: ruleId_TypeSanitizer, Apply: TypeSanitizer}, + analyzer.Rule{Id: ruleId_ResolveValuesTypes, Apply: ResolveValuesTypes}, analyzer.Rule{Id: ruleId_GenerateForeignKeyName, Apply: generateForeignKeyName}, analyzer.Rule{Id: ruleId_AddDomainConstraints, Apply: AddDomainConstraints}, analyzer.Rule{Id: ruleId_ValidateColumnDefaults, Apply: ValidateColumnDefaults}, diff --git a/server/analyzer/resolve_values_types.go b/server/analyzer/resolve_values_types.go new file mode 100644 index 0000000000..5e6cf08564 --- /dev/null +++ b/server/analyzer/resolve_values_types.go @@ -0,0 +1,261 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package analyzer + +import ( + "strings" + + "github.com/cockroachdb/errors" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/transform" + + pgexprs "github.com/dolthub/doltgresql/server/expression" + "github.com/dolthub/doltgresql/server/functions/framework" + pgtransform "github.com/dolthub/doltgresql/server/transform" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// ResolveValuesTypes determines the common type for each column in a VALUES clause +// by examining all rows, following PostgreSQL's type resolution rules. +// This ensures VALUES(1),(2.01),(3) correctly infers numeric type, not integer. +func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { + // Walk the tree and wrap mixed-type VALUES columns with ImplicitCast. + // We record which VDTs changed so we can fix up GetField types afterward. + transformedVDTs := make(map[sql.TableId]sql.Schema) + node, same, err := transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + newNode, same, err := transformValuesNode(n) + if err != nil { + return nil, same, err + } + if !same { + if vdt, ok := newNode.(*plan.ValueDerivedTable); ok { + transformedVDTs[vdt.Id()] = vdt.Schema() + } + } + return newNode, same, err + }) + if err != nil { + return nil, transform.SameTree, err + } + + // Now, fix GetField types that reference a transformed VDT. For example, + // after wrapping VALUES(1),(2.5) with ImplicitCast to numeric, any + // GetField reading column "n" from that VDT still says int4 and needs + // to be updated to numeric. + if len(transformedVDTs) > 0 { + node, _, err = pgtransform.NodeExprsWithOpaque(node, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + gf, ok := expr.(*expression.GetField) + if !ok { + return expr, transform.SameTree, nil + } + newSch, ok := transformedVDTs[gf.TableId()] + if !ok { + return expr, transform.SameTree, nil + } + + // We match by column name because GetField indices are global + // across all tables in a JOIN (e.g., a.n=0, b.id=1, b.label=2). + // We can't convert a global index to a per-table position without + // knowing the table's starting offset, which we don't have here. + schemaIdx := -1 + for i, col := range newSch { + if col.Name == gf.Name() { + schemaIdx = i + break + } + } + if schemaIdx < 0 { + return expr, transform.SameTree, nil + } + + newType := newSch[schemaIdx].Type + if gf.Type() == newType { + return expr, transform.SameTree, nil + } + + return expression.NewGetFieldWithTable( + gf.Index(), int(gf.TableId()), newType, + gf.Database(), gf.Table(), gf.Name(), gf.IsNullable(), + ), transform.NewTree, nil + }) + if err != nil { + return nil, transform.SameTree, err + } + + // The pass above only fixed GetFields that read directly from a VDT + // (matched by tableId). But changing a VDT column's type can have a + // ripple effect: if that column feeds into an aggregate like MIN or + // MAX, the aggregate's return type changes too. Parent nodes that + // read the aggregate result still have the old type. For example: + // + // SELECT MIN(n) FROM (VALUES(1),(2.5)) v(n) + // + // Project [GetField("min(v.n)", tableId=GroupBy, type=int4)] + // └── GroupBy [MIN(GetField("n", tableId=VDT, type=numeric))] + // └── VDT [n: int4 → numeric] + // + // The pass above fixed "n" inside MIN because its tableId=VDT. + // MIN now returns numeric, so GroupBy produces numeric. But the + // Project's GetField still says int4 because its tableId=GroupBy, + // which wasn't in transformedVDTs. At runtime this causes a panic + // because the actual value is decimal.Decimal but the type says int32. + // + // This pass catches those: for each GetField, check if its type + // disagrees with what the child node actually produces. + node, _, err = pgtransform.NodeExprsWithNodeWithOpaque(node, func(n sql.Node, expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + gf, ok := expr.(*expression.GetField) + if !ok { + return expr, transform.SameTree, nil + } + // Skip VDT GetFields — the first pass already handled these + if _, isVDT := transformedVDTs[gf.TableId()]; isVDT { + return expr, transform.SameTree, nil + } + // Collect the schema that this node's children produce + var childSchema sql.Schema + for _, child := range n.Children() { + childSchema = append(childSchema, child.Schema()...) + } + // TODO: resolve GMS case asymmetry issues. + // GMS has a casing asymmetry for aggregate names that forces + // case-insensitive matching here. GMS's Builder.buildAggregateFunc() + // in planbuilder/aggregates.go lowercases the entire aggregate + // name producing "sum(v.n)", but GroupBy.Schema() in + // plan/group_by.go keeps original casing from e.String() + // producing "SUM(v.n)". Without strings.ToLower, the match + // fails silently and aggregate type propagation breaks, causing + // runtime panics (interface conversion: interface {} is + // decimal.Decimal, not int32). + // + // We can't use non-name matching because sql.Column has no + // ColumnId field, so there is nothing on the child schema side + // to match against GetField.Id(). Name is the only shared + // identifier. + // + // This causes a known false-match when two quoted column names + // differ only by case (e.g., "Val" vs "val"), since the + // planbuilder has already lowered both to the same GetField + // name. GMS originated as a MySQL engine where identifiers are + // case-insensitive, but Postgres requires case-sensitivity for + // quoted identifiers. A proper fix requires structured + // case-sensitivity discrimination in GMS, either by adding + // ColumnId to sql.Column or by fixing the casing asymmetry in + // Builder.buildAggregateFunc() and GroupBy.Schema(). + gfName := strings.ToLower(gf.Name()) + for _, col := range childSchema { + if strings.ToLower(col.Name) == gfName && gf.Type() != col.Type { + return expression.NewGetFieldWithTable( + gf.Index(), int(gf.TableId()), col.Type, + gf.Database(), gf.Table(), gf.Name(), gf.IsNullable(), + ), transform.NewTree, nil + } + } + return expr, transform.SameTree, nil + }) + if err != nil { + return nil, transform.SameTree, err + } + } + + return node, same, nil +} + +// transformValuesNode transforms a plan.Values or plan.ValueDerivedTable node to use common types +func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + var values *plan.Values + var expressionerNode sql.Expressioner + switch v := n.(type) { + case *plan.ValueDerivedTable: + values = v.Values + expressionerNode = v + case *plan.Values: + values = v + expressionerNode = v + default: + return n, transform.SameTree, nil + } + + // Skip if no rows or single row (nothing to unify) + if len(values.ExpressionTuples) <= 1 { + return n, transform.SameTree, nil + } + numCols := len(values.ExpressionTuples[0]) + for i := 1; i < len(values.ExpressionTuples); i++ { + if len(values.ExpressionTuples[i]) != numCols { + return nil, transform.NewTree, errors.Errorf("VALUES: row %d has %d columns, expected %d", i+1, len(values.ExpressionTuples[i]), numCols) + } + } + if numCols == 0 { + return n, transform.SameTree, nil + } + + // Collect types for each column across all rows + columnTypes := make([][]*pgtypes.DoltgresType, numCols) + for colIdx := 0; colIdx < numCols; colIdx++ { + columnTypes[colIdx] = make([]*pgtypes.DoltgresType, len(values.ExpressionTuples)) + for rowIdx, row := range values.ExpressionTuples { + exprType := row[colIdx].Type() + if exprType == nil { + columnTypes[colIdx][rowIdx] = pgtypes.Unknown + } else if pgType, ok := exprType.(*pgtypes.DoltgresType); ok { + columnTypes[colIdx][rowIdx] = pgType + } else { + return nil, transform.NewTree, errors.Errorf("VALUES: non-Doltgres type found in row %d, column %d: %s", rowIdx, colIdx, exprType.String()) + } + } + } + + // Find common type for each column + var newTuples [][]sql.Expression + for colIdx := 0; colIdx < numCols; colIdx++ { + commonType, requiresCasts, err := framework.FindCommonType(columnTypes[colIdx]) + if err != nil { + return nil, transform.NewTree, err + } + // If we require any casts, then we'll add casting to all expressions in the list + if requiresCasts { + if len(newTuples) == 0 { + // Deep copy to avoid mutating the original expression tuples. + newTuples = make([][]sql.Expression, len(values.ExpressionTuples)) + for i, row := range values.ExpressionTuples { + newTuples[i] = make([]sql.Expression, len(row)) + copy(newTuples[i], row) + } + } + for rowIdx := 0; rowIdx < len(newTuples); rowIdx++ { + newTuples[rowIdx][colIdx] = pgexprs.NewImplicitCast( + newTuples[rowIdx][colIdx], columnTypes[colIdx][rowIdx], commonType) + } + } + } + // If we didn't require any casts, then we can simply return our old node + if len(newTuples) == 0 { + return n, transform.SameTree, nil + } + + // Flatten the new tuples into a single expression slice for WithExpressions + flatExprs := make([]sql.Expression, 0, len(newTuples)*len(newTuples[0])) + for _, row := range newTuples { + flatExprs = append(flatExprs, row...) + } + newNode, err := expressionerNode.WithExpressions(flatExprs...) + if err != nil { + return nil, transform.NewTree, err + } + return newNode, transform.NewTree, nil +} diff --git a/server/expression/explicit_cast.go b/server/expression/explicit_cast.go index 36186cf18a..f506db768e 100644 --- a/server/expression/explicit_cast.go +++ b/server/expression/explicit_cast.go @@ -96,8 +96,11 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { baseCastToType := checkForDomainType(c.castToType) castFunction := framework.GetExplicitCast(fromType, baseCastToType) if castFunction == nil { - return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s", - fromType.String(), c.castToType.String(), c.sqlChild.String()) + return nil, errors.Errorf( + "EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s", + fromType.String(), c.castToType.String(), c.sqlChild.String(), + ) + } castResult, err := castFunction(ctx, val, c.castToType) if err != nil { diff --git a/testing/bats/types.bats b/testing/bats/types.bats index aaeb6a80b3..258432b9b6 100644 --- a/testing/bats/types.bats +++ b/testing/bats/types.bats @@ -37,4 +37,4 @@ SQL [[ "$output" =~ '4,"{f,f}"' ]] || false [[ "$output" =~ '5,{t}' ]] || false [[ "$output" =~ '6,{f}' ]] || false -} +} \ No newline at end of file diff --git a/testing/bats/values.bats b/testing/bats/values.bats new file mode 100644 index 0000000000..834ca857e6 --- /dev/null +++ b/testing/bats/values.bats @@ -0,0 +1,219 @@ +#!/usr/bin/env bats +load $BATS_TEST_DIRNAME/setup/common.bash + +setup() { + setup_common + start_sql_server + +} + +teardown() { + teardown_common +} + +@test 'values: mixed int and decimal' { + # Integer first, then decimal - should resolve to numeric + run query_server -t -c "SELECT * FROM (VALUES(1),(2.01),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "2.01" ]] || false + [[ "$output" =~ "3" ]] || false +} + +@test 'values: decimal first then int' { + # Decimal first, then integers - should resolve to numeric + run query_server -t -c "SELECT * FROM (VALUES(1.01),(2),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1.01" ]] || false + [[ "$output" =~ "2" ]] || false + [[ "$output" =~ "3" ]] || false +} + +@test 'values: SUM with mixed types' { + # SUM should work directly now that VALUES has correct type + run query_server -t -c "SELECT SUM(n) FROM (VALUES(1),(2.01),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "6.01" ]] || false +} + +@test 'values: multiple columns mixed types' { + run query_server -t -c "SELECT * FROM (VALUES(1, 'a'), (2.5, 'b')) v(num, str);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "a" ]] || false + [[ "$output" =~ "2.5" ]] || false + [[ "$output" =~ "b" ]] || false +} + +@test 'values: SUM with explicit cast' { + run query_server -t -c "SELECT SUM(n::numeric) FROM (VALUES(1),(2.01),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "6.01" ]] || false +} + +@test 'values: MIN and MAX with mixed types' { + run query_server -t -c "SELECT MIN(n), MAX(n) FROM (VALUES(1),(2.5),(3),(0.5)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "0.5" ]] || false + [[ "$output" =~ "3" ]] || false +} + +@test 'values: GROUP BY with mixed types' { + run query_server -t -c "SELECT n, COUNT(*) FROM (VALUES(1),(2.5),(1),(3.5),(2.5)) v(n) GROUP BY n ORDER BY n;" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "2.5" ]] || false + [[ "$output" =~ "3.5" ]] || false +} + +@test 'values: SUM GROUP BY with mixed types' { + run query_server -t -c "SELECT category, SUM(amount) FROM (VALUES('a', 1),('b', 2.5),('a', 3),('b', 4.5)) v(category, amount) GROUP BY category ORDER BY category;" + [ "$status" -eq 0 ] + [[ "$output" =~ "a" ]] || false + [[ "$output" =~ "4" ]] || false + [[ "$output" =~ "b" ]] || false + [[ "$output" =~ "7.0" ]] || false +} + +@test 'values: DISTINCT with mixed types' { + run query_server -t -c "SELECT DISTINCT n FROM (VALUES(1),(2.5),(1),(2.5),(3)) v(n) ORDER BY n;" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "2.5" ]] || false + [[ "$output" =~ "3" ]] || false +} + +@test 'values: ORDER BY with mixed types' { + run query_server -t -c "SELECT * FROM (VALUES(3),(1.5),(2),(4.5)) v(n) ORDER BY n;" + [ "$status" -eq 0 ] + [[ "$output" =~ "1.5" ]] || false + [[ "$output" =~ "2" ]] || false + [[ "$output" =~ "3" ]] || false + [[ "$output" =~ "4.5" ]] || false +} + +@test 'values: ORDER BY DESC with mixed types' { + run query_server -t -c "SELECT * FROM (VALUES(3),(1.5),(2),(4.5)) v(n) ORDER BY n DESC;" + [ "$status" -eq 0 ] + [[ "$output" =~ "4.5" ]] || false + [[ "$output" =~ "3" ]] || false + [[ "$output" =~ "2" ]] || false + [[ "$output" =~ "1.5" ]] || false +} + +@test 'values: LIMIT with mixed types' { + run query_server -t -c "SELECT * FROM (VALUES(1),(2.5),(3),(4.5),(5)) v(n) LIMIT 3;" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "2.5" ]] || false + [[ "$output" =~ "3" ]] || false + ! [[ "$output" =~ "4.5" ]] || false + ! [[ "$output" =~ " 5" ]] || false +} + +@test 'values: WHERE filter with mixed types' { + run query_server -t -c "SELECT * FROM (VALUES(1),(2.5),(3),(4.5),(5)) v(n) WHERE n > 2;" + [ "$status" -eq 0 ] + [[ "$output" =~ "2.5" ]] || false + [[ "$output" =~ "3" ]] || false + [[ "$output" =~ "4.5" ]] || false + [[ "$output" =~ "5" ]] || false +} + +@test 'values: NULLs with mixed types' { + run query_server -t -c "SELECT * FROM (VALUES(1),(NULL),(2.5)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "2.5" ]] || false +} + +@test 'values: all same type no cast needed' { + run query_server -t -c "SELECT * FROM (VALUES(1),(2),(3)) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "2" ]] || false + [[ "$output" =~ "3" ]] || false +} + +@test 'values: all string literals' { + run query_server -t -c "SELECT * FROM (VALUES('a'),('b'),('c')) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "a" ]] || false + [[ "$output" =~ "b" ]] || false + [[ "$output" =~ "c" ]] || false +} + +@test 'values: string concatenation' { + run query_server -t -c "SELECT n || '!' FROM (VALUES('hello'),('world')) v(n);" + [ "$status" -eq 0 ] + [[ "$output" =~ "hello!" ]] || false + [[ "$output" =~ "world!" ]] || false +} + +@test 'values: type mismatch bool and int errors' { + run query_server -t -c "SELECT * FROM (VALUES(true),(1),(false)) v(n);" + [ "$status" -ne 0 ] + [[ "$output" =~ "cannot be matched" ]] || false +} + +@test 'values: JOIN with same types' { + run query_server -t -c "SELECT a.n, b.label FROM (VALUES(1),(2),(3)) a(n) JOIN (VALUES(1, 'one'),(2, 'two'),(3, 'three')) b(id, label) ON a.n = b.id;" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "one" ]] || false + [[ "$output" =~ "2" ]] || false + [[ "$output" =~ "two" ]] || false + [[ "$output" =~ "3" ]] || false + [[ "$output" =~ "three" ]] || false +} + +@test 'values: JOIN with mixed types' { + run query_server -t -c "SELECT a.n, b.label FROM (VALUES(1),(2.5),(3)) a(n) JOIN (VALUES(1, 'one'),(3, 'three')) b(id, label) ON a.n = b.id;" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "one" ]] || false + [[ "$output" =~ "3" ]] || false + [[ "$output" =~ "three" ]] || false +} + +@test 'values: CTE with mixed types' { + run query_server -t -c "WITH nums AS (SELECT * FROM (VALUES(1),(2.5),(3)) v(n)) SELECT * FROM nums;" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "2.5" ]] || false + [[ "$output" =~ "3" ]] || false +} + +@test 'values: CTE SUM with mixed types' { + run query_server -t -c "WITH nums AS (SELECT * FROM (VALUES(1),(2.5),(3)) v(n)) SELECT SUM(n) FROM nums;" + [ "$status" -eq 0 ] + [[ "$output" =~ "6.5" ]] || false +} + +@test 'values: multi-column partial cast' { + # Only second column needs cast, first stays int + run query_server -t -c "SELECT * FROM (VALUES(1, 10),(2, 20.5),(3, 30)) v(a, b);" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "10" ]] || false + [[ "$output" =~ "2" ]] || false + [[ "$output" =~ "20.5" ]] || false + [[ "$output" =~ "3" ]] || false + [[ "$output" =~ "30" ]] || false +} + +@test 'values: combined GROUP BY ORDER BY LIMIT' { + run query_server -t -c "SELECT n, COUNT(*) as cnt FROM (VALUES(1),(2.5),(1),(2.5),(3),(1)) v(n) GROUP BY n ORDER BY cnt DESC LIMIT 2;" + [ "$status" -eq 0 ] + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "3" ]] || false + [[ "$output" =~ "2.5" ]] || false + [[ "$output" =~ "2" ]] || false +} + +@test 'values: combined WHERE ORDER BY LIMIT' { + run query_server -t -c "SELECT * FROM (VALUES(1),(2.5),(3),(4.5),(5)) v(n) WHERE n > 1 ORDER BY n DESC LIMIT 2;" + [ "$status" -eq 0 ] + [[ "$output" =~ "5" ]] || false + [[ "$output" =~ "4.5" ]] || false +} diff --git a/testing/go/values_statement_test.go b/testing/go/values_statement_test.go index cf995bd301..6d5c82bad8 100644 --- a/testing/go/values_statement_test.go +++ b/testing/go/values_statement_test.go @@ -26,8 +26,7 @@ func TestValuesStatement(t *testing.T) { var ValuesStatementTests = []ScriptTest{ { - Name: "basic values statements", - SetUpScript: []string{}, + Name: "basic values statements", Assertions: []ScriptTestAssertion{ { Query: `SELECT * FROM (VALUES (1), (2), (3)) sqa;`, @@ -53,4 +52,526 @@ var ValuesStatementTests = []ScriptTest{ }, }, }, + { + Name: "VALUES with mixed int and decimal", + Assertions: []ScriptTestAssertion{ + { + // Integer first, then decimal - should resolve to numeric + Query: `SELECT * FROM (VALUES(1),(2.01),(3)) v(n);`, + Expected: []sql.Row{ + {Numeric("1")}, + {Numeric("2.01")}, + {Numeric("3")}, + }, + }, + { + // Decimal first, then integers - should resolve to numeric + Query: `SELECT * FROM (VALUES(1.01),(2),(3)) v(n);`, + Expected: []sql.Row{ + {Numeric("1.01")}, + {Numeric("2")}, + {Numeric("3")}, + }, + }, + { + // SUM should work directly now that VALUES has correct type + Query: `SELECT SUM(n) FROM (VALUES(1),(2.01),(3)) v(n);`, + Expected: []sql.Row{{Numeric("6.01")}}, + }, + { + // Exact repro from issue #1648: integer first, explicit cast to numeric + Query: `SELECT SUM(n::numeric) FROM (VALUES(1),(2.01),(3)) v(n);`, + Expected: []sql.Row{{Numeric("6.01")}}, + }, + { + // Exact repro from issue #1648: decimal first, explicit cast to numeric + Query: `SELECT SUM(n::numeric) FROM (VALUES(1.01),(2),(3)) v(n);`, + Expected: []sql.Row{{Numeric("6.01")}}, + }, + }, + }, + { + Name: "VALUES with multiple columns mixed types", + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT * FROM (VALUES(1, 'a'), (2.5, 'b')) v(num, str);`, + Expected: []sql.Row{ + {Numeric("1"), "a"}, + {Numeric("2.5"), "b"}, + }, + }, + }, + }, + { + Name: "VALUES with GROUP BY", + Assertions: []ScriptTestAssertion{ + { + // GROUP BY on mixed type VALUES - tests that GetField types are updated correctly + Query: `SELECT n, COUNT(*) FROM (VALUES(1),(2.5),(1),(3.5),(2.5)) v(n) GROUP BY n ORDER BY n;`, + Expected: []sql.Row{ + {Numeric("1"), int64(2)}, + {Numeric("2.5"), int64(2)}, + {Numeric("3.5"), int64(1)}, + }, + }, + { + // SUM with GROUP BY + Query: `SELECT category, SUM(amount) FROM (VALUES('a', 1),('b', 2.5),('a', 3),('b', 4.5)) v(category, amount) GROUP BY category ORDER BY category;`, + Expected: []sql.Row{ + {"a", Numeric("4")}, + {"b", Numeric("7.0")}, + }, + }, + }, + }, + { + Name: "VALUES with DISTINCT", + Assertions: []ScriptTestAssertion{ + { + // DISTINCT on mixed type VALUES + Query: `SELECT DISTINCT n FROM (VALUES(1),(2.5),(1),(2.5),(3)) v(n) ORDER BY n;`, + Expected: []sql.Row{ + {Numeric("1")}, + {Numeric("2.5")}, + {Numeric("3")}, + }, + }, + }, + }, + { + Name: "VALUES with LIMIT and OFFSET", + Assertions: []ScriptTestAssertion{ + { + // LIMIT on mixed type VALUES + Query: `SELECT * FROM (VALUES(1),(2.5),(3),(4.5),(5)) v(n) LIMIT 3;`, + Expected: []sql.Row{ + {Numeric("1")}, + {Numeric("2.5")}, + {Numeric("3")}, + }, + }, + { + // LIMIT with OFFSET + Query: `SELECT * FROM (VALUES(1),(2.5),(3),(4.5),(5)) v(n) LIMIT 2 OFFSET 2;`, + Expected: []sql.Row{ + {Numeric("3")}, + {Numeric("4.5")}, + }, + }, + }, + }, + { + Name: "VALUES with ORDER BY", + Assertions: []ScriptTestAssertion{ + { + // ORDER BY on mixed type VALUES - ascending + Query: `SELECT * FROM (VALUES(3),(1.5),(2),(4.5)) v(n) ORDER BY n;`, + Expected: []sql.Row{ + {Numeric("1.5")}, + {Numeric("2")}, + {Numeric("3")}, + {Numeric("4.5")}, + }, + }, + { + // ORDER BY descending + Query: `SELECT * FROM (VALUES(3),(1.5),(2),(4.5)) v(n) ORDER BY n DESC;`, + Expected: []sql.Row{ + {Numeric("4.5")}, + {Numeric("3")}, + {Numeric("2")}, + {Numeric("1.5")}, + }, + }, + }, + }, + { + Name: "VALUES in subquery", + Assertions: []ScriptTestAssertion{ + { + // VALUES as subquery in FROM clause + // TODO: pre-existing bug: arithmetic in subquery over VALUES is not applied (returns original values) + Skip: true, + Query: `SELECT * FROM (SELECT n * 2 AS doubled FROM (VALUES(1),(2.5),(3)) v(n)) sub;`, + Expected: []sql.Row{ + {Numeric("2")}, + {Numeric("5.0")}, + {Numeric("6")}, + }, + }, + { + // VALUES with LIMIT inside subquery + // TODO: pre-existing bug: LIMIT inside subquery over VALUES is ignored (returns all rows) + Skip: true, + Query: `SELECT * FROM (SELECT * FROM (VALUES(1),(2.5),(3),(4.5)) v(n) LIMIT 2) sub;`, + Expected: []sql.Row{ + {Numeric("1")}, + {Numeric("2.5")}, + }, + }, + { + // VALUES with ORDER BY inside subquery + // TODO: pre-existing bug - ORDER BY inside subquery over VALUES is ignored + Skip: true, + Query: `SELECT * FROM (SELECT * FROM (VALUES(3),(1.5),(2)) v(n) ORDER BY n) sub;`, + Expected: []sql.Row{ + {Numeric("1.5")}, + {Numeric("2")}, + {Numeric("3")}, + }, + }, + }, + }, + { + Name: "VALUES with WHERE clause (Filter node)", + Assertions: []ScriptTestAssertion{ + { + // Filter on mixed type VALUES + Query: `SELECT * FROM (VALUES(1),(2.5),(3),(4.5),(5)) v(n) WHERE n > 2;`, + Expected: []sql.Row{ + {Numeric("2.5")}, + {Numeric("3")}, + {Numeric("4.5")}, + {Numeric("5")}, + }, + }, + { + // Filter with multiple conditions + Query: `SELECT * FROM (VALUES(1),(2.5),(3),(4.5),(5)) v(n) WHERE n > 1 AND n < 4.5;`, + Expected: []sql.Row{ + {Numeric("2.5")}, + {Numeric("3")}, + }, + }, + }, + }, + { + Name: "VALUES with aggregate functions", + Assertions: []ScriptTestAssertion{ + { + // AVG on mixed types + Query: `SELECT AVG(n) FROM (VALUES(1),(2),(3),(4)) v(n);`, + Expected: []sql.Row{{2.5}}, + }, + { + // MIN/MAX on mixed types + Query: `SELECT MIN(n), MAX(n) FROM (VALUES(1),(2.5),(3),(0.5)) v(n);`, + Expected: []sql.Row{ + {Numeric("0.5"), Numeric("3")}, + }, + }, + }, + }, + { + Name: "VALUES combined operations", + Assertions: []ScriptTestAssertion{ + { + // GROUP BY + ORDER BY + LIMIT + Query: `SELECT n, COUNT(*) as cnt FROM (VALUES(1),(2.5),(1),(2.5),(3),(1)) v(n) GROUP BY n ORDER BY cnt DESC LIMIT 2;`, + Expected: []sql.Row{ + {Numeric("1"), int64(3)}, + {Numeric("2.5"), int64(2)}, + }, + }, + { + // DISTINCT + ORDER BY + LIMIT + Query: `SELECT DISTINCT n FROM (VALUES(1),(2.5),(1),(3),(2.5),(4)) v(n) ORDER BY n DESC LIMIT 3;`, + Expected: []sql.Row{ + {Numeric("4")}, + {Numeric("3")}, + {Numeric("2.5")}, + }, + }, + { + // WHERE + ORDER BY + LIMIT + Query: `SELECT * FROM (VALUES(1),(2.5),(3),(4.5),(5)) v(n) WHERE n > 1 ORDER BY n DESC LIMIT 2;`, + Expected: []sql.Row{ + {Numeric("5")}, + {Numeric("4.5")}, + }, + }, + }, + }, + { + Name: "VALUES with single row (no type unification needed)", + Assertions: []ScriptTestAssertion{ + { + // Single row should pass through unchanged + Query: `SELECT * FROM (VALUES(42)) v(n);`, + Expected: []sql.Row{ + {int32(42)}, + }, + }, + { + // Single row with decimal + Query: `SELECT * FROM (VALUES(3.14)) v(n);`, + Expected: []sql.Row{ + {Numeric("3.14")}, + }, + }, + }, + }, + { + Name: "VALUES with NULL values", + Assertions: []ScriptTestAssertion{ + { + // NULL mixed with integers - should resolve to integer, NULL stays NULL + Query: `SELECT * FROM (VALUES(1),(NULL),(3)) v(n);`, + Expected: []sql.Row{ + {int32(1)}, + {nil}, + {int32(3)}, + }, + }, + { + // NULL mixed with decimals - should resolve to numeric + Query: `SELECT * FROM (VALUES(1.5),(NULL),(3.5)) v(n);`, + Expected: []sql.Row{ + {Numeric("1.5")}, + {nil}, + {Numeric("3.5")}, + }, + }, + { + // NULL mixed with int and decimal - should resolve to numeric + Query: `SELECT * FROM (VALUES(1),(NULL),(2.5)) v(n);`, + Expected: []sql.Row{ + {Numeric("1")}, + {nil}, + {Numeric("2.5")}, + }, + }, + { + // All NULLs - should resolve to text (PostgreSQL behavior) + Query: `SELECT * FROM (VALUES(NULL),(NULL)) v(n);`, + Expected: []sql.Row{ + {nil}, + {nil}, + }, + }, + }, + }, + { + Name: "VALUES type mismatch errors", + Assertions: []ScriptTestAssertion{ + { + // Integer and unknown('text'): FindCommonType resolves to int4 (the non-unknown type), + // then the I/O cast from 'text' to int4 fails at execution time. This matches PostgreSQL behavior: + // psql returns "invalid input syntax for type integer: "text"" + Query: `SELECT * FROM (VALUES(1),('text'),(3)) v(n);`, + ExpectedErr: "invalid input syntax for type int4", + }, + { + // Boolean and integer cannot be matched + Query: `SELECT * FROM (VALUES(true),(1),(false)) v(n);`, + ExpectedErr: "cannot be matched", + }, + }, + }, + { + Name: "VALUES with all unknown types (string literals)", + Assertions: []ScriptTestAssertion{ + { + // All string literals should resolve to text + Query: `SELECT * FROM (VALUES('a'),('b'),('c')) v(n);`, + Expected: []sql.Row{ + {"a"}, + {"b"}, + {"c"}, + }, + }, + { + // String literals with operations + Query: `SELECT n || '!' FROM (VALUES('hello'),('world')) v(n);`, + Expected: []sql.Row{ + {"hello!"}, + {"world!"}, + }, + }, + }, + }, + { + Name: "VALUES with array types", + Assertions: []ScriptTestAssertion{ + { + // Integer arrays: doltgresql returns arrays in text format over the wire + Query: `SELECT * FROM (VALUES(ARRAY[1,2]),(ARRAY[3,4])) v(arr);`, + Expected: []sql.Row{ + {"{1,2}"}, + {"{3,4}"}, + }, + }, + { + // Text arrays: doltgresql returns arrays in text format over the wire + Query: `SELECT * FROM (VALUES(ARRAY['a','b']),(ARRAY['c','d'])) v(arr);`, + Expected: []sql.Row{ + {"{a,b}"}, + {"{c,d}"}, + }, + }, + }, + }, + { + Name: "VALUES with all same type multi-row (no casts needed)", + Assertions: []ScriptTestAssertion{ + { + // All integers + Query: `SELECT * FROM (VALUES(1),(2),(3)) v(n);`, + Expected: []sql.Row{ + {int32(1)}, + {int32(2)}, + {int32(3)}, + }, + }, + { + // All decimals + Query: `SELECT * FROM (VALUES(1.5),(2.5),(3.5)) v(n);`, + Expected: []sql.Row{ + {Numeric("1.5")}, + {Numeric("2.5")}, + {Numeric("3.5")}, + }, + }, + { + // All text + Query: `SELECT * FROM (VALUES('x'),('y'),('z')) v(n);`, + Expected: []sql.Row{ + {"x"}, + {"y"}, + {"z"}, + }, + }, + }, + }, + { + Name: "VALUES with multi-column partial cast", + Assertions: []ScriptTestAssertion{ + { + // Only first column needs cast + Query: `SELECT * FROM (VALUES(1, 'a'),(2.5, 'b'),(3, 'c')) v(num, str);`, + Expected: []sql.Row{ + {Numeric("1"), "a"}, + {Numeric("2.5"), "b"}, + {Numeric("3"), "c"}, + }, + }, + { + // Only second column needs cast + Query: `SELECT * FROM (VALUES(1, 10),(2, 20.5),(3, 30)) v(a, b);`, + Expected: []sql.Row{ + {int32(1), Numeric("10")}, + {int32(2), Numeric("20.5")}, + {int32(3), Numeric("30")}, + }, + }, + }, + }, + { + Name: "VALUES in CTE (WITH clause)", + Assertions: []ScriptTestAssertion{ + { + // Mixed types via CTE + Query: `WITH nums AS (SELECT * FROM (VALUES(1),(2.5),(3)) v(n)) SELECT * FROM nums;`, + Expected: []sql.Row{ + {Numeric("1")}, + {Numeric("2.5")}, + {Numeric("3")}, + }, + }, + { + // SUM over CTE + Query: `WITH nums AS (SELECT * FROM (VALUES(1),(2.5),(3)) v(n)) SELECT SUM(n) FROM nums;`, + Expected: []sql.Row{{Numeric("6.5")}}, + }, + }, + }, + { + Name: "VALUES with JOIN", + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT a.n, b.label FROM (VALUES(1),(2),(3)) a(n) JOIN (VALUES(1, 'one'),(2, 'two'),(3, 'three')) b(id, label) ON a.n = b.id;`, + Expected: []sql.Row{ + {int32(1), "one"}, + {int32(2), "two"}, + {int32(3), "three"}, + }, + }, + { + // Mixed types in one of the joined VALUES + Query: `SELECT a.n, b.label FROM (VALUES(1),(2.5),(3)) a(n) JOIN (VALUES(1, 'one'),(3, 'three')) b(id, label) ON a.n = b.id;`, + Expected: []sql.Row{ + {Numeric("1"), "one"}, + {Numeric("3"), "three"}, + }, + }, + }, + }, + { + Name: "VALUES with same-type booleans", + Assertions: []ScriptTestAssertion{ + { + // All booleans, returned as "t"/"f" over the wire + Query: `SELECT * FROM (VALUES(true),(false),(true)) v(b);`, + Expected: []sql.Row{ + {"t"}, + {"f"}, + {"t"}, + }, + }, + { + // Boolean WHERE filter + Query: `SELECT * FROM (VALUES(true),(false),(true),(false)) v(b) WHERE b = true;`, + Expected: []sql.Row{ + {"t"}, + {"t"}, + }, + }, + }, + }, + { + Name: "VALUES with case-sensitive quoted column names", + Assertions: []ScriptTestAssertion{ + { + // Column names w/ quotes preserve case; unquoted are lowercased by the parser + Query: `SELECT "ColA", "colb" FROM (VALUES(1, 2),(3.5, 4.5)) v("ColA", "colb");`, + Expected: []sql.Row{ + {Numeric("1"), Numeric("2")}, + {Numeric("3.5"), Numeric("4.5")}, + }, + }, + { + // Mixed case: one quoted (preserved), one unquoted (lowered) + Query: `SELECT "MixedCase", plain FROM (VALUES(1, 'a'),(2.5, 'b')) v("MixedCase", plain);`, + Expected: []sql.Row{ + {Numeric("1"), "a"}, + {Numeric("2.5"), "b"}, + }, + }, + { + // SUM with quoted column name + Query: `SELECT SUM("Val") FROM (VALUES(1),(2.5),(3)) v("Val");`, + Expected: []sql.Row{{Numeric("6.5")}}, + }, + }, + }, + { + Name: "VALUES with case-differing quoted columns and aggregates", + Assertions: []ScriptTestAssertion{ + { + // TODO: resolve GMS case asymmetry issues. + // Builder.buildAggregateFunc() in planbuilder/aggregates.go + // lowercases entire aggregate names, so SUM("Val") and + // SUM("val") both become "sum(v.val)" in the GetField. The + // second pass in ResolveValuesTypes must use strings.ToLower + // to match against GroupBy.Schema() which keeps original + // casing. This causes a false match when two columns differ + // only by case. See resolve_values_types.go for details. + Skip: true, + Query: `SELECT SUM("Val"), SUM("val") FROM (VALUES(1, 10),(2.5, 20)) v("Val", "val");`, + Expected: []sql.Row{ + {Numeric("3.5"), int64(30)}, + }, + }, + }, + }, }