From 174bf4d52dfe7e126472c822a428e48c66ec6db5 Mon Sep 17 00:00:00 2001 From: Myers Carpenter Date: Thu, 22 Jan 2026 08:41:04 -0500 Subject: [PATCH 1/4] Fix UNNEST bulk insert column mapping --- server/ast/select_clause.go | 33 ++++++++++++++++++++++++++++++++- testing/go/functions_test.go | 30 ++++++++++++++++++++++++++++++ testing/go/insert_test.go | 22 ++++++++++++++++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/server/ast/select_clause.go b/server/ast/select_clause.go index f71a7b38e0..2fb1b9118e 100644 --- a/server/ast/select_clause.go +++ b/server/ast/select_clause.go @@ -15,8 +15,9 @@ package ast import ( - "github.com/dolthub/go-mysql-server/sql/expression" + "strings" + "github.com/dolthub/go-mysql-server/sql/expression" vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" @@ -158,6 +159,36 @@ PostJoinRewrite: } } } + // Handle multi-argument UNNEST: UNNEST(arr1, arr2, ...) produces a table with one column per array, + // where corresponding elements are "zipped" together. PostgreSQL pads shorter arrays with NULLs. + // We transform: SELECT * FROM UNNEST(arr1, arr2) + // Into: SELECT * FROM (SELECT unnest(arr1), unnest(arr2)) AS unnest + // GMS's ProjectRowWithNestedIters handles multiple SRFs by zipping them together correctly. + if tableFuncExpr, ok := from[i].(*vitess.TableFuncExpr); ok { + if strings.EqualFold(tableFuncExpr.Name, "unnest") && len(tableFuncExpr.Exprs) > 1 { + selectExprs := make(vitess.SelectExprs, 0, len(tableFuncExpr.Exprs)) + for _, argExpr := range tableFuncExpr.Exprs { + selectExprs = append(selectExprs, &vitess.AliasedExpr{ + Expr: &vitess.FuncExpr{ + Name: vitess.NewColIdent("unnest"), + Exprs: vitess.SelectExprs{argExpr}, + }, + }) + } + alias := tableFuncExpr.Alias + if alias.IsEmpty() { + alias = vitess.NewTableIdent("unnest") + } + from[i] = &vitess.AliasedTableExpr{ + Expr: &vitess.Subquery{ + Select: &vitess.Select{ + SelectExprs: selectExprs, + }, + }, + As: alias, + } + } + } } distinct := node.Distinct var distinctOn vitess.Exprs diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index e2e6ee02c9..375d89c646 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -1552,6 +1552,36 @@ func TestArrayFunctions(t *testing.T) { }, }, }, + { + Name: "multi-argument unnest", + Assertions: []ScriptTestAssertion{ + { + // Basic multi-argument UNNEST with equal-length arrays + Query: `SELECT * FROM UNNEST(ARRAY['a','b','c'], ARRAY[1,2,3])`, + Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}, {"c", int64(3)}}, + }, + { + // Multi-argument UNNEST with unequal-length arrays (shorter padded with NULL) + Query: `SELECT * FROM UNNEST(ARRAY['a','b'], ARRAY[1,2,3])`, + Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}, {nil, int64(3)}}, + }, + { + // Multi-argument UNNEST with empty array + Query: `SELECT * FROM UNNEST(ARRAY['a','b'], ARRAY[]::int[])`, + Expected: []sql.Row{{"a", nil}, {"b", nil}}, + }, + { + // Multi-argument UNNEST with three arrays (booleans come as "t"/"f" strings from PostgreSQL wire protocol) + Query: `SELECT * FROM UNNEST(ARRAY[1,2], ARRAY['x','y'], ARRAY[true,false])`, + Expected: []sql.Row{{int64(1), "x", "t"}, {int64(2), "y", "f"}}, + }, + { + // Multi-argument UNNEST with alias + Query: `SELECT u.* FROM UNNEST(ARRAY['a','b'], ARRAY[1,2]) AS u`, + Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}}, + }, + }, + }, { Name: "array_to_string", SetUpScript: []string{}, diff --git a/testing/go/insert_test.go b/testing/go/insert_test.go index 7e887b913e..57514c2251 100755 --- a/testing/go/insert_test.go +++ b/testing/go/insert_test.go @@ -293,5 +293,27 @@ ON CONFLICT (id) do update set c1 = $4`, }, }, }, + { + Name: "insert from unnest", + SetUpScript: []string{ + `CREATE TABLE "django_content_type" (id serial primary key, app_label varchar, model varchar)`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `INSERT INTO "django_content_type" ("app_label", "model") +SELECT * FROM UNNEST(('{debug_app,debug_app}')::varchar[], + ('{debugmodel1,debugmodel2}')::varchar[]) +RETURNING "django_content_type"."id"`, + Expected: []sql.Row{{1}, {2}}, + }, + { + Query: `SELECT "app_label", "model" FROM "django_content_type" ORDER BY "id"`, + Expected: []sql.Row{ + {"debug_app", "debugmodel1"}, + {"debug_app", "debugmodel2"}, + }, + }, + }, + }, }) } From 0396d701af402bdcd1c47d81351925cdc268bfad Mon Sep 17 00:00:00 2001 From: Myers Carpenter Date: Sat, 31 Jan 2026 21:21:43 -0500 Subject: [PATCH 2/4] Add ROWS FROM support for multi-argument UNNEST Expands multi-argument UNNEST(arr1, arr2, ...) into ROWS FROM(unnest(arr1), unnest(arr2), ...) which properly zips results together with NULL padding for shorter arrays. Changes: - table_expr.go: Detect and expand multi-arg UNNEST to RowsFromExpr - aliased_table_expr.go: Support WITH ORDINALITY via RowsFromExpr - select_clause.go: Update multi-arg UNNEST rewrite to use RowsFromExpr Depends on: - dolthub/vitess#454 - dolthub/go-mysql-server#3412 --- server/ast/aliased_table_expr.go | 135 +++++++++++++++++++++++++------ server/ast/select_clause.go | 14 ++-- server/ast/table_expr.go | 33 +++++++- 3 files changed, 146 insertions(+), 36 deletions(-) diff --git a/server/ast/aliased_table_expr.go b/server/ast/aliased_table_expr.go index db46298b16..879a710f19 100644 --- a/server/ast/aliased_table_expr.go +++ b/server/ast/aliased_table_expr.go @@ -15,6 +15,8 @@ package ast import ( + "strings" + "github.com/cockroachdb/errors" vitess "github.com/dolthub/vitess/go/vt/sqlparser" @@ -24,13 +26,117 @@ import ( ) // nodeAliasedTableExpr handles *tree.AliasedTableExpr nodes. -func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (*vitess.AliasedTableExpr, error) { - if node.Ordinality { - return nil, errors.Errorf("ordinality is not yet supported") - } +func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (vitess.TableExpr, error) { if node.IndexFlags != nil { return nil, errors.Errorf("index flags are not yet supported") } + + // Handle RowsFromExpr specially - it can have WITH ORDINALITY and column aliases + if rowsFrom, ok := node.Expr.(*tree.RowsFromExpr); ok { + // Handle multi-argument UNNEST specially: UNNEST(arr1, arr2, ...) + // is syntactic sugar for ROWS FROM(unnest(arr1), unnest(arr2), ...) + // We need to detect this case and expand it to use RowsFromExpr. + if len(rowsFrom.Items) == 1 { + if funcExpr, ok := rowsFrom.Items[0].(*tree.FuncExpr); ok { + funcName := funcExpr.Func.String() + if strings.EqualFold(funcName, "unnest") && len(funcExpr.Exprs) > 1 { + // Expand multi-arg UNNEST into separate unnest calls + selectExprs := make(vitess.SelectExprs, len(funcExpr.Exprs)) + for i, arg := range funcExpr.Exprs { + argExpr, err := nodeExpr(ctx, arg) + if err != nil { + return nil, err + } + selectExprs[i] = &vitess.AliasedExpr{ + Expr: &vitess.FuncExpr{ + Name: vitess.NewColIdent("unnest"), + Exprs: vitess.SelectExprs{&vitess.AliasedExpr{Expr: argExpr}}, + }, + } + } + + var columns vitess.Columns + if len(node.As.Cols) > 0 { + columns = make(vitess.Columns, len(node.As.Cols)) + for i := range node.As.Cols { + columns[i] = vitess.NewColIdent(string(node.As.Cols[i])) + } + } + + return &vitess.RowsFromExpr{ + Exprs: selectExprs, + WithOrdinality: node.Ordinality, + Alias: vitess.NewTableIdent(string(node.As.Alias)), + Columns: columns, + }, nil + } + } + } + + // For single functions or non-multi-arg-UNNEST cases, use the existing + // subquery-based approach that works with the table function infrastructure. + // Only WITH ORDINALITY requires the new RowsFromExpr approach. + if node.Ordinality { + // Use RowsFromExpr for WITH ORDINALITY support + selectExprs := make(vitess.SelectExprs, len(rowsFrom.Items)) + for i, item := range rowsFrom.Items { + expr, err := nodeExpr(ctx, item) + if err != nil { + return nil, err + } + selectExprs[i] = &vitess.AliasedExpr{Expr: expr} + } + + var columns vitess.Columns + if len(node.As.Cols) > 0 { + columns = make(vitess.Columns, len(node.As.Cols)) + for i := range node.As.Cols { + columns[i] = vitess.NewColIdent(string(node.As.Cols[i])) + } + } + + return &vitess.RowsFromExpr{ + Exprs: selectExprs, + WithOrdinality: node.Ordinality, + Alias: vitess.NewTableIdent(string(node.As.Alias)), + Columns: columns, + }, nil + } + + // For non-ordinality cases, fall through to use the existing + // table function infrastructure via nodeTableExpr + tableExpr, err := nodeTableExpr(ctx, rowsFrom) + if err != nil { + return nil, err + } + + // Wrap in a subquery as the original code did + subquery := &vitess.Subquery{ + Select: &vitess.Select{ + From: vitess.TableExprs{tableExpr}, + }, + } + + if len(node.As.Cols) > 0 { + columns := make([]vitess.ColIdent, len(node.As.Cols)) + for i := range node.As.Cols { + columns[i] = vitess.NewColIdent(string(node.As.Cols[i])) + } + subquery.Columns = columns + } + + return &vitess.AliasedTableExpr{ + Expr: subquery, + As: vitess.NewTableIdent(string(node.As.Alias)), + Lateral: node.Lateral, + }, nil + } + + // For non-RowsFromExpr expressions, ordinality is not yet supported + if node.Ordinality { + return nil, errors.Errorf("ordinality is only supported for ROWS FROM expressions") + } + var aliasExpr vitess.SimpleTableExpr var authInfo vitess.AuthInformation @@ -92,27 +198,6 @@ func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (*vitess.Al Select: selectStmt, } - if len(node.As.Cols) > 0 { - columns := make([]vitess.ColIdent, len(node.As.Cols)) - for i := range node.As.Cols { - columns[i] = vitess.NewColIdent(string(node.As.Cols[i])) - } - subquery.Columns = columns - } - aliasExpr = subquery - case *tree.RowsFromExpr: - tableExpr, err := nodeTableExpr(ctx, expr) - if err != nil { - return nil, err - } - - // TODO: this should be represented as a table function more directly - subquery := &vitess.Subquery{ - Select: &vitess.Select{ - From: vitess.TableExprs{tableExpr}, - }, - } - if len(node.As.Cols) > 0 { columns := make([]vitess.ColIdent, len(node.As.Cols)) for i := range node.As.Cols { diff --git a/server/ast/select_clause.go b/server/ast/select_clause.go index 2fb1b9118e..23cf1f7777 100644 --- a/server/ast/select_clause.go +++ b/server/ast/select_clause.go @@ -162,8 +162,8 @@ PostJoinRewrite: // Handle multi-argument UNNEST: UNNEST(arr1, arr2, ...) produces a table with one column per array, // where corresponding elements are "zipped" together. PostgreSQL pads shorter arrays with NULLs. // We transform: SELECT * FROM UNNEST(arr1, arr2) - // Into: SELECT * FROM (SELECT unnest(arr1), unnest(arr2)) AS unnest - // GMS's ProjectRowWithNestedIters handles multiple SRFs by zipping them together correctly. + // Into: SELECT * FROM ROWS FROM(unnest(arr1), unnest(arr2)) AS unnest + // This uses the native ROWS FROM table function which properly zips SRFs together. if tableFuncExpr, ok := from[i].(*vitess.TableFuncExpr); ok { if strings.EqualFold(tableFuncExpr.Name, "unnest") && len(tableFuncExpr.Exprs) > 1 { selectExprs := make(vitess.SelectExprs, 0, len(tableFuncExpr.Exprs)) @@ -179,13 +179,9 @@ PostJoinRewrite: if alias.IsEmpty() { alias = vitess.NewTableIdent("unnest") } - from[i] = &vitess.AliasedTableExpr{ - Expr: &vitess.Subquery{ - Select: &vitess.Select{ - SelectExprs: selectExprs, - }, - }, - As: alias, + from[i] = &vitess.RowsFromExpr{ + Exprs: selectExprs, + Alias: alias, } } } diff --git a/server/ast/table_expr.go b/server/ast/table_expr.go index 3de0978a64..ef296f5749 100644 --- a/server/ast/table_expr.go +++ b/server/ast/table_expr.go @@ -15,6 +15,8 @@ package ast import ( + "strings" + "github.com/cockroachdb/errors" vitess "github.com/dolthub/vitess/go/vt/sqlparser" @@ -99,12 +101,39 @@ func nodeTableExpr(ctx *Context, node tree.TableExpr) (vitess.TableExpr, error) Exprs: vitess.TableExprs{tableExpr}, }, nil case *tree.RowsFromExpr: + // Handle multi-argument UNNEST specially: UNNEST(arr1, arr2, ...) + // is syntactic sugar for ROWS FROM(unnest(arr1), unnest(arr2), ...) + // We need to detect this case and expand it to use RowsFromExpr. + if len(node.Items) == 1 { + if funcExpr, ok := node.Items[0].(*tree.FuncExpr); ok { + funcName := funcExpr.Func.String() + if strings.EqualFold(funcName, "unnest") && len(funcExpr.Exprs) > 1 { + // Expand multi-arg UNNEST into separate unnest calls + selectExprs := make(vitess.SelectExprs, len(funcExpr.Exprs)) + for i, arg := range funcExpr.Exprs { + argExpr, err := nodeExpr(ctx, arg) + if err != nil { + return nil, err + } + selectExprs[i] = &vitess.AliasedExpr{ + Expr: &vitess.FuncExpr{ + Name: vitess.NewColIdent("unnest"), + Exprs: vitess.SelectExprs{&vitess.AliasedExpr{Expr: argExpr}}, + }, + } + } + return &vitess.RowsFromExpr{ + Exprs: selectExprs, + }, nil + } + } + } + // For single functions or other cases, use the original ValuesStatement approach + // which works with the existing table function infrastructure exprs, err := nodeExprs(ctx, node.Items) if err != nil { return nil, err } - //TODO: not sure if this is correct at all. I think we want to return one result per row, but maybe not. - // This needs to be tested to verify. rows := make([]vitess.ValTuple, len(exprs)) for i := range exprs { rows[i] = vitess.ValTuple{exprs[i]} From 6501e350eba7a108364e4b8d1e191cd9dbce0d0f Mon Sep 17 00:00:00 2001 From: Myers Carpenter Date: Tue, 3 Feb 2026 12:53:49 -0500 Subject: [PATCH 3/4] Add ROWS FROM tests and fix multi-function support - Add comprehensive tests for explicit ROWS FROM syntax: - ROWS FROM with generate_series - ROWS FROM with unnest - ROWS FROM with mixed functions - Fix table_expr.go and aliased_table_expr.go to use RowsFromExpr for all multi-function ROWS FROM cases, not just multi-arg UNNEST Depends on: - dolthub/vitess#454 - dolthub/go-mysql-server#3412 --- server/ast/aliased_table_expr.go | 11 +++---- server/ast/table_expr.go | 19 ++++++++++- testing/go/functions_test.go | 55 ++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/server/ast/aliased_table_expr.go b/server/ast/aliased_table_expr.go index 879a710f19..da4d32f04f 100644 --- a/server/ast/aliased_table_expr.go +++ b/server/ast/aliased_table_expr.go @@ -73,11 +73,10 @@ func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (vitess.Tab } } - // For single functions or non-multi-arg-UNNEST cases, use the existing - // subquery-based approach that works with the table function infrastructure. - // Only WITH ORDINALITY requires the new RowsFromExpr approach. - if node.Ordinality { - // Use RowsFromExpr for WITH ORDINALITY support + // Use RowsFromExpr for: + // 1. Multiple functions: ROWS FROM(func1(), func2()) AS alias + // 2. WITH ORDINALITY: ROWS FROM(func()) WITH ORDINALITY + if len(rowsFrom.Items) > 1 || node.Ordinality { selectExprs := make(vitess.SelectExprs, len(rowsFrom.Items)) for i, item := range rowsFrom.Items { expr, err := nodeExpr(ctx, item) @@ -103,7 +102,7 @@ func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (vitess.Tab }, nil } - // For non-ordinality cases, fall through to use the existing + // For single function without ordinality, fall through to use the existing // table function infrastructure via nodeTableExpr tableExpr, err := nodeTableExpr(ctx, rowsFrom) if err != nil { diff --git a/server/ast/table_expr.go b/server/ast/table_expr.go index ef296f5749..ee32a7c95f 100644 --- a/server/ast/table_expr.go +++ b/server/ast/table_expr.go @@ -128,7 +128,24 @@ func nodeTableExpr(ctx *Context, node tree.TableExpr) (vitess.TableExpr, error) } } } - // For single functions or other cases, use the original ValuesStatement approach + + // For explicit ROWS FROM with multiple functions, use RowsFromExpr + // This handles: ROWS FROM(generate_series(1,3), generate_series(10,12)) + if len(node.Items) > 1 { + selectExprs := make(vitess.SelectExprs, len(node.Items)) + for i, item := range node.Items { + expr, err := nodeExpr(ctx, item) + if err != nil { + return nil, err + } + selectExprs[i] = &vitess.AliasedExpr{Expr: expr} + } + return &vitess.RowsFromExpr{ + Exprs: selectExprs, + }, nil + } + + // For single functions, use the original ValuesStatement approach // which works with the existing table function infrastructure exprs, err := nodeExprs(ctx, node.Items) if err != nil { diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index 375d89c646..9f4d8e5acc 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -1582,6 +1582,61 @@ func TestArrayFunctions(t *testing.T) { }, }, }, + { + Name: "ROWS FROM with generate_series", + Assertions: []ScriptTestAssertion{ + { + // Basic ROWS FROM with two generate_series calls + Query: `SELECT * FROM ROWS FROM(generate_series(1,3), generate_series(10,12))`, + Expected: []sql.Row{{int64(1), int64(10)}, {int64(2), int64(11)}, {int64(3), int64(12)}}, + }, + { + // ROWS FROM with unequal-length series (shorter padded with NULL) + Query: `SELECT * FROM ROWS FROM(generate_series(1,2), generate_series(10,13))`, + Expected: []sql.Row{{int64(1), int64(10)}, {int64(2), int64(11)}, {nil, int64(12)}, {nil, int64(13)}}, + }, + { + // ROWS FROM with table alias + Query: `SELECT r.* FROM ROWS FROM(generate_series(1,2), generate_series(10,11)) AS r`, + Expected: []sql.Row{{int64(1), int64(10)}, {int64(2), int64(11)}}, + }, + { + // ROWS FROM with three functions + Query: `SELECT * FROM ROWS FROM(generate_series(1,2), generate_series(10,11), generate_series(100,101))`, + Expected: []sql.Row{{int64(1), int64(10), int64(100)}, {int64(2), int64(11), int64(101)}}, + }, + }, + }, + { + Name: "ROWS FROM with unnest", + Assertions: []ScriptTestAssertion{ + { + // ROWS FROM with explicit unnest calls + Query: `SELECT * FROM ROWS FROM(unnest(ARRAY['a','b']), unnest(ARRAY[1,2]))`, + Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}}, + }, + { + // ROWS FROM with unequal-length unnest + Query: `SELECT * FROM ROWS FROM(unnest(ARRAY['a','b','c']), unnest(ARRAY[1,2]))`, + Expected: []sql.Row{{"a", int64(1)}, {"b", int64(2)}, {"c", nil}}, + }, + }, + }, + { + Name: "ROWS FROM mixed functions", + Assertions: []ScriptTestAssertion{ + { + // Mix generate_series and unnest + Query: `SELECT * FROM ROWS FROM(generate_series(1,3), unnest(ARRAY['x','y','z']))`, + Expected: []sql.Row{{int64(1), "x"}, {int64(2), "y"}, {int64(3), "z"}}, + }, + { + // Mix with unequal lengths + Query: `SELECT * FROM ROWS FROM(generate_series(1,2), unnest(ARRAY['x','y','z']))`, + Expected: []sql.Row{{int64(1), "x"}, {int64(2), "y"}, {nil, "z"}}, + }, + }, + }, { Name: "array_to_string", SetUpScript: []string{}, From 2de7f7dbb74cc3a8e50f6a9e45162260fda8a9ce Mon Sep 17 00:00:00 2001 From: Myers Carpenter Date: Wed, 11 Feb 2026 22:05:28 -0500 Subject: [PATCH 4/4] Replace RowsFromExpr with TableFuncExpr and own RowsFrom execution Move RowsFrom and RowsFromIter from GMS into server/node as an ExecBuilderNode. Register BuildMultiExprTableFunc factory to handle nameless TableFuncExpr (ROWS FROM pattern). Update AST conversion to emit TableFuncExpr instead of the removed RowsFromExpr. --- server/analyzer/init.go | 9 + server/ast/aliased_table_expr.go | 7 +- server/ast/select_clause.go | 2 +- server/ast/table_expr.go | 7 +- server/node/rows_from.go | 386 +++++++++++++++++++++++++++++++ 5 files changed, 402 insertions(+), 9 deletions(-) create mode 100644 server/node/rows_from.go diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 1047bb2457..929daaae8c 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -15,6 +15,7 @@ package analyzer import ( + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/memo" @@ -22,6 +23,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/planbuilder" pgexpression "github.com/dolthub/doltgresql/server/expression" + pgnodes "github.com/dolthub/doltgresql/server/node" ) // IDs are basically arbitrary, we just need to ensure that they do not conflict with existing IDs @@ -114,6 +116,13 @@ func initEngine() { planbuilder.IsAggregateFunc = IsAggregateFunc + planbuilder.BuildMultiExprTableFunc = func( + exprs []sql.Expression, alias string, + withOrdinality bool, columnAliases []string, + ) (sql.Node, error) { + return pgnodes.NewRowsFrom(exprs, alias, withOrdinality, columnAliases), nil + } + expression.DefaultExpressionFactory = pgexpression.PostgresExpressionFactory{} // There are a couple places during analysis where SplitConjunction in GMS cannot correctly split up diff --git a/server/ast/aliased_table_expr.go b/server/ast/aliased_table_expr.go index da4d32f04f..57c2099f13 100644 --- a/server/ast/aliased_table_expr.go +++ b/server/ast/aliased_table_expr.go @@ -35,7 +35,6 @@ func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (vitess.Tab if rowsFrom, ok := node.Expr.(*tree.RowsFromExpr); ok { // Handle multi-argument UNNEST specially: UNNEST(arr1, arr2, ...) // is syntactic sugar for ROWS FROM(unnest(arr1), unnest(arr2), ...) - // We need to detect this case and expand it to use RowsFromExpr. if len(rowsFrom.Items) == 1 { if funcExpr, ok := rowsFrom.Items[0].(*tree.FuncExpr); ok { funcName := funcExpr.Func.String() @@ -63,7 +62,7 @@ func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (vitess.Tab } } - return &vitess.RowsFromExpr{ + return &vitess.TableFuncExpr{ Exprs: selectExprs, WithOrdinality: node.Ordinality, Alias: vitess.NewTableIdent(string(node.As.Alias)), @@ -73,7 +72,7 @@ func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (vitess.Tab } } - // Use RowsFromExpr for: + // Use TableFuncExpr (nameless) for: // 1. Multiple functions: ROWS FROM(func1(), func2()) AS alias // 2. WITH ORDINALITY: ROWS FROM(func()) WITH ORDINALITY if len(rowsFrom.Items) > 1 || node.Ordinality { @@ -94,7 +93,7 @@ func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (vitess.Tab } } - return &vitess.RowsFromExpr{ + return &vitess.TableFuncExpr{ Exprs: selectExprs, WithOrdinality: node.Ordinality, Alias: vitess.NewTableIdent(string(node.As.Alias)), diff --git a/server/ast/select_clause.go b/server/ast/select_clause.go index 23cf1f7777..dddbeed3c4 100644 --- a/server/ast/select_clause.go +++ b/server/ast/select_clause.go @@ -179,7 +179,7 @@ PostJoinRewrite: if alias.IsEmpty() { alias = vitess.NewTableIdent("unnest") } - from[i] = &vitess.RowsFromExpr{ + from[i] = &vitess.TableFuncExpr{ Exprs: selectExprs, Alias: alias, } diff --git a/server/ast/table_expr.go b/server/ast/table_expr.go index ee32a7c95f..762164c4a6 100644 --- a/server/ast/table_expr.go +++ b/server/ast/table_expr.go @@ -103,7 +103,6 @@ func nodeTableExpr(ctx *Context, node tree.TableExpr) (vitess.TableExpr, error) case *tree.RowsFromExpr: // Handle multi-argument UNNEST specially: UNNEST(arr1, arr2, ...) // is syntactic sugar for ROWS FROM(unnest(arr1), unnest(arr2), ...) - // We need to detect this case and expand it to use RowsFromExpr. if len(node.Items) == 1 { if funcExpr, ok := node.Items[0].(*tree.FuncExpr); ok { funcName := funcExpr.Func.String() @@ -122,14 +121,14 @@ func nodeTableExpr(ctx *Context, node tree.TableExpr) (vitess.TableExpr, error) }, } } - return &vitess.RowsFromExpr{ + return &vitess.TableFuncExpr{ Exprs: selectExprs, }, nil } } } - // For explicit ROWS FROM with multiple functions, use RowsFromExpr + // For explicit ROWS FROM with multiple functions, use TableFuncExpr (nameless) // This handles: ROWS FROM(generate_series(1,3), generate_series(10,12)) if len(node.Items) > 1 { selectExprs := make(vitess.SelectExprs, len(node.Items)) @@ -140,7 +139,7 @@ func nodeTableExpr(ctx *Context, node tree.TableExpr) (vitess.TableExpr, error) } selectExprs[i] = &vitess.AliasedExpr{Expr: expr} } - return &vitess.RowsFromExpr{ + return &vitess.TableFuncExpr{ Exprs: selectExprs, }, nil } diff --git a/server/node/rows_from.go b/server/node/rows_from.go new file mode 100644 index 0000000000..aa21bd610d --- /dev/null +++ b/server/node/rows_from.go @@ -0,0 +1,386 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package node + +import ( + "errors" + "fmt" + "io" + "strings" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// RowsFrom represents a ROWS FROM table function that executes multiple +// set-returning functions in parallel and zips their results together. +// This is the PostgreSQL-compatible syntax: ROWS FROM(func1(...), func2(...), ...) +type RowsFrom struct { + // Functions contains the set-returning function expressions to execute + Functions []sql.Expression + // withOrdinality when true, adds an ordinality column to the result + withOrdinality bool + // alias is the table alias for this ROWS FROM expression + alias string + // columnAliases are optional column names for the result columns + columnAliases []string + // colset tracks the column IDs for this node + colset sql.ColSet + // id is the table ID for this node + id sql.TableId +} + +var _ sql.Node = (*RowsFrom)(nil) +var _ sql.Expressioner = (*RowsFrom)(nil) +var _ sql.CollationCoercible = (*RowsFrom)(nil) +var _ plan.TableIdNode = (*RowsFrom)(nil) +var _ sql.RenameableNode = (*RowsFrom)(nil) +var _ sql.ExecBuilderNode = (*RowsFrom)(nil) + +// NewRowsFrom creates a new RowsFrom node with the given function expressions. +func NewRowsFrom(exprs []sql.Expression, alias string, withOrdinality bool, columnAliases []string) *RowsFrom { + return &RowsFrom{ + Functions: exprs, + withOrdinality: withOrdinality, + alias: alias, + columnAliases: columnAliases, + } +} + +// BuildRowIter implements sql.ExecBuilderNode. +func (r *RowsFrom) BuildRowIter(ctx *sql.Context, b sql.NodeExecBuilder, row sql.Row) (sql.RowIter, error) { + return NewRowsFromIter(r.Functions, r.withOrdinality, row), nil +} + +// WithId implements plan.TableIdNode +func (r *RowsFrom) WithId(id sql.TableId) plan.TableIdNode { + ret := *r + ret.id = id + return &ret +} + +// Id implements plan.TableIdNode +func (r *RowsFrom) Id() sql.TableId { + return r.id +} + +// WithColumns implements plan.TableIdNode +func (r *RowsFrom) WithColumns(set sql.ColSet) plan.TableIdNode { + ret := *r + ret.colset = set + return &ret +} + +// Columns implements plan.TableIdNode +func (r *RowsFrom) Columns() sql.ColSet { + return r.colset +} + +// Name returns the alias name for this ROWS FROM expression +func (r *RowsFrom) Name() string { + if r.alias != "" { + return r.alias + } + return "rows_from" +} + +// WithName implements sql.RenameableNode +func (r *RowsFrom) WithName(s string) sql.Node { + ret := *r + ret.alias = s + return &ret +} + +// Schema implements the sql.Node interface. +func (r *RowsFrom) Schema() sql.Schema { + var schema sql.Schema + + for i, f := range r.Functions { + colName := fmt.Sprintf("col%d", i) + if i < len(r.columnAliases) && r.columnAliases[i] != "" { + colName = r.columnAliases[i] + } else if nameable, ok := f.(sql.Nameable); ok { + colName = nameable.Name() + } + + schema = append(schema, &sql.Column{ + Name: colName, + Type: f.Type(), + Nullable: true, // SRF results can be NULL when zipping unequal-length results + Source: r.Name(), + }) + } + + if r.withOrdinality { + schema = append(schema, &sql.Column{ + Name: "ordinality", + Type: types.Int64, + Nullable: false, + Source: r.Name(), + }) + } + + return schema +} + +// Children implements the sql.Node interface. +func (r *RowsFrom) Children() []sql.Node { + return nil +} + +// Resolved implements the sql.Resolvable interface. +func (r *RowsFrom) Resolved() bool { + for _, f := range r.Functions { + if !f.Resolved() { + return false + } + } + return true +} + +// IsReadOnly implements the sql.Node interface. +func (r *RowsFrom) IsReadOnly() bool { + return true +} + +// String implements the sql.Node interface. +func (r *RowsFrom) String() string { + var sb strings.Builder + sb.WriteString("ROWS FROM(") + for i, f := range r.Functions { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(f.String()) + } + sb.WriteString(")") + if r.withOrdinality { + sb.WriteString(" WITH ORDINALITY") + } + if r.alias != "" { + sb.WriteString(" AS ") + sb.WriteString(r.alias) + } + return sb.String() +} + +// DebugString implements the sql.DebugStringer interface. +func (r *RowsFrom) DebugString() string { + var sb strings.Builder + sb.WriteString("RowsFrom(") + for i, f := range r.Functions { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(sql.DebugString(f)) + } + sb.WriteString(")") + if r.withOrdinality { + sb.WriteString(" WITH ORDINALITY") + } + if r.alias != "" { + sb.WriteString(" AS ") + sb.WriteString(r.alias) + } + return sb.String() +} + +// Expressions implements the sql.Expressioner interface. +func (r *RowsFrom) Expressions() []sql.Expression { + return r.Functions +} + +// WithExpressions implements the sql.Expressioner interface. +func (r *RowsFrom) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(r.Functions) { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(exprs), len(r.Functions)) + } + ret := *r + ret.Functions = exprs + return &ret, nil +} + +// WithChildren implements the sql.Node interface. +func (r *RowsFrom) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0) + } + return r, nil +} + +// CollationCoercibility implements the interface sql.CollationCoercible. +func (*RowsFrom) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.Collation_binary, 7 +} + +// RowsFromIter is an iterator for the RowsFrom node. +// It executes multiple set-returning functions in parallel and zips their results together. +// When one function is exhausted before another, NULL is used for its values. +type RowsFromIter struct { + functions []sql.Expression + iters []sql.RowIter + finished []bool + withOrdinality bool + ordinality int64 + initialized bool + sourceRow sql.Row +} + +var _ sql.RowIter = (*RowsFromIter)(nil) + +// NewRowsFromIter creates a new RowsFromIter. +func NewRowsFromIter(functions []sql.Expression, withOrdinality bool, row sql.Row) *RowsFromIter { + return &RowsFromIter{ + functions: functions, + withOrdinality: withOrdinality, + sourceRow: row, + finished: make([]bool, len(functions)), + } +} + +// Next implements the sql.RowIter interface. +func (r *RowsFromIter) Next(ctx *sql.Context) (sql.Row, error) { + if !r.initialized { + if err := r.initIterators(ctx); err != nil { + return nil, err + } + r.initialized = true + } + + allFinished := true + for _, f := range r.finished { + if !f { + allFinished = false + break + } + } + if allFinished { + return nil, io.EOF + } + + row := make(sql.Row, len(r.functions)) + for i, iter := range r.iters { + if r.finished[i] { + row[i] = nil + continue + } + + nextRow, err := iter.Next(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + r.finished[i] = true + row[i] = nil + continue + } + return nil, err + } + + if len(nextRow) > 0 { + row[i] = nextRow[0] + } else { + row[i] = nil + } + } + + allFinished = true + for _, f := range r.finished { + if !f { + allFinished = false + break + } + } + + allNulls := true + for _, v := range row { + if v != nil { + allNulls = false + break + } + } + + if allFinished && allNulls { + return nil, io.EOF + } + + r.ordinality++ + if r.withOrdinality { + row = append(row, r.ordinality) + } + + return row, nil +} + +func (r *RowsFromIter) initIterators(ctx *sql.Context) error { + r.iters = make([]sql.RowIter, len(r.functions)) + + for i, f := range r.functions { + if rie, ok := f.(sql.RowIterExpression); ok && rie.ReturnsRowIter() { + iter, err := rie.EvalRowIter(ctx, r.sourceRow) + if err != nil { + for j := 0; j < i; j++ { + if r.iters[j] != nil { + r.iters[j].Close(ctx) + } + } + return err + } + r.iters[i] = iter + } else { + val, err := f.Eval(ctx, r.sourceRow) + if err != nil { + for j := 0; j < i; j++ { + if r.iters[j] != nil { + r.iters[j].Close(ctx) + } + } + return err + } + r.iters[i] = &singleValueIter{value: val} + } + } + + return nil +} + +// Close implements the sql.RowIter interface. +func (r *RowsFromIter) Close(ctx *sql.Context) error { + var firstErr error + for _, iter := range r.iters { + if iter != nil { + if err := iter.Close(ctx); err != nil && firstErr == nil { + firstErr = err + } + } + } + return firstErr +} + +type singleValueIter struct { + value interface{} + consumed bool +} + +func (s *singleValueIter) Next(ctx *sql.Context) (sql.Row, error) { + if s.consumed { + return nil, io.EOF + } + s.consumed = true + return sql.Row{s.value}, nil +} + +func (s *singleValueIter) Close(ctx *sql.Context) error { + return nil +}