diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 1047bb2457..929daaae8c 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -15,6 +15,7 @@ package analyzer import ( + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/memo" @@ -22,6 +23,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/planbuilder" pgexpression "github.com/dolthub/doltgresql/server/expression" + pgnodes "github.com/dolthub/doltgresql/server/node" ) // IDs are basically arbitrary, we just need to ensure that they do not conflict with existing IDs @@ -114,6 +116,13 @@ func initEngine() { planbuilder.IsAggregateFunc = IsAggregateFunc + planbuilder.BuildMultiExprTableFunc = func( + exprs []sql.Expression, alias string, + withOrdinality bool, columnAliases []string, + ) (sql.Node, error) { + return pgnodes.NewRowsFrom(exprs, alias, withOrdinality, columnAliases), nil + } + expression.DefaultExpressionFactory = pgexpression.PostgresExpressionFactory{} // There are a couple places during analysis where SplitConjunction in GMS cannot correctly split up diff --git a/server/ast/aliased_table_expr.go b/server/ast/aliased_table_expr.go index db46298b16..57c2099f13 100644 --- a/server/ast/aliased_table_expr.go +++ b/server/ast/aliased_table_expr.go @@ -15,6 +15,8 @@ package ast import ( + "strings" + "github.com/cockroachdb/errors" vitess "github.com/dolthub/vitess/go/vt/sqlparser" @@ -24,13 +26,115 @@ import ( ) // nodeAliasedTableExpr handles *tree.AliasedTableExpr nodes. -func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (*vitess.AliasedTableExpr, error) { - if node.Ordinality { - return nil, errors.Errorf("ordinality is not yet supported") - } +func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (vitess.TableExpr, error) { if node.IndexFlags != nil { return nil, errors.Errorf("index flags are not yet supported") } + + // Handle RowsFromExpr specially - it can have WITH ORDINALITY and column aliases + if rowsFrom, ok := node.Expr.(*tree.RowsFromExpr); ok { + // Handle multi-argument UNNEST specially: UNNEST(arr1, arr2, ...) + // is syntactic sugar for ROWS FROM(unnest(arr1), unnest(arr2), ...) + if len(rowsFrom.Items) == 1 { + if funcExpr, ok := rowsFrom.Items[0].(*tree.FuncExpr); ok { + funcName := funcExpr.Func.String() + if strings.EqualFold(funcName, "unnest") && len(funcExpr.Exprs) > 1 { + // Expand multi-arg UNNEST into separate unnest calls + selectExprs := make(vitess.SelectExprs, len(funcExpr.Exprs)) + for i, arg := range funcExpr.Exprs { + argExpr, err := nodeExpr(ctx, arg) + if err != nil { + return nil, err + } + selectExprs[i] = &vitess.AliasedExpr{ + Expr: &vitess.FuncExpr{ + Name: vitess.NewColIdent("unnest"), + Exprs: vitess.SelectExprs{&vitess.AliasedExpr{Expr: argExpr}}, + }, + } + } + + var columns vitess.Columns + if len(node.As.Cols) > 0 { + columns = make(vitess.Columns, len(node.As.Cols)) + for i := range node.As.Cols { + columns[i] = vitess.NewColIdent(string(node.As.Cols[i])) + } + } + + return &vitess.TableFuncExpr{ + Exprs: selectExprs, + WithOrdinality: node.Ordinality, + Alias: vitess.NewTableIdent(string(node.As.Alias)), + Columns: columns, + }, nil + } + } + } + + // Use TableFuncExpr (nameless) for: + // 1. Multiple functions: ROWS FROM(func1(), func2()) AS alias + // 2. WITH ORDINALITY: ROWS FROM(func()) WITH ORDINALITY + if len(rowsFrom.Items) > 1 || node.Ordinality { + selectExprs := make(vitess.SelectExprs, len(rowsFrom.Items)) + for i, item := range rowsFrom.Items { + expr, err := nodeExpr(ctx, item) + if err != nil { + return nil, err + } + selectExprs[i] = &vitess.AliasedExpr{Expr: expr} + } + + var columns vitess.Columns + if len(node.As.Cols) > 0 { + columns = make(vitess.Columns, len(node.As.Cols)) + for i := range node.As.Cols { + columns[i] = vitess.NewColIdent(string(node.As.Cols[i])) + } + } + + return &vitess.TableFuncExpr{ + Exprs: selectExprs, + WithOrdinality: node.Ordinality, + Alias: vitess.NewTableIdent(string(node.As.Alias)), + Columns: columns, + }, nil + } + + // For single function without ordinality, fall through to use the existing + // table function infrastructure via nodeTableExpr + tableExpr, err := nodeTableExpr(ctx, rowsFrom) + if err != nil { + return nil, err + } + + // Wrap in a subquery as the original code did + subquery := &vitess.Subquery{ + Select: &vitess.Select{ + From: vitess.TableExprs{tableExpr}, + }, + } + + if len(node.As.Cols) > 0 { + columns := make([]vitess.ColIdent, len(node.As.Cols)) + for i := range node.As.Cols { + columns[i] = vitess.NewColIdent(string(node.As.Cols[i])) + } + subquery.Columns = columns + } + + return &vitess.AliasedTableExpr{ + Expr: subquery, + As: vitess.NewTableIdent(string(node.As.Alias)), + Lateral: node.Lateral, + }, nil + } + + // For non-RowsFromExpr expressions, ordinality is not yet supported + if node.Ordinality { + return nil, errors.Errorf("ordinality is only supported for ROWS FROM expressions") + } + var aliasExpr vitess.SimpleTableExpr var authInfo vitess.AuthInformation @@ -92,27 +196,6 @@ func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (*vitess.Al Select: selectStmt, } - if len(node.As.Cols) > 0 { - columns := make([]vitess.ColIdent, len(node.As.Cols)) - for i := range node.As.Cols { - columns[i] = vitess.NewColIdent(string(node.As.Cols[i])) - } - subquery.Columns = columns - } - aliasExpr = subquery - case *tree.RowsFromExpr: - tableExpr, err := nodeTableExpr(ctx, expr) - if err != nil { - return nil, err - } - - // TODO: this should be represented as a table function more directly - subquery := &vitess.Subquery{ - Select: &vitess.Select{ - From: vitess.TableExprs{tableExpr}, - }, - } - if len(node.As.Cols) > 0 { columns := make([]vitess.ColIdent, len(node.As.Cols)) for i := range node.As.Cols { diff --git a/server/ast/select_clause.go b/server/ast/select_clause.go index f71a7b38e0..dddbeed3c4 100644 --- a/server/ast/select_clause.go +++ b/server/ast/select_clause.go @@ -15,8 +15,9 @@ package ast import ( - "github.com/dolthub/go-mysql-server/sql/expression" + "strings" + "github.com/dolthub/go-mysql-server/sql/expression" vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" @@ -158,6 +159,32 @@ PostJoinRewrite: } } } + // Handle multi-argument UNNEST: UNNEST(arr1, arr2, ...) produces a table with one column per array, + // where corresponding elements are "zipped" together. PostgreSQL pads shorter arrays with NULLs. + // We transform: SELECT * FROM UNNEST(arr1, arr2) + // Into: SELECT * FROM ROWS FROM(unnest(arr1), unnest(arr2)) AS unnest + // This uses the native ROWS FROM table function which properly zips SRFs together. + if tableFuncExpr, ok := from[i].(*vitess.TableFuncExpr); ok { + if strings.EqualFold(tableFuncExpr.Name, "unnest") && len(tableFuncExpr.Exprs) > 1 { + selectExprs := make(vitess.SelectExprs, 0, len(tableFuncExpr.Exprs)) + for _, argExpr := range tableFuncExpr.Exprs { + selectExprs = append(selectExprs, &vitess.AliasedExpr{ + Expr: &vitess.FuncExpr{ + Name: vitess.NewColIdent("unnest"), + Exprs: vitess.SelectExprs{argExpr}, + }, + }) + } + alias := tableFuncExpr.Alias + if alias.IsEmpty() { + alias = vitess.NewTableIdent("unnest") + } + from[i] = &vitess.TableFuncExpr{ + Exprs: selectExprs, + Alias: alias, + } + } + } } distinct := node.Distinct var distinctOn vitess.Exprs diff --git a/server/ast/table_expr.go b/server/ast/table_expr.go index 3de0978a64..762164c4a6 100644 --- a/server/ast/table_expr.go +++ b/server/ast/table_expr.go @@ -15,6 +15,8 @@ package ast import ( + "strings" + "github.com/cockroachdb/errors" vitess "github.com/dolthub/vitess/go/vt/sqlparser" @@ -99,12 +101,55 @@ func nodeTableExpr(ctx *Context, node tree.TableExpr) (vitess.TableExpr, error) Exprs: vitess.TableExprs{tableExpr}, }, nil case *tree.RowsFromExpr: + // Handle multi-argument UNNEST specially: UNNEST(arr1, arr2, ...) + // is syntactic sugar for ROWS FROM(unnest(arr1), unnest(arr2), ...) + if len(node.Items) == 1 { + if funcExpr, ok := node.Items[0].(*tree.FuncExpr); ok { + funcName := funcExpr.Func.String() + if strings.EqualFold(funcName, "unnest") && len(funcExpr.Exprs) > 1 { + // Expand multi-arg UNNEST into separate unnest calls + selectExprs := make(vitess.SelectExprs, len(funcExpr.Exprs)) + for i, arg := range funcExpr.Exprs { + argExpr, err := nodeExpr(ctx, arg) + if err != nil { + return nil, err + } + selectExprs[i] = &vitess.AliasedExpr{ + Expr: &vitess.FuncExpr{ + Name: vitess.NewColIdent("unnest"), + Exprs: vitess.SelectExprs{&vitess.AliasedExpr{Expr: argExpr}}, + }, + } + } + return &vitess.TableFuncExpr{ + Exprs: selectExprs, + }, nil + } + } + } + + // For explicit ROWS FROM with multiple functions, use TableFuncExpr (nameless) + // This handles: ROWS FROM(generate_series(1,3), generate_series(10,12)) + if len(node.Items) > 1 { + selectExprs := make(vitess.SelectExprs, len(node.Items)) + for i, item := range node.Items { + expr, err := nodeExpr(ctx, item) + if err != nil { + return nil, err + } + selectExprs[i] = &vitess.AliasedExpr{Expr: expr} + } + return &vitess.TableFuncExpr{ + Exprs: selectExprs, + }, nil + } + + // For single functions, use the original ValuesStatement approach + // which works with the existing table function infrastructure exprs, err := nodeExprs(ctx, node.Items) if err != nil { return nil, err } - //TODO: not sure if this is correct at all. I think we want to return one result per row, but maybe not. - // This needs to be tested to verify. rows := make([]vitess.ValTuple, len(exprs)) for i := range exprs { rows[i] = vitess.ValTuple{exprs[i]} diff --git a/server/node/rows_from.go b/server/node/rows_from.go new file mode 100644 index 0000000000..aa21bd610d --- /dev/null +++ b/server/node/rows_from.go @@ -0,0 +1,386 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package node + +import ( + "errors" + "fmt" + "io" + "strings" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// RowsFrom represents a ROWS FROM table function that executes multiple +// set-returning functions in parallel and zips their results together. +// This is the PostgreSQL-compatible syntax: ROWS FROM(func1(...), func2(...), ...) +type RowsFrom struct { + // Functions contains the set-returning function expressions to execute + Functions []sql.Expression + // withOrdinality when true, adds an ordinality column to the result + withOrdinality bool + // alias is the table alias for this ROWS FROM expression + alias string + // columnAliases are optional column names for the result columns + columnAliases []string + // colset tracks the column IDs for this node + colset sql.ColSet + // id is the table ID for this node + id sql.TableId +} + +var _ sql.Node = (*RowsFrom)(nil) +var _ sql.Expressioner = (*RowsFrom)(nil) +var _ sql.CollationCoercible = (*RowsFrom)(nil) +var _ plan.TableIdNode = (*RowsFrom)(nil) +var _ sql.RenameableNode = (*RowsFrom)(nil) +var _ sql.ExecBuilderNode = (*RowsFrom)(nil) + +// NewRowsFrom creates a new RowsFrom node with the given function expressions. +func NewRowsFrom(exprs []sql.Expression, alias string, withOrdinality bool, columnAliases []string) *RowsFrom { + return &RowsFrom{ + Functions: exprs, + withOrdinality: withOrdinality, + alias: alias, + columnAliases: columnAliases, + } +} + +// BuildRowIter implements sql.ExecBuilderNode. +func (r *RowsFrom) BuildRowIter(ctx *sql.Context, b sql.NodeExecBuilder, row sql.Row) (sql.RowIter, error) { + return NewRowsFromIter(r.Functions, r.withOrdinality, row), nil +} + +// WithId implements plan.TableIdNode +func (r *RowsFrom) WithId(id sql.TableId) plan.TableIdNode { + ret := *r + ret.id = id + return &ret +} + +// Id implements plan.TableIdNode +func (r *RowsFrom) Id() sql.TableId { + return r.id +} + +// WithColumns implements plan.TableIdNode +func (r *RowsFrom) WithColumns(set sql.ColSet) plan.TableIdNode { + ret := *r + ret.colset = set + return &ret +} + +// Columns implements plan.TableIdNode +func (r *RowsFrom) Columns() sql.ColSet { + return r.colset +} + +// Name returns the alias name for this ROWS FROM expression +func (r *RowsFrom) Name() string { + if r.alias != "" { + return r.alias + } + return "rows_from" +} + +// WithName implements sql.RenameableNode +func (r *RowsFrom) WithName(s string) sql.Node { + ret := *r + ret.alias = s + return &ret +} + +// Schema implements the sql.Node interface. +func (r *RowsFrom) Schema() sql.Schema { + var schema sql.Schema + + for i, f := range r.Functions { + colName := fmt.Sprintf("col%d", i) + if i < len(r.columnAliases) && r.columnAliases[i] != "" { + colName = r.columnAliases[i] + } else if nameable, ok := f.(sql.Nameable); ok { + colName = nameable.Name() + } + + schema = append(schema, &sql.Column{ + Name: colName, + Type: f.Type(), + Nullable: true, // SRF results can be NULL when zipping unequal-length results + Source: r.Name(), + }) + } + + if r.withOrdinality { + schema = append(schema, &sql.Column{ + Name: "ordinality", + Type: types.Int64, + Nullable: false, + Source: r.Name(), + }) + } + + return schema +} + +// Children implements the sql.Node interface. +func (r *RowsFrom) Children() []sql.Node { + return nil +} + +// Resolved implements the sql.Resolvable interface. +func (r *RowsFrom) Resolved() bool { + for _, f := range r.Functions { + if !f.Resolved() { + return false + } + } + return true +} + +// IsReadOnly implements the sql.Node interface. +func (r *RowsFrom) IsReadOnly() bool { + return true +} + +// String implements the sql.Node interface. +func (r *RowsFrom) String() string { + var sb strings.Builder + sb.WriteString("ROWS FROM(") + for i, f := range r.Functions { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(f.String()) + } + sb.WriteString(")") + if r.withOrdinality { + sb.WriteString(" WITH ORDINALITY") + } + if r.alias != "" { + sb.WriteString(" AS ") + sb.WriteString(r.alias) + } + return sb.String() +} + +// DebugString implements the sql.DebugStringer interface. +func (r *RowsFrom) DebugString() string { + var sb strings.Builder + sb.WriteString("RowsFrom(") + for i, f := range r.Functions { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(sql.DebugString(f)) + } + sb.WriteString(")") + if r.withOrdinality { + sb.WriteString(" WITH ORDINALITY") + } + if r.alias != "" { + sb.WriteString(" AS ") + sb.WriteString(r.alias) + } + return sb.String() +} + +// Expressions implements the sql.Expressioner interface. +func (r *RowsFrom) Expressions() []sql.Expression { + return r.Functions +} + +// WithExpressions implements the sql.Expressioner interface. +func (r *RowsFrom) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(r.Functions) { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(exprs), len(r.Functions)) + } + ret := *r + ret.Functions = exprs + return &ret, nil +} + +// WithChildren implements the sql.Node interface. +func (r *RowsFrom) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0) + } + return r, nil +} + +// CollationCoercibility implements the interface sql.CollationCoercible. +func (*RowsFrom) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.Collation_binary, 7 +} + +// RowsFromIter is an iterator for the RowsFrom node. +// It executes multiple set-returning functions in parallel and zips their results together. +// When one function is exhausted before another, NULL is used for its values. +type RowsFromIter struct { + functions []sql.Expression + iters []sql.RowIter + finished []bool + withOrdinality bool + ordinality int64 + initialized bool + sourceRow sql.Row +} + +var _ sql.RowIter = (*RowsFromIter)(nil) + +// NewRowsFromIter creates a new RowsFromIter. +func NewRowsFromIter(functions []sql.Expression, withOrdinality bool, row sql.Row) *RowsFromIter { + return &RowsFromIter{ + functions: functions, + withOrdinality: withOrdinality, + sourceRow: row, + finished: make([]bool, len(functions)), + } +} + +// Next implements the sql.RowIter interface. +func (r *RowsFromIter) Next(ctx *sql.Context) (sql.Row, error) { + if !r.initialized { + if err := r.initIterators(ctx); err != nil { + return nil, err + } + r.initialized = true + } + + allFinished := true + for _, f := range r.finished { + if !f { + allFinished = false + break + } + } + if allFinished { + return nil, io.EOF + } + + row := make(sql.Row, len(r.functions)) + for i, iter := range r.iters { + if r.finished[i] { + row[i] = nil + continue + } + + nextRow, err := iter.Next(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + r.finished[i] = true + row[i] = nil + continue + } + return nil, err + } + + if len(nextRow) > 0 { + row[i] = nextRow[0] + } else { + row[i] = nil + } + } + + allFinished = true + for _, f := range r.finished { + if !f { + allFinished = false + break + } + } + + allNulls := true + for _, v := range row { + if v != nil { + allNulls = false + break + } + } + + if allFinished && allNulls { + return nil, io.EOF + } + + r.ordinality++ + if r.withOrdinality { + row = append(row, r.ordinality) + } + + return row, nil +} + +func (r *RowsFromIter) initIterators(ctx *sql.Context) error { + r.iters = make([]sql.RowIter, len(r.functions)) + + for i, f := range r.functions { + if rie, ok := f.(sql.RowIterExpression); ok && rie.ReturnsRowIter() { + iter, err := rie.EvalRowIter(ctx, r.sourceRow) + if err != nil { + for j := 0; j < i; j++ { + if r.iters[j] != nil { + r.iters[j].Close(ctx) + } + } + return err + } + r.iters[i] = iter + } else { + val, err := f.Eval(ctx, r.sourceRow) + if err != nil { + for j := 0; j < i; j++ { + if r.iters[j] != nil { + r.iters[j].Close(ctx) + } + } + return err + } + r.iters[i] = &singleValueIter{value: val} + } + } + + return nil +} + +// Close implements the sql.RowIter interface. +func (r *RowsFromIter) Close(ctx *sql.Context) error { + var firstErr error + for _, iter := range r.iters { + if iter != nil { + if err := iter.Close(ctx); err != nil && firstErr == nil { + firstErr = err + } + } + } + return firstErr +} + +type singleValueIter struct { + value interface{} + consumed bool +} + +func (s *singleValueIter) Next(ctx *sql.Context) (sql.Row, error) { + if s.consumed { + return nil, io.EOF + } + s.consumed = true + return sql.Row{s.value}, nil +} + +func (s *singleValueIter) Close(ctx *sql.Context) error { + return nil +} diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index e2e6ee02c9..9f4d8e5acc 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -1552,6 +1552,91 @@ func TestArrayFunctions(t *testing.T) { }, }, }, + { + Name: "multi-argument unnest", + Assertions: []ScriptTestAssertion{ + { + // Basic multi-argument UNNEST with equal-length arrays + Query: `SELECT * FROM UNNEST(ARRAY['a','b','c'], ARRAY[1,2,3])`, + Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}, {"c", int64(3)}}, + }, + { + // Multi-argument UNNEST with unequal-length arrays (shorter padded with NULL) + Query: `SELECT * FROM UNNEST(ARRAY['a','b'], ARRAY[1,2,3])`, + Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}, {nil, int64(3)}}, + }, + { + // Multi-argument UNNEST with empty array + Query: `SELECT * FROM UNNEST(ARRAY['a','b'], ARRAY[]::int[])`, + Expected: []sql.Row{{"a", nil}, {"b", nil}}, + }, + { + // Multi-argument UNNEST with three arrays (booleans come as "t"/"f" strings from PostgreSQL wire protocol) + Query: `SELECT * FROM UNNEST(ARRAY[1,2], ARRAY['x','y'], ARRAY[true,false])`, + Expected: []sql.Row{{int64(1), "x", "t"}, {int64(2), "y", "f"}}, + }, + { + // Multi-argument UNNEST with alias + Query: `SELECT u.* FROM UNNEST(ARRAY['a','b'], ARRAY[1,2]) AS u`, + Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}}, + }, + }, + }, + { + Name: "ROWS FROM with generate_series", + Assertions: []ScriptTestAssertion{ + { + // Basic ROWS FROM with two generate_series calls + Query: `SELECT * FROM ROWS FROM(generate_series(1,3), generate_series(10,12))`, + Expected: []sql.Row{{int64(1), int64(10)}, {int64(2), int64(11)}, {int64(3), int64(12)}}, + }, + { + // ROWS FROM with unequal-length series (shorter padded with NULL) + Query: `SELECT * FROM ROWS FROM(generate_series(1,2), generate_series(10,13))`, + Expected: []sql.Row{{int64(1), int64(10)}, {int64(2), int64(11)}, {nil, int64(12)}, {nil, int64(13)}}, + }, + { + // ROWS FROM with table alias + Query: `SELECT r.* FROM ROWS FROM(generate_series(1,2), generate_series(10,11)) AS r`, + Expected: []sql.Row{{int64(1), int64(10)}, {int64(2), int64(11)}}, + }, + { + // ROWS FROM with three functions + Query: `SELECT * FROM ROWS FROM(generate_series(1,2), generate_series(10,11), generate_series(100,101))`, + Expected: []sql.Row{{int64(1), int64(10), int64(100)}, {int64(2), int64(11), int64(101)}}, + }, + }, + }, + { + Name: "ROWS FROM with unnest", + Assertions: []ScriptTestAssertion{ + { + // ROWS FROM with explicit unnest calls + Query: `SELECT * FROM ROWS FROM(unnest(ARRAY['a','b']), unnest(ARRAY[1,2]))`, + Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}}, + }, + { + // ROWS FROM with unequal-length unnest + Query: `SELECT * FROM ROWS FROM(unnest(ARRAY['a','b','c']), unnest(ARRAY[1,2]))`, + Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}, {"c", nil}}, + }, + }, + }, + { + Name: "ROWS FROM mixed functions", + Assertions: []ScriptTestAssertion{ + { + // Mix generate_series and unnest + Query: `SELECT * FROM ROWS FROM(generate_series(1,3), unnest(ARRAY['x','y','z']))`, + Expected: []sql.Row{{int64(1), "x"}, {int64(2), "y"}, {int64(3), "z"}}, + }, + { + // Mix with unequal lengths + Query: `SELECT * FROM ROWS FROM(generate_series(1,2), unnest(ARRAY['x','y','z']))`, + Expected: []sql.Row{{int64(1), "x"}, {int64(2), "y"}, {nil, "z"}}, + }, + }, + }, { Name: "array_to_string", SetUpScript: []string{}, diff --git a/testing/go/insert_test.go b/testing/go/insert_test.go index 7e887b913e..57514c2251 100755 --- a/testing/go/insert_test.go +++ b/testing/go/insert_test.go @@ -293,5 +293,27 @@ ON CONFLICT (id) do update set c1 = $4`, }, }, }, + { + Name: "insert from unnest", + SetUpScript: []string{ + `CREATE TABLE "django_content_type" (id serial primary key, app_label varchar, model varchar)`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `INSERT INTO "django_content_type" ("app_label", "model") +SELECT * FROM UNNEST(('{debug_app,debug_app}')::varchar[], + ('{debugmodel1,debugmodel2}')::varchar[]) +RETURNING "django_content_type"."id"`, + Expected: []sql.Row{{1}, {2}}, + }, + { + Query: `SELECT "app_label", "model" FROM "django_content_type" ORDER BY "id"`, + Expected: []sql.Row{ + {"debug_app", "debugmodel1"}, + {"debug_app", "debugmodel2"}, + }, + }, + }, + }, }) }