From 3003e3fbc02718c3900a06665a348b82aef8a9dc Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 4 Mar 2026 00:44:03 +0000 Subject: [PATCH] refactor: split sql builders from the sqlglot_ir class --- bigframes/bigquery/_operations/sql.py | 7 +- .../sqlglot/aggregations/unary_compiler.py | 16 +- bigframes/core/compile/sqlglot/compiler.py | 82 +++--- .../compile/sqlglot/expression_compiler.py | 4 +- .../sqlglot/expressions/comparison_ops.py | 10 +- .../sqlglot/expressions/generic_ops.py | 49 ++-- .../core/compile/sqlglot/sql/__init__.py | 39 +++ bigframes/core/compile/sqlglot/sql/base.py | 156 +++++++++++ bigframes/core/compile/sqlglot/sql/dml.py | 57 ++++ bigframes/core/compile/sqlglot/sqlglot_ir.py | 260 ++++-------------- bigframes/core/sql/__init__.py | 14 +- bigframes/core/sql/ml.py | 22 +- bigframes/ml/compose.py | 8 +- bigframes/ml/sql.py | 36 +-- bigframes/session/_io/bigquery/__init__.py | 4 +- bigframes/session/bigquery_session.py | 8 +- bigframes/session/bq_caching_executor.py | 8 +- 17 files changed, 436 insertions(+), 344 deletions(-) create mode 100644 bigframes/core/compile/sqlglot/sql/__init__.py create mode 100644 bigframes/core/compile/sqlglot/sql/base.py create mode 100644 bigframes/core/compile/sqlglot/sql/dml.py diff --git a/bigframes/bigquery/_operations/sql.py b/bigframes/bigquery/_operations/sql.py index e6ac1b9c27..c3846b8335 100644 --- a/bigframes/bigquery/_operations/sql.py +++ b/bigframes/bigquery/_operations/sql.py @@ -20,7 +20,7 @@ import google.cloud.bigquery -from bigframes.core.compile.sqlglot import sqlglot_ir +from bigframes.core.compile.sqlglot import sql import bigframes.dtypes import bigframes.operations import bigframes.series @@ -68,10 +68,7 @@ def sql_scalar( # Another benefit of this is that if there is a syntax error in the SQL # template, then this will fail with an error earlier in the process, # aiding users in debugging. - literals_sql = [ - sqlglot_ir._literal(None, column.dtype).sql(dialect="bigquery") - for column in columns - ] + literals_sql = [sql.to_sql(sql.literal(None, column.dtype)) for column in columns] select_sql = sql_template.format(*literals_sql) dry_run_sql = f"SELECT {select_sql}" diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index add3ccd923..cca0f02133 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -22,11 +22,11 @@ from bigframes import dtypes from bigframes.core import window_spec +from bigframes.core.compile.sqlglot import sql import bigframes.core.compile.sqlglot.aggregations.op_registration as reg from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present from bigframes.core.compile.sqlglot.expressions import constants import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr -import bigframes.core.compile.sqlglot.sqlglot_ir as ir from bigframes.operations import aggregations as agg_ops UNARY_OP_REGISTRATION = reg.OpRegistration() @@ -157,9 +157,9 @@ def _cut_ops_w_int_bins( for this_bin in range(bins): value: sge.Expression if op.labels is False: - value = ir._literal(this_bin, dtypes.INT_DTYPE) + value = sql.literal(this_bin, dtypes.INT_DTYPE) elif isinstance(op.labels, typing.Iterable): - value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE) + value = sql.literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE) else: left_adj: sge.Expression = ( adj if this_bin == 0 and op.right else sge.convert(0) @@ -217,10 +217,10 @@ def _cut_ops_w_intervals( ) -> sge.Case: case_expr = sge.Case() for this_bin, interval in enumerate(bins): - left: sge.Expression = ir._literal( + left: sge.Expression = sql.literal( interval[0], dtypes.infer_literal_type(interval[0]) ) - right: sge.Expression = ir._literal( + right: sge.Expression = sql.literal( interval[1], dtypes.infer_literal_type(interval[1]) ) condition: sge.Expression @@ -237,9 +237,9 @@ def _cut_ops_w_intervals( value: sge.Expression if op.labels is False: - value = ir._literal(this_bin, dtypes.INT_DTYPE) + value = sql.literal(this_bin, dtypes.INT_DTYPE) elif isinstance(op.labels, typing.Iterable): - value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE) + value = sql.literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE) else: if op.right: left_identifier = sge.Identifier(this="left_exclusive", quoted=True) @@ -609,7 +609,7 @@ def _( # Will be null if all inputs are null. Pandas defaults to zero sum though. zero = pd.to_timedelta(0) if column.dtype == dtypes.TIMEDELTA_DTYPE else 0 - return sge.func("IFNULL", expr, ir._literal(zero, column.dtype)) + return sge.func("IFNULL", expr, sql.literal(zero, column.dtype)) @UNARY_OP_REGISTRATION.register(agg_ops.VarOp) diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 6b90b94067..d4ae01f511 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -30,11 +30,11 @@ sql_nodes, ) from bigframes.core.compile import configs +from bigframes.core.compile.sqlglot import sql, sqlglot_ir import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler from bigframes.core.compile.sqlglot.aggregations import windows import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler from bigframes.core.compile.sqlglot.expressions import typed_expr -import bigframes.core.compile.sqlglot.sqlglot_ir as ir from bigframes.core.logging import data_types as data_type_logger import bigframes.core.ordering as bf_ordering from bigframes.core.rewrite import schema_binding @@ -108,20 +108,20 @@ def _compile_result_node(root: nodes.ResultNode) -> str: # Probably, should defer even further root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root)) - sqlglot_ir = compile_node(rewrite.as_sql_nodes(root), uid_gen) - return sqlglot_ir.sql + sqlglot_ir_obj = compile_node(rewrite.as_sql_nodes(root), uid_gen) + return sqlglot_ir_obj.sql def compile_node( node: nodes.BigFrameNode, uid_gen: guid.SequentialUIDGenerator -) -> ir.SQLGlotIR: +) -> sqlglot_ir.SQLGlotIR: """Compiles the given BigFrameNode from bottem-up into SQLGlotIR.""" - bf_to_sqlglot: dict[nodes.BigFrameNode, ir.SQLGlotIR] = {} - child_results: tuple[ir.SQLGlotIR, ...] = () + bf_to_sqlglot: dict[nodes.BigFrameNode, sqlglot_ir.SQLGlotIR] = {} + child_results: tuple[sqlglot_ir.SQLGlotIR, ...] = () for current_node in list(node.iter_nodes_topo()): if current_node.child_nodes == (): # For leaf node, generates a dumpy child to pass the UID generator. - child_results = tuple([ir.SQLGlotIR(uid_gen=uid_gen)]) + child_results = tuple([sqlglot_ir.SQLGlotIR(uid_gen=uid_gen)]) else: # Child nodes should have been compiled in the reverse topological order. child_results = tuple( @@ -135,14 +135,14 @@ def compile_node( @functools.singledispatch def _compile_node( - node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR -) -> ir.SQLGlotIR: + node: nodes.BigFrameNode, *compiled_children: sqlglot_ir.SQLGlotIR +) -> sqlglot_ir.SQLGlotIR: """Defines transformation but isn't cached, always use compile_node instead""" raise ValueError(f"Can't compile unrecognized node: {node}") @_compile_node.register -def compile_sql_select(node: sql_nodes.SqlSelectNode, child: ir.SQLGlotIR): +def compile_sql_select(node: sql_nodes.SqlSelectNode, child: sqlglot_ir.SQLGlotIR): ordering_cols = tuple( sge.Ordered( this=expression_compiler.expression_compiler.compile_expression( @@ -175,7 +175,9 @@ def compile_sql_select(node: sql_nodes.SqlSelectNode, child: ir.SQLGlotIR): @_compile_node.register -def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_readlocal( + node: nodes.ReadLocalNode, child: sqlglot_ir.SQLGlotIR +) -> sqlglot_ir.SQLGlotIR: pa_table = node.local_data_source.data pa_table = pa_table.select([item.source_id for item in node.scan_list.items]) pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items]) @@ -184,16 +186,18 @@ def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLG if offsets: pa_table = pyarrow_utils.append_offsets(pa_table, offsets) - return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=child.uid_gen) + return sqlglot_ir.SQLGlotIR.from_pyarrow( + pa_table, node.schema, uid_gen=child.uid_gen + ) @_compile_node.register -def compile_readtable(node: sql_nodes.SqlDataSource, child: ir.SQLGlotIR): - table = node.source.table - return ir.SQLGlotIR.from_table( - table.project_id, - table.dataset_id, - table.table_id, +def compile_readtable(node: sql_nodes.SqlDataSource, child: sqlglot_ir.SQLGlotIR): + table_obj = node.source.table + return sqlglot_ir.SQLGlotIR.from_table( + table_obj.project_id, + table_obj.dataset_id, + table_obj.table_id, uid_gen=child.uid_gen, sql_predicate=node.source.sql_predicate, system_time=node.source.at_time, @@ -202,20 +206,20 @@ def compile_readtable(node: sql_nodes.SqlDataSource, child: ir.SQLGlotIR): @_compile_node.register def compile_join( - node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR -) -> ir.SQLGlotIR: + node: nodes.JoinNode, left: sqlglot_ir.SQLGlotIR, right: sqlglot_ir.SQLGlotIR +) -> sqlglot_ir.SQLGlotIR: conditions = tuple( ( typed_expr.TypedExpr( - expression_compiler.expression_compiler.compile_expression(left), - left.output_type, + expression_compiler.expression_compiler.compile_expression(left_expr), + left_expr.output_type, ), typed_expr.TypedExpr( - expression_compiler.expression_compiler.compile_expression(right), - right.output_type, + expression_compiler.expression_compiler.compile_expression(right_expr), + right_expr.output_type, ), ) - for left, right in node.conditions + for left_expr, right_expr in node.conditions ) return left.join( @@ -228,8 +232,8 @@ def compile_join( @_compile_node.register def compile_isin_join( - node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR -) -> ir.SQLGlotIR: + node: nodes.InNode, left: sqlglot_ir.SQLGlotIR, right: sqlglot_ir.SQLGlotIR +) -> sqlglot_ir.SQLGlotIR: right_field = node.right_child.fields[0] conditions = ( typed_expr.TypedExpr( @@ -253,7 +257,9 @@ def compile_isin_join( @_compile_node.register -def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_concat( + node: nodes.ConcatNode, *children: sqlglot_ir.SQLGlotIR +) -> sqlglot_ir.SQLGlotIR: assert len(children) >= 1 uid_gen = children[0].uid_gen @@ -264,7 +270,7 @@ def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlo for default_output_id, output_id in zip(default_output_ids, node.output_ids) ] - return ir.SQLGlotIR.from_union( + return sqlglot_ir.SQLGlotIR.from_union( [child._as_select() for child in children], output_aliases=output_aliases, uid_gen=uid_gen, @@ -272,7 +278,9 @@ def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlo @_compile_node.register -def compile_explode(node: nodes.ExplodeNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_explode( + node: nodes.ExplodeNode, child: sqlglot_ir.SQLGlotIR +) -> sqlglot_ir.SQLGlotIR: offsets_col = node.offsets_col.sql if (node.offsets_col is not None) else None columns = tuple(ref.id.sql for ref in node.column_ids) return child.explode(columns, offsets_col) @@ -280,8 +288,8 @@ def compile_explode(node: nodes.ExplodeNode, child: ir.SQLGlotIR) -> ir.SQLGlotI @_compile_node.register def compile_fromrange( - node: nodes.FromRangeNode, start: ir.SQLGlotIR, end: ir.SQLGlotIR -) -> ir.SQLGlotIR: + node: nodes.FromRangeNode, start: sqlglot_ir.SQLGlotIR, end: sqlglot_ir.SQLGlotIR +) -> sqlglot_ir.SQLGlotIR: start_col_id = node.start.fields[0].id end_col_id = node.end.fields[0].id @@ -291,20 +299,22 @@ def compile_fromrange( end_expr = expression_compiler.expression_compiler.compile_expression( expression.DerefOp(end_col_id) ) - step_expr = ir._literal(node.step, dtypes.INT_DTYPE) + step_expr = sql.literal(node.step, dtypes.INT_DTYPE) return start.resample(end, node.output_id.sql, start_expr, end_expr, step_expr) @_compile_node.register def compile_random_sample( - node: nodes.RandomSampleNode, child: ir.SQLGlotIR -) -> ir.SQLGlotIR: + node: nodes.RandomSampleNode, child: sqlglot_ir.SQLGlotIR +) -> sqlglot_ir.SQLGlotIR: return child.sample(node.fraction) @_compile_node.register -def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_aggregate( + node: nodes.AggregateNode, child: sqlglot_ir.SQLGlotIR +) -> sqlglot_ir.SQLGlotIR: # The BigQuery ordered aggregation cannot support for NULL FIRST/LAST, # so we need to add extra expressions to enforce the null ordering. ordering_cols = windows.get_window_order_by(node.order_by, override_null_order=True) diff --git a/bigframes/core/compile/sqlglot/expression_compiler.py b/bigframes/core/compile/sqlglot/expression_compiler.py index b2ff34bf74..49780fbaea 100644 --- a/bigframes/core/compile/sqlglot/expression_compiler.py +++ b/bigframes/core/compile/sqlglot/expression_compiler.py @@ -19,8 +19,8 @@ import bigframes_vendored.sqlglot.expressions as sge import bigframes.core.agg_expressions as agg_exprs +from bigframes.core.compile.sqlglot import sql from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr -import bigframes.core.compile.sqlglot.sqlglot_ir as ir import bigframes.core.expression as ex import bigframes.operations as ops @@ -77,7 +77,7 @@ def _(self, expr: ex.DerefOp) -> sge.Expression: @compile_expression.register def _(self, expr: ex.ScalarConstantExpression) -> sge.Expression: - return ir._literal(expr.value, expr.dtype) + return sql.literal(expr.value, expr.dtype) @compile_expression.register def _(self, expr: agg_exprs.WindowExpression) -> sge.Expression: diff --git a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py index f767314be7..44103e500f 100644 --- a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py @@ -22,7 +22,7 @@ from bigframes import dtypes from bigframes import operations as ops -from bigframes.core.compile.sqlglot import sqlglot_ir +from bigframes.core.compile.sqlglot import sql import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr @@ -59,9 +59,9 @@ def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression: @register_binary_op(ops.eq_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if sqlglot_ir._is_null_literal(left.expr): + if sql.is_null_literal(left.expr): return sge.Is(this=right.expr, expression=sge.Null()) - if sqlglot_ir._is_null_literal(right.expr): + if sql.is_null_literal(right.expr): return sge.Is(this=left.expr, expression=sge.Null()) left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) @@ -140,12 +140,12 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.ne_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if sqlglot_ir._is_null_literal(left.expr): + if sql.is_null_literal(left.expr): return sge.Is( this=sge.paren(right.expr, copy=False), expression=sg.not_(sge.Null(), copy=False), ) - if sqlglot_ir._is_null_literal(right.expr): + if sql.is_null_literal(right.expr): return sge.Is( this=sge.paren(left.expr, copy=False), expression=sg.not_(sge.Null(), copy=False), diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 14af91e591..46032145e2 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -19,7 +19,7 @@ from bigframes import dtypes from bigframes import operations as ops -from bigframes.core.compile.sqlglot import sqlglot_ir, sqlglot_types +from bigframes.core.compile.sqlglot import sql, sqlglot_types import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr @@ -48,8 +48,8 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: return result if to_type == dtypes.FLOAT_DTYPE and from_type == dtypes.BOOL_DTYPE: - sg_expr = _cast(sg_expr, "INT64", op.safe) - return _cast(sg_expr, sg_to_type, op.safe) + sg_expr = sql.cast(sg_expr, "INT64", op.safe) + return sql.cast(sg_expr, sg_to_type, op.safe) if to_type == dtypes.BOOL_DTYPE: if from_type == dtypes.BOOL_DTYPE: @@ -58,16 +58,16 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: return sge.NEQ(this=sg_expr, expression=sge.convert(0)) if to_type == dtypes.STRING_DTYPE: - sg_expr = _cast(sg_expr, sg_to_type, op.safe) + sg_expr = sql.cast(sg_expr, sg_to_type, op.safe) if from_type == dtypes.BOOL_DTYPE: sg_expr = sge.func("INITCAP", sg_expr) return sg_expr if dtypes.is_time_like(to_type) and from_type == dtypes.INT_DTYPE: sg_expr = sge.func("TIMESTAMP_MICROS", sg_expr) - return _cast(sg_expr, sg_to_type, op.safe) + return sql.cast(sg_expr, sg_to_type, op.safe) - return _cast(sg_expr, sg_to_type, op.safe) + return sql.cast(sg_expr, sg_to_type, op.safe) @register_unary_op(ops.hash_op) @@ -104,17 +104,19 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression: mappings = [ ( - sqlglot_ir._literal(key, dtypes.is_compatible(key, expr.dtype)), - sqlglot_ir._literal(value, dtypes.is_compatible(value, expr.dtype)), + sql.literal(key, dtypes.is_compatible(key, expr.dtype)), + sql.literal(value, dtypes.is_compatible(value, expr.dtype)), ) for key, value in op.mappings ] return sge.Case( ifs=[ sge.If( - this=sge.EQ(this=expr.expr, expression=key) - if not sqlglot_ir._is_null_literal(key) - else sge.Is(this=expr.expr, expression=sge.Null()), + this=( + sge.EQ(this=expr.expr, expression=key) + if not sql.is_null_literal(key) + else sge.Is(this=expr.expr, expression=sge.Null()) + ), true=value, ) for key, value in mappings @@ -201,12 +203,14 @@ def _(*cases_and_outputs: TypedExpr) -> sge.Expression: ) if do_upcast_bool: result_values = tuple( - TypedExpr( - sge.Cast(this=val.expr, to="INT64"), - dtypes.INT_DTYPE, + ( + TypedExpr( + sge.Cast(this=val.expr, to="INT64"), + dtypes.INT_DTYPE, + ) + if val.dtype == dtypes.BOOL_DTYPE + else val ) - if val.dtype == dtypes.BOOL_DTYPE - else val for val in result_values ) @@ -286,30 +290,23 @@ def _cast_to_int(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression | None: sg_expr = expr.expr # Cannot cast DATETIME to INT directly so need to convert to TIMESTAMP first. if from_type == dtypes.DATETIME_DTYPE: - sg_expr = _cast(sg_expr, "TIMESTAMP", op.safe) + sg_expr = sql.cast(sg_expr, "TIMESTAMP", op.safe) return sge.func("UNIX_MICROS", sg_expr) if from_type == dtypes.TIMESTAMP_DTYPE: return sge.func("UNIX_MICROS", sg_expr) if from_type == dtypes.TIME_DTYPE: return sge.func( "TIME_DIFF", - _cast(sg_expr, "TIME", op.safe), + sql.cast(sg_expr, "TIME", op.safe), sge.convert("00:00:00"), "MICROSECOND", ) if from_type == dtypes.NUMERIC_DTYPE or from_type == dtypes.FLOAT_DTYPE: sg_expr = sge.func("TRUNC", sg_expr) - return _cast(sg_expr, "INT64", op.safe) + return sql.cast(sg_expr, "INT64", op.safe) return None -def _cast(expr: sge.Expression, to: str, safe: bool): - if safe: - return sge.TryCast(this=expr, to=to) - else: - return sge.Cast(this=expr, to=to) - - def _convert_to_nonnull_string_sqlglot(expr: TypedExpr) -> sge.Expression: col_type = expr.dtype sg_expr = expr.expr diff --git a/bigframes/core/compile/sqlglot/sql/__init__.py b/bigframes/core/compile/sqlglot/sql/__init__.py new file mode 100644 index 0000000000..6d2dbd65a6 --- /dev/null +++ b/bigframes/core/compile/sqlglot/sql/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from bigframes.core.compile.sqlglot.sql.base import ( + cast, + escape_chars, + identifier, + is_null_literal, + literal, + table, + to_sql, +) +from bigframes.core.compile.sqlglot.sql.dml import insert, replace + +__all__ = [ + # From base.py + "cast", + "escape_chars", + "identifier", + "is_null_literal", + "literal", + "table", + "to_sql", + # From dml.py + "insert", + "replace", +] diff --git a/bigframes/core/compile/sqlglot/sql/base.py b/bigframes/core/compile/sqlglot/sql/base.py new file mode 100644 index 0000000000..d268a57357 --- /dev/null +++ b/bigframes/core/compile/sqlglot/sql/base.py @@ -0,0 +1,156 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing + +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge +from google.cloud import bigquery +import numpy as np +import pandas as pd +import pyarrow as pa + +from bigframes import dtypes +from bigframes.core import utils +from bigframes.core.compile.sqlglot.expressions import constants +import bigframes.core.compile.sqlglot.sqlglot_types as sgt + +# shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0. +try: + from shapely.io import to_wkt # type: ignore +except ImportError: + from shapely.wkt import dumps # type: ignore + + to_wkt = dumps + + +QUOTED: bool = True +"""Whether to quote identifiers in the generated SQL.""" + +PRETTY: bool = True +"""Whether to pretty-print the generated SQL.""" + +DIALECT = sg.dialects.bigquery.BigQuery +"""The SQL dialect used for generation.""" + + +def to_sql(expr: sge.Expression) -> str: + """Generate SQL string from the given expression.""" + return expr.sql(dialect=DIALECT, pretty=PRETTY) + + +def identifier(id: str) -> sge.Identifier: + """Return a string representing column reference in a SQL.""" + return sge.to_identifier(id, quoted=QUOTED) + + +def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: + sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None + if sqlglot_type is None: + if not pd.isna(value): + raise ValueError(f"Cannot infer SQLGlot type from None dtype: {value}") + return sge.Null() + + if value is None: + return cast(sge.Null(), sqlglot_type) + if dtypes.is_struct_like(dtype): + items = [ + literal(value=value[field_name], dtype=field_dtype).as_( + field_name, quoted=True + ) + for field_name, field_dtype in dtypes.get_struct_fields(dtype).items() + ] + return sge.Struct.from_arg_list(items) + elif dtypes.is_array_like(dtype): + value_type = dtypes.get_array_inner_type(dtype) + values = sge.Array( + expressions=[literal(value=v, dtype=value_type) for v in value] + ) + return values if len(value) > 0 else cast(values, sqlglot_type) + elif pd.isna(value) or (isinstance(value, pa.Scalar) and not value.is_valid): + return cast(sge.Null(), sqlglot_type) + elif dtype == dtypes.JSON_DTYPE: + return sge.ParseJSON(this=sge.convert(str(value))) + elif dtype == dtypes.BYTES_DTYPE: + return cast(str(value), sqlglot_type) + elif dtypes.is_time_like(dtype): + if isinstance(value, str): + return cast(sge.convert(value), sqlglot_type) + if isinstance(value, np.generic): + value = value.item() + return cast(sge.convert(value.isoformat()), sqlglot_type) + elif dtype in (dtypes.NUMERIC_DTYPE, dtypes.BIGNUMERIC_DTYPE): + return cast(sge.convert(value), sqlglot_type) + elif dtypes.is_geo_like(dtype): + wkt = value if isinstance(value, str) else to_wkt(value) + return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt)) + elif dtype == dtypes.TIMEDELTA_DTYPE: + return sge.convert(utils.timedelta_to_micros(value)) + elif dtype == dtypes.FLOAT_DTYPE: + if np.isinf(value): + return constants._INF if value > 0 else constants._NEG_INF + return sge.convert(value) + else: + if isinstance(value, np.generic): + value = value.item() + return sge.convert(value) + + +def cast(arg: typing.Any, to: str, safe: bool = False) -> sge.Cast | sge.TryCast: + if safe: + return sge.TryCast(this=arg, to=to) + else: + return sge.Cast(this=arg, to=to) + + +def table(table: bigquery.TableReference) -> sge.Table: + return sge.Table( + this=sge.to_identifier(table.table_id, quoted=True), + db=sge.to_identifier(table.dataset_id, quoted=True), + catalog=sge.to_identifier(table.project, quoted=True), + ) + + +def escape_chars(value: str): + """Escapes all special characters""" + # TODO: Reuse literal's escaping logic instead of re-implementing it here. + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#string_and_bytes_literals + trans_table = str.maketrans( + { + "\a": r"\a", + "\b": r"\b", + "\f": r"\f", + "\n": r"\n", + "\r": r"\r", + "\t": r"\t", + "\v": r"\v", + "\\": r"\\", + "?": r"\?", + '"': r"\"", + "'": r"\'", + "`": r"\`", + } + ) + return value.translate(trans_table) + + +def is_null_literal(expr: sge.Expression) -> bool: + """Checks if the given expression is a NULL literal.""" + if isinstance(expr, sge.Null): + return True + if isinstance(expr, sge.Cast) and isinstance(expr.this, sge.Null): + return True + return False diff --git a/bigframes/core/compile/sqlglot/sql/dml.py b/bigframes/core/compile/sqlglot/sql/dml.py new file mode 100644 index 0000000000..1a1b140ee2 --- /dev/null +++ b/bigframes/core/compile/sqlglot/sql/dml.py @@ -0,0 +1,57 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing + +import bigframes_vendored.sqlglot.expressions as sge +from google.cloud import bigquery + +from bigframes import dtypes +from bigframes.core.compile.sqlglot.sql import base + + +def insert( + query_or_table: typing.Union[sge.Select, sge.Table], + destination: bigquery.TableReference, +) -> sge.Insert: + """Generates an INSERT INTO SQL statement from the given SELECT statement or + table reference.""" + return sge.insert(_as_from_item(query_or_table), base.table(destination)) + + +def replace( + query_or_table: typing.Union[sge.Select, sge.Table], + destination: bigquery.TableReference, +) -> sge.Merge: + """Generates a MERGE statement to replace the contents of the destination table.""" + return sge.Merge( + this=base.table(destination), + using=_as_from_item(query_or_table), + on=base.literal(False, dtypes.BOOL_DTYPE), + whens=[ + sge.When(matched=False, source=True, then=sge.Delete()), + sge.When(matched=False, then=sge.Insert(this=sge.Var(this="ROW"))), + ], + ) + + +def _as_from_item( + query_or_table: typing.Union[sge.Select, sge.Table] +) -> typing.Union[sge.Subquery, sge.Table]: + if isinstance(query_or_table, sge.Select): + return query_or_table.subquery() + else: # table + return query_or_table diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 94ffa39dae..52906ffb69 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -21,14 +21,12 @@ import bigframes_vendored.sqlglot as sg import bigframes_vendored.sqlglot.expressions as sge -from google.cloud import bigquery -import numpy as np -import pandas as pd import pyarrow as pa from bigframes import dtypes -from bigframes.core import guid, local_data, schema, utils -from bigframes.core.compile.sqlglot.expressions import constants, typed_expr +from bigframes.core import guid, local_data, schema +from bigframes.core.compile.sqlglot import sql +from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.sqlglot_types as sgt # shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0. @@ -47,22 +45,13 @@ class SQLGlotIR: expr: typing.Union[sge.Select, sge.Table] = sg.select() """The SQLGlot expression representing the query.""" - dialect = sg.dialects.bigquery.BigQuery - """The SQL dialect used for generation.""" - - quoted: bool = True - """Whether to quote identifiers in the generated SQL.""" - - pretty: bool = True - """Whether to pretty-print the generated SQL.""" - uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator() """Generator for unique identifiers.""" @property def sql(self) -> str: """Generate SQL string from the given expression.""" - return self.expr.sql(dialect=self.dialect, pretty=self.pretty) + return sql.to_sql(self.expr) @classmethod def from_pyarrow( @@ -89,7 +78,7 @@ def from_pyarrow( data_expr = [ sge.Struct( expressions=tuple( - _literal( + sql.literal( value=value, dtype=field.dtype, ) @@ -143,16 +132,16 @@ def from_table( ) table_alias = next(uid_gen.get_uid_stream("bft_")) table_expr = sge.Table( - this=sg.to_identifier(table_id, quoted=cls.quoted), - db=sg.to_identifier(dataset_id, quoted=cls.quoted), - catalog=sg.to_identifier(project_id, quoted=cls.quoted), + this=sql.identifier(table_id), + db=sql.identifier(dataset_id), + catalog=sql.identifier(project_id), version=version, - alias=sge.Identifier(this=table_alias, quoted=cls.quoted), + alias=sql.identifier(table_alias), ) if sql_predicate: select_expr = sge.Select().select(sge.Star()).from_(table_expr) select_expr = select_expr.where( - sg.parse_one(sql_predicate, dialect=cls.dialect), append=False + sg.parse_one(sql_predicate, dialect=sql.base.DIALECT), append=False ) return cls(expr=select_expr, uid_gen=uid_gen) @@ -178,7 +167,7 @@ def select( to_select = [ sge.Alias( this=expr, - alias=sge.to_identifier(id, quoted=self.quoted), + alias=sql.identifier(id), ) if expr.alias_or_name != id else expr @@ -197,7 +186,7 @@ def select( return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) @classmethod - def from_query_string( + def from_unparsed_query( cls, query_string: str, ) -> SQLGlotIR: @@ -205,9 +194,7 @@ def from_query_string( in a CTE can avoid the query parsing issue for unsupported syntax in SQLGlot.""" uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator() - cte_name = sge.to_identifier( - next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted - ) + cte_name = sql.identifier(next(uid_gen.get_uid_stream("bfcte_"))) cte = sge.CTE( this=query_string, alias=cte_name, @@ -251,8 +238,8 @@ def from_union( selections = [ sge.Alias( - this=sge.to_identifier(old_name, quoted=cls.quoted), - alias=sge.to_identifier(new_name, quoted=cls.quoted), + this=sql.identifier(old_name), + alias=sql.identifier(new_name), ) for old_name, new_name in output_aliases ] @@ -318,9 +305,7 @@ def isin_join( new_column: sge.Expression if joins_nulls: - right_table_name = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bft_")), quoted=self.quoted - ) + right_table_name = sql.identifier(next(self.uid_gen.get_uid_stream("bft_"))) right_condition = typed_expr.TypedExpr( sge.Column(this=conditions[1].expr, table=right_table_name), conditions[1].dtype, @@ -341,7 +326,7 @@ def isin_join( new_column = sge.Alias( this=new_column, - alias=sge.to_identifier(indicator_col, quoted=self.quoted), + alias=sql.identifier(indicator_col), ) new_expr = ( @@ -370,7 +355,7 @@ def sample(self, fraction: float) -> SQLGlotIR: """Uniform samples a fraction of the rows.""" condition = sge.LT( this=sge.func("RAND"), - expression=_literal(fraction, dtypes.FLOAT_DTYPE), + expression=sql.literal(fraction, dtypes.FLOAT_DTYPE), ) new_expr = self._select_to_cte()[0].where(condition, append=False) @@ -392,7 +377,7 @@ def aggregate( aggregations_expr = [ sge.Alias( this=expr, - alias=sge.to_identifier(id, quoted=self.quoted), + alias=sql.identifier(id), ) for id, expr in aggregations ] @@ -431,15 +416,15 @@ def resample( generate_array = sge.func("GENERATE_ARRAY", start_expr, stop_expr, step_expr) - unnested_column_alias = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted + unnested_column_alias = sql.identifier( + next(self.uid_gen.get_uid_stream("bfcol_")) ) unnest_expr = sge.Unnest( expressions=[generate_array], alias=sge.TableAlias(columns=[unnested_column_alias]), ) - final_col_id = sge.to_identifier(array_col_name, quoted=self.quoted) + final_col_id = sql.identifier(array_col_name) # Build final expression by joining everything directly in a single SELECT new_expr = ( @@ -453,50 +438,14 @@ def resample( return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) - def insert( - self, - destination: bigquery.TableReference, - ) -> str: - """Generates an INSERT INTO SQL statement from the current SELECT clause.""" - return sge.insert(self._as_from_item(), _table(destination)).sql( - dialect=self.dialect, pretty=self.pretty - ) - - def replace( - self, - destination: bigquery.TableReference, - ) -> str: - """Generates a MERGE statement to replace the destination table's contents. - by the current SELECT clause. - """ - # Workaround for SQLGlot breaking change: - # https://github.com/tobymao/sqlglot/pull/4495 - whens_expr = [ - sge.When(matched=False, source=True, then=sge.Delete()), - sge.When(matched=False, then=sge.Insert(this=sge.Var(this="ROW"))), - ] - whens_str = "\n".join( - when_expr.sql(dialect=self.dialect, pretty=self.pretty) - for when_expr in whens_expr - ) - - merge_str = sge.Merge( - this=_table(destination), - using=self._as_from_item(), - on=_literal(False, dtypes.BOOL_DTYPE), - ).sql(dialect=self.dialect, pretty=self.pretty) - return f"{merge_str}\n{whens_str}" - def _explode_single_column( self, column_name: str, offsets_col: typing.Optional[str] ) -> SQLGlotIR: """Helper method to handle the case of exploding a single column.""" - offset = ( - sge.to_identifier(offsets_col, quoted=self.quoted) if offsets_col else None - ) - column = sge.to_identifier(column_name, quoted=self.quoted) - unnested_column_alias = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted + offset = sql.identifier(offsets_col) if offsets_col else None + column = sql.identifier(column_name) + unnested_column_alias = sql.identifier( + next(self.uid_gen.get_uid_stream("bfcol_")) ) unnest_expr = sge.Unnest( expressions=[column], @@ -518,27 +467,19 @@ def _explode_multiple_columns( offsets_col: typing.Optional[str], ) -> SQLGlotIR: """Helper method to handle the case of exploding multiple columns.""" - offset = ( - sge.to_identifier(offsets_col, quoted=self.quoted) if offsets_col else None - ) - columns = [ - sge.to_identifier(column_name, quoted=self.quoted) - for column_name in column_names - ] + offset = sql.identifier(offsets_col) if offsets_col else None + columns = [sql.identifier(column_name) for column_name in column_names] # If there are multiple columns, we need to unnest by zipping the arrays: # https://cloud.google.com/bigquery/docs/arrays#zipping_arrays - column_lengths = [ - sge.func("ARRAY_LENGTH", sge.to_identifier(column, quoted=self.quoted)) - 1 - for column in columns - ] + column_lengths = [sge.func("ARRAY_LENGTH", column) - 1 for column in columns] generate_array = sge.func( "GENERATE_ARRAY", sge.convert(0), sge.func("LEAST", *column_lengths), ) - unnested_offset_alias = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted + unnested_offset_alias = sql.identifier( + next(self.uid_gen.get_uid_stream("bfcol_")) ) unnest_expr = sge.Unnest( expressions=[generate_array], @@ -563,12 +504,6 @@ def _explode_multiple_columns( ) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) - def _as_from_item(self) -> typing.Union[sge.Table, sge.Subquery]: - if isinstance(self.expr, sge.Select): - return self.expr.subquery() - else: # table - return self.expr - def _as_select(self) -> sge.Select: if isinstance(self.expr, sge.Select): return self.expr @@ -582,9 +517,7 @@ def _select_to_cte(self) -> tuple[sge.Select, sge.Identifier]: """Transforms a given sge.Select query by pushing its main SELECT statement into a new CTE and then generates a 'SELECT * FROM new_cte_name' for the new query.""" - cte_name = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted - ) + cte_name = sql.identifier(next(self.uid_gen.get_uid_stream("bfcte_"))) select_expr = self._as_select().copy() select_expr, existing_ctes = _pop_query_ctes(select_expr) new_cte = sge.CTE( @@ -598,107 +531,6 @@ def _select_to_cte(self) -> tuple[sge.Select, sge.Identifier]: return new_select_expr, cte_name -def identifier(id: str) -> str: - """Return a string representing column reference in a SQL.""" - return sge.to_identifier(id, quoted=SQLGlotIR.quoted).sql(dialect=SQLGlotIR.dialect) - - -def _escape_chars(value: str): - """Escapes all special characters""" - # TODO: Reuse _literal's escaping logic instead of re-implementing it here. - # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#string_and_bytes_literals - trans_table = str.maketrans( - { - "\a": r"\a", - "\b": r"\b", - "\f": r"\f", - "\n": r"\n", - "\r": r"\r", - "\t": r"\t", - "\v": r"\v", - "\\": r"\\", - "?": r"\?", - '"': r"\"", - "'": r"\'", - "`": r"\`", - } - ) - return value.translate(trans_table) - - -def _is_null_literal(expr: sge.Expression) -> bool: - """Checks if the given expression is a NULL literal.""" - if isinstance(expr, sge.Null): - return True - if isinstance(expr, sge.Cast) and isinstance(expr.this, sge.Null): - return True - return False - - -def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: - sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None - if sqlglot_type is None: - if not pd.isna(value): - raise ValueError(f"Cannot infer SQLGlot type from None dtype: {value}") - return sge.Null() - - if value is None: - return _cast(sge.Null(), sqlglot_type) - if dtypes.is_struct_like(dtype): - items = [ - _literal(value=value[field_name], dtype=field_dtype).as_( - field_name, quoted=True - ) - for field_name, field_dtype in dtypes.get_struct_fields(dtype).items() - ] - return sge.Struct.from_arg_list(items) - elif dtypes.is_array_like(dtype): - value_type = dtypes.get_array_inner_type(dtype) - values = sge.Array( - expressions=[_literal(value=v, dtype=value_type) for v in value] - ) - return values if len(value) > 0 else _cast(values, sqlglot_type) - elif pd.isna(value) or (isinstance(value, pa.Scalar) and not value.is_valid): - return _cast(sge.Null(), sqlglot_type) - elif dtype == dtypes.JSON_DTYPE: - return sge.ParseJSON(this=sge.convert(str(value))) - elif dtype == dtypes.BYTES_DTYPE: - return _cast(str(value), sqlglot_type) - elif dtypes.is_time_like(dtype): - if isinstance(value, str): - return _cast(sge.convert(value), sqlglot_type) - if isinstance(value, np.generic): - value = value.item() - return _cast(sge.convert(value.isoformat()), sqlglot_type) - elif dtype in (dtypes.NUMERIC_DTYPE, dtypes.BIGNUMERIC_DTYPE): - return _cast(sge.convert(value), sqlglot_type) - elif dtypes.is_geo_like(dtype): - wkt = value if isinstance(value, str) else to_wkt(value) - return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt)) - elif dtype == dtypes.TIMEDELTA_DTYPE: - return sge.convert(utils.timedelta_to_micros(value)) - elif dtype == dtypes.FLOAT_DTYPE: - if np.isinf(value): - return constants._INF if value > 0 else constants._NEG_INF - return sge.convert(value) - else: - if isinstance(value, np.generic): - value = value.item() - return sge.convert(value) - - -def _cast(arg: typing.Any, to: str) -> sge.Cast: - return sge.Cast(this=arg, to=to) - - -def _table(table: bigquery.TableReference) -> sge.Table: - return sge.Table( - this=sg.to_identifier(table.table_id, quoted=True), - db=sg.to_identifier(table.dataset_id, quoted=True), - catalog=sg.to_identifier(table.project, quoted=True), - ) - - def _and(conditions: tuple[sge.Expression, ...]) -> typing.Optional[sge.Expression]: """Chains multiple expressions together using a logical AND.""" if not conditions: @@ -751,12 +583,12 @@ def _join_condition_for_others( """Generates a join condition for non-numeric types to match pandas's null-handling logic. """ - left_str = _cast(left.expr, "STRING") - right_str = _cast(right.expr, "STRING") - left_0 = sge.func("COALESCE", left_str, _literal("0", dtypes.STRING_DTYPE)) - left_1 = sge.func("COALESCE", left_str, _literal("1", dtypes.STRING_DTYPE)) - right_0 = sge.func("COALESCE", right_str, _literal("0", dtypes.STRING_DTYPE)) - right_1 = sge.func("COALESCE", right_str, _literal("1", dtypes.STRING_DTYPE)) + left_str = sql.cast(left.expr, "STRING") + right_str = sql.cast(right.expr, "STRING") + left_0 = sge.func("COALESCE", left_str, sql.literal("0", dtypes.STRING_DTYPE)) + left_1 = sge.func("COALESCE", left_str, sql.literal("1", dtypes.STRING_DTYPE)) + right_0 = sge.func("COALESCE", right_str, sql.literal("0", dtypes.STRING_DTYPE)) + right_1 = sge.func("COALESCE", right_str, sql.literal("1", dtypes.STRING_DTYPE)) return sge.And( this=sge.EQ(this=left_0, expression=right_0), expression=sge.EQ(this=left_1, expression=right_1), @@ -774,10 +606,10 @@ def _join_condition_for_numeric( is_floating_types = ( left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE ) - left_0 = sge.func("COALESCE", left.expr, _literal(0, left.dtype)) - left_1 = sge.func("COALESCE", left.expr, _literal(1, left.dtype)) - right_0 = sge.func("COALESCE", right.expr, _literal(0, right.dtype)) - right_1 = sge.func("COALESCE", right.expr, _literal(1, right.dtype)) + left_0 = sge.func("COALESCE", left.expr, sql.literal(0, left.dtype)) + left_1 = sge.func("COALESCE", left.expr, sql.literal(1, left.dtype)) + right_0 = sge.func("COALESCE", right.expr, sql.literal(0, right.dtype)) + right_1 = sge.func("COALESCE", right.expr, sql.literal(1, right.dtype)) if not is_floating_types: return sge.And( this=sge.EQ(this=left_0, expression=right_0), @@ -785,16 +617,16 @@ def _join_condition_for_numeric( ) left_2 = sge.If( - this=sge.IsNan(this=left.expr), true=_literal(2, left.dtype), false=left_0 + this=sge.IsNan(this=left.expr), true=sql.literal(2, left.dtype), false=left_0 ) left_3 = sge.If( - this=sge.IsNan(this=left.expr), true=_literal(3, left.dtype), false=left_1 + this=sge.IsNan(this=left.expr), true=sql.literal(3, left.dtype), false=left_1 ) right_2 = sge.If( - this=sge.IsNan(this=right.expr), true=_literal(2, right.dtype), false=right_0 + this=sge.IsNan(this=right.expr), true=sql.literal(2, right.dtype), false=right_0 ) right_3 = sge.If( - this=sge.IsNan(this=right.expr), true=_literal(3, right.dtype), false=right_1 + this=sge.IsNan(this=right.expr), true=sql.literal(3, right.dtype), false=right_1 ) return sge.And( this=sge.EQ(this=left_2, expression=right_2), diff --git a/bigframes/core/sql/__init__.py b/bigframes/core/sql/__init__.py index b025ca07c2..e17830042d 100644 --- a/bigframes/core/sql/__init__.py +++ b/bigframes/core/sql/__init__.py @@ -26,7 +26,7 @@ import bigframes_vendored.sqlglot.expressions as sge import shapely.geometry.base # type: ignore -from bigframes.core.compile.sqlglot import sqlglot_ir +from bigframes.core.compile.sqlglot import sql if TYPE_CHECKING: import google.cloud.bigquery as bigquery @@ -66,7 +66,7 @@ def simple_literal(value: Union[SIMPLE_LITERAL_TYPES, None]) -> str: return "NULL" elif isinstance(value, str): # Single quoting seems to work nicer with ibis than double quoting - return f"'{sqlglot_ir._escape_chars(value)}'" + return f"'{sql.escape_chars(value)}'" elif isinstance(value, bytes): return repr(value) elif isinstance(value, (bool, int)): @@ -119,7 +119,7 @@ def cast_as_string(column_name: str) -> str: def to_json_string(column_name: str) -> str: """Return a string representing JSON version of a column.""" - return f"TO_JSON_STRING({sqlglot_ir.identifier(column_name)})" + return f"TO_JSON_STRING({sql.to_sql(sql.identifier(column_name))})" def csv(values: Iterable[str]) -> str: @@ -202,7 +202,7 @@ def create_vector_index_ddl( if len(stored_column_names) > 0: escaped_stored = [ - f"{sqlglot_ir.identifier(name)}" for name in stored_column_names + f"{sql.to_sql(sql.identifier(name))}" for name in stored_column_names ] storing = f"STORING({', '.join(escaped_stored)}) " else: @@ -216,8 +216,8 @@ def create_vector_index_ddl( ) return f""" - {create} {sqlglot_ir.identifier(index_name)} - ON {sqlglot_ir.identifier(table_name)}({sqlglot_ir.identifier(column_name)}) + {create} {sql.to_sql(sql.identifier(index_name))} + ON {sql.to_sql(sql.identifier(table_name))}({sql.to_sql(sql.identifier(column_name))}) {storing} OPTIONS({rendered_options}); """ @@ -236,7 +236,7 @@ def create_vector_search_sql( """Encode the VECTOR SEARCH statement for BigQuery Vector Search.""" vector_search_args = [ - f"TABLE {sqlglot_ir.identifier(cast(str, base_table))}", + f"TABLE {sql.to_sql(sql.identifier(cast(str, base_table)))}", f"{simple_literal(column_to_search)}", f"({sql_string})", ] diff --git a/bigframes/core/sql/ml.py b/bigframes/core/sql/ml.py index 38d66ab9a5..391d905d2f 100644 --- a/bigframes/core/sql/ml.py +++ b/bigframes/core/sql/ml.py @@ -16,7 +16,7 @@ from typing import Any, Dict, List, Mapping, Optional, Union -from bigframes.core.compile.sqlglot import sqlglot_ir +from bigframes.core.compile.sqlglot import sql as sg_sql import bigframes.core.sql import bigframes.core.sql.literals @@ -46,7 +46,7 @@ def create_model_ddl( else: create = "CREATE MODEL " - ddl = f"{create}{sqlglot_ir.identifier(model_name)}\n" + ddl = f"{create}{sg_sql.to_sql(sg_sql.identifier(model_name))}\n" # [TRANSFORM (select_list)] if transform: @@ -66,7 +66,7 @@ def create_model_ddl( if connection_name.upper() == "DEFAULT": ddl += "REMOTE WITH CONNECTION DEFAULT\n" else: - ddl += f"REMOTE WITH CONNECTION {sqlglot_ir.identifier(connection_name)}\n" + ddl += f"REMOTE WITH CONNECTION {sg_sql.to_sql(sg_sql.identifier(connection_name))}\n" # [OPTIONS(model_option_list)] if options: @@ -130,7 +130,7 @@ def evaluate( if confidence_level is not None: struct_options["confidence_level"] = confidence_level - sql = f"SELECT * FROM ML.EVALUATE(MODEL {sqlglot_ir.identifier(model_name)}" + sql = f"SELECT * FROM ML.EVALUATE(MODEL {sg_sql.to_sql(sg_sql.identifier(model_name))}" if table: sql += f", ({table})" @@ -158,9 +158,7 @@ def predict( if trial_id is not None: struct_options["trial_id"] = trial_id - sql = ( - f"SELECT * FROM ML.PREDICT(MODEL {sqlglot_ir.identifier(model_name)}, ({table})" - ) + sql = f"SELECT * FROM ML.PREDICT(MODEL {sg_sql.to_sql(sg_sql.identifier(model_name))}, ({table})" sql += _build_struct_sql(struct_options) sql += ")\n" return sql @@ -190,7 +188,7 @@ def explain_predict( if approx_feature_contrib is not None: struct_options["approx_feature_contrib"] = approx_feature_contrib - sql = f"SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {sqlglot_ir.identifier(model_name)}, ({table})" + sql = f"SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {sg_sql.to_sql(sg_sql.identifier(model_name))}, ({table})" sql += _build_struct_sql(struct_options) sql += ")\n" return sql @@ -208,7 +206,7 @@ def global_explain( if class_level_explain is not None: struct_options["class_level_explain"] = class_level_explain - sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {sqlglot_ir.identifier(model_name)}" + sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {sg_sql.to_sql(sg_sql.identifier(model_name))}" sql += _build_struct_sql(struct_options) sql += ")\n" return sql @@ -221,7 +219,7 @@ def transform( """Encode the ML.TRANSFORM statement. See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-transform for reference. """ - sql = f"SELECT * FROM ML.TRANSFORM(MODEL {sqlglot_ir.identifier(model_name)}, ({table}))\n" + sql = f"SELECT * FROM ML.TRANSFORM(MODEL {sg_sql.to_sql(sg_sql.identifier(model_name))}, ({table}))\n" return sql @@ -262,7 +260,7 @@ def generate_text( if request_type is not None: struct_options["request_type"] = request_type - sql = f"SELECT * FROM ML.GENERATE_TEXT(MODEL {sqlglot_ir.identifier(model_name)}, ({table})" + sql = f"SELECT * FROM ML.GENERATE_TEXT(MODEL {sg_sql.to_sql(sg_sql.identifier(model_name))}, ({table})" sql += _build_struct_sql(struct_options) sql += ")\n" return sql @@ -290,7 +288,7 @@ def generate_embedding( if output_dimensionality is not None: struct_options["output_dimensionality"] = output_dimensionality - sql = f"SELECT * FROM ML.GENERATE_EMBEDDING(MODEL {sqlglot_ir.identifier(model_name)}, ({table})" + sql = f"SELECT * FROM ML.GENERATE_EMBEDDING(MODEL {sg_sql.to_sql(sg_sql.identifier(model_name))}, ({table})" sql += _build_struct_sql(struct_options) sql += ")\n" return sql diff --git a/bigframes/ml/compose.py b/bigframes/ml/compose.py index 9413cd0695..d81e3ab1bd 100644 --- a/bigframes/ml/compose.py +++ b/bigframes/ml/compose.py @@ -27,7 +27,7 @@ import bigframes_vendored.sklearn.compose._column_transformer from google.cloud import bigquery -import bigframes.core.compile.sqlglot.sqlglot_ir as sql_utils +from bigframes.core.compile.sqlglot import sql as sg_sql from bigframes.core.logging import log_adapter import bigframes.core.utils as core_utils from bigframes.ml import base, core, globals, impute, preprocessing, utils @@ -111,9 +111,9 @@ def _compile_to_sql( columns, _ = core_utils.get_standardized_ids(columns) result = [] for column in columns: - current_sql = self._sql.format(sql_utils.identifier(column)) - current_target_column = sql_utils.identifier( - self._target_column.format(column) + current_sql = self._sql.format(sg_sql.to_sql(sg_sql.identifier(column))) + current_target_column = sg_sql.to_sql( + sg_sql.identifier(self._target_column.format(column)) ) result.append(f"{current_sql} AS {current_target_column}") return result diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index d90d23a474..be9055e956 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -21,7 +21,7 @@ import bigframes_vendored.constants as constants import google.cloud.bigquery -import bigframes.core.compile.sqlglot.sqlglot_ir as sql_utils +from bigframes.core.compile.sqlglot import sql as sg_sql import bigframes.core.sql as sql_vals INDENT_STR = " " @@ -62,7 +62,7 @@ def build_structs(self, **kwargs: Union[int, float, str, Mapping]) -> str: v_trans = self.build_schema(**v) if isinstance(v, Mapping) else v param_strs.append( - f"{sql_vals.simple_literal(v_trans)} AS {sql_utils.identifier(k)}" + f"{sql_vals.simple_literal(v_trans)} AS {sg_sql.to_sql(sg_sql.identifier(k))}" ) return "\n" + INDENT_STR + f",\n{INDENT_STR}".join(param_strs) @@ -73,7 +73,9 @@ def build_expressions(self, *expr_sqls: str) -> str: def build_schema(self, **kwargs: str) -> str: """Encode a dict of values into a formatted schema type items for SQL""" - param_strs = [f"{sql_utils.identifier(k)} {v}" for k, v in kwargs.items()] + param_strs = [ + f"{sg_sql.to_sql(sg_sql.identifier(k))} {v}" for k, v in kwargs.items() + ] return "\n" + INDENT_STR + f",\n{INDENT_STR}".join(param_strs) def options(self, **kwargs: Union[str, int, float, Iterable[str]]) -> str: @@ -86,7 +88,9 @@ def struct_options(self, **kwargs: Union[int, float, Mapping]) -> str: def struct_columns(self, columns: Iterable[str]) -> str: """Encode a BQ Table columns to a STRUCT.""" - columns_str = ", ".join(map(sql_utils.identifier, columns)) + columns_str = ", ".join( + map(lambda x: sg_sql.to_sql(sg_sql.identifier(x)), columns) + ) return f"STRUCT({columns_str})" def input(self, **kwargs: str) -> str: @@ -109,15 +113,15 @@ def transform(self, *expr_sqls: str) -> str: def ml_standard_scaler(self, numeric_expr_sql: str, name: str) -> str: """Encode ML.STANDARD_SCALER for BQML""" - return f"""ML.STANDARD_SCALER({sql_utils.identifier(numeric_expr_sql)}) OVER() AS {sql_utils.identifier(name)}""" + return f"""ML.STANDARD_SCALER({sg_sql.to_sql(sg_sql.identifier(numeric_expr_sql))}) OVER() AS {sg_sql.to_sql(sg_sql.identifier(name))}""" def ml_max_abs_scaler(self, numeric_expr_sql: str, name: str) -> str: """Encode ML.MAX_ABS_SCALER for BQML""" - return f"""ML.MAX_ABS_SCALER({sql_utils.identifier(numeric_expr_sql)}) OVER() AS {sql_utils.identifier(name)}""" + return f"""ML.MAX_ABS_SCALER({sg_sql.to_sql(sg_sql.identifier(numeric_expr_sql))}) OVER() AS {sg_sql.to_sql(sg_sql.identifier(name))}""" def ml_min_max_scaler(self, numeric_expr_sql: str, name: str) -> str: """Encode ML.MIN_MAX_SCALER for BQML""" - return f"""ML.MIN_MAX_SCALER({sql_utils.identifier(numeric_expr_sql)}) OVER() AS {sql_utils.identifier(name)}""" + return f"""ML.MIN_MAX_SCALER({sg_sql.to_sql(sg_sql.identifier(numeric_expr_sql))}) OVER() AS {sg_sql.to_sql(sg_sql.identifier(name))}""" def ml_imputer( self, @@ -126,7 +130,7 @@ def ml_imputer( name: str, ) -> str: """Encode ML.IMPUTER for BQML""" - return f"""ML.IMPUTER({sql_utils.identifier(col_name)}, '{strategy}') OVER() AS {sql_utils.identifier(name)}""" + return f"""ML.IMPUTER({sg_sql.to_sql(sg_sql.identifier(col_name))}, '{strategy}') OVER() AS {sg_sql.to_sql(sg_sql.identifier(name))}""" def ml_bucketize( self, @@ -140,7 +144,7 @@ def ml_bucketize( point.item() if hasattr(point, "item") else point for point in array_split_points ] - return f"""ML.BUCKETIZE({sql_utils.identifier(input_id)}, {points}, FALSE) AS {sql_utils.identifier(output_id)}""" + return f"""ML.BUCKETIZE({sg_sql.to_sql(sg_sql.identifier(input_id))}, {points}, FALSE) AS {sg_sql.to_sql(sg_sql.identifier(output_id))}""" def ml_quantile_bucketize( self, @@ -149,7 +153,7 @@ def ml_quantile_bucketize( name: str, ) -> str: """Encode ML.QUANTILE_BUCKETIZE for BQML""" - return f"""ML.QUANTILE_BUCKETIZE({sql_utils.identifier(numeric_expr_sql)}, {num_bucket}) OVER() AS {sql_utils.identifier(name)}""" + return f"""ML.QUANTILE_BUCKETIZE({sg_sql.to_sql(sg_sql.identifier(numeric_expr_sql))}, {num_bucket}) OVER() AS {sg_sql.to_sql(sg_sql.identifier(name))}""" def ml_one_hot_encoder( self, @@ -162,7 +166,7 @@ def ml_one_hot_encoder( """Encode ML.ONE_HOT_ENCODER for BQML. https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-one-hot-encoder for params. """ - return f"""ML.ONE_HOT_ENCODER({sql_utils.identifier(numeric_expr_sql)}, '{drop}', {top_k}, {frequency_threshold}) OVER() AS {sql_utils.identifier(name)}""" + return f"""ML.ONE_HOT_ENCODER({sg_sql.to_sql(sg_sql.identifier(numeric_expr_sql))}, '{drop}', {top_k}, {frequency_threshold}) OVER() AS {sg_sql.to_sql(sg_sql.identifier(name))}""" def ml_label_encoder( self, @@ -174,7 +178,7 @@ def ml_label_encoder( """Encode ML.LABEL_ENCODER for BQML. https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-label-encoder for params. """ - return f"""ML.LABEL_ENCODER({sql_utils.identifier(numeric_expr_sql)}, {top_k}, {frequency_threshold}) OVER() AS {sql_utils.identifier(name)}""" + return f"""ML.LABEL_ENCODER({sg_sql.to_sql(sg_sql.identifier(numeric_expr_sql))}, {top_k}, {frequency_threshold}) OVER() AS {sg_sql.to_sql(sg_sql.identifier(name))}""" def ml_polynomial_expand( self, columns: Iterable[str], degree: int, name: str @@ -182,7 +186,7 @@ def ml_polynomial_expand( """Encode ML.POLYNOMIAL_EXPAND. https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-polynomial-expand """ - return f"""ML.POLYNOMIAL_EXPAND({self.struct_columns(columns)}, {degree}) AS {sql_utils.identifier(name)}""" + return f"""ML.POLYNOMIAL_EXPAND({self.struct_columns(columns)}, {degree}) AS {sg_sql.to_sql(sg_sql.identifier(name))}""" def ml_distance( self, @@ -195,7 +199,7 @@ def ml_distance( """Encode ML.DISTANCE for BQML. https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-distance """ - return f"""SELECT *, ML.DISTANCE({sql_utils.identifier(col_x)}, {sql_utils.identifier(col_y)}, '{type}') AS {sql_utils.identifier(name)} FROM ({source_sql})""" + return f"""SELECT *, ML.DISTANCE({sg_sql.to_sql(sg_sql.identifier(col_x))}, {sg_sql.to_sql(sg_sql.identifier(col_y))}, '{type}') AS {sg_sql.to_sql(sg_sql.identifier(name))} FROM ({source_sql})""" def ai_forecast( self, @@ -217,7 +221,7 @@ def _model_id_sql( self, model_ref: google.cloud.bigquery.ModelReference, ): - return f"{sql_utils.identifier(model_ref.project)}.{sql_utils.identifier(model_ref.dataset_id)}.{sql_utils.identifier(model_ref.model_id)}" + return f"{sg_sql.to_sql(sg_sql.identifier(model_ref.project))}.{sg_sql.to_sql(sg_sql.identifier(model_ref.dataset_id))}.{sg_sql.to_sql(sg_sql.identifier(model_ref.model_id))}" # Model create and alter def create_model( @@ -308,7 +312,7 @@ def __init__(self, model_ref: google.cloud.bigquery.ModelReference): self._model_ref = model_ref def _model_ref_sql(self) -> str: - return f"{sql_utils.identifier(self._model_ref.project)}.{sql_utils.identifier(self._model_ref.dataset_id)}.{sql_utils.identifier(self._model_ref.model_id)}" + return f"{sg_sql.to_sql(sg_sql.identifier(self._model_ref.project))}.{sg_sql.to_sql(sg_sql.identifier(self._model_ref.dataset_id))}.{sg_sql.to_sql(sg_sql.identifier(self._model_ref.model_id))}" # Alter model def alter_model( diff --git a/bigframes/session/_io/bigquery/__init__.py b/bigframes/session/_io/bigquery/__init__.py index 1d1dc57c30..a9abf6602d 100644 --- a/bigframes/session/_io/bigquery/__init__.py +++ b/bigframes/session/_io/bigquery/__init__.py @@ -32,7 +32,7 @@ import google.cloud.bigquery._job_helpers import google.cloud.bigquery.table -from bigframes.core.compile.sqlglot import sqlglot_ir +from bigframes.core.compile.sqlglot import sql as sg_sql import bigframes.core.events from bigframes.core.logging import log_adapter import bigframes.core.sql @@ -599,7 +599,7 @@ def compile_filters(filters: third_party_pandas_gbq.FiltersType) -> str: operator_str = valid_operators[operator] - column_ref = sqlglot_ir.identifier(column) + column_ref = sg_sql.to_sql(sg_sql.identifier(column)) if operator_str in ["IN", "NOT IN"]: value_literal = bigframes.core.sql.multi_literal(*value) else: diff --git a/bigframes/session/bigquery_session.py b/bigframes/session/bigquery_session.py index 1a38bca1e8..79fb21486c 100644 --- a/bigframes/session/bigquery_session.py +++ b/bigframes/session/bigquery_session.py @@ -24,7 +24,7 @@ import bigframes_vendored.ibis.backends.bigquery.datatypes as ibis_bq import google.cloud.bigquery as bigquery -from bigframes.core.compile.sqlglot import sqlglot_ir +from bigframes.core.compile.sqlglot import sql as sg_sql import bigframes.core.events from bigframes.session import temporary_storage import bigframes.session._io.bigquery as bfbqio @@ -80,7 +80,7 @@ def create_temp_table( ibis_schema = ibis_bq.BigQuerySchema.to_ibis(list(schema)) fields = [ - f"{sqlglot_ir.identifier(name)} {ibis_bq.BigQueryType.from_ibis(ibis_type)}" + f"{sg_sql.to_sql(sg_sql.identifier(name))} {ibis_bq.BigQueryType.from_ibis(ibis_type)}" for name, ibis_type in ibis_schema.fields.items() ] fields_string = ",".join(fields) @@ -88,12 +88,12 @@ def create_temp_table( cluster_string = "" if cluster_cols: cluster_cols_sql = ", ".join( - f"{sqlglot_ir.identifier(cluster_col)}" + f"{sg_sql.to_sql(sg_sql.identifier(cluster_col))}" for cluster_col in cluster_cols ) cluster_string = f"\nCLUSTER BY {cluster_cols_sql}" - ddl = f"CREATE TEMP TABLE `_SESSION`.{sqlglot_ir.identifier(table_ref.table_id)} ({fields_string}){cluster_string}" + ddl = f"CREATE TEMP TABLE `_SESSION`.{sg_sql.to_sql(sg_sql.identifier(table_ref.table_id))} ({fields_string}){cluster_string}" _, job = bfbqio.start_query_with_client( self.bqclient, diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index a5c0176565..cd1642f6de 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -30,6 +30,7 @@ import bigframes.constants import bigframes.core from bigframes.core import bq_data, compile, local_data, rewrite +from bigframes.core.compile.sqlglot import sql as sg_sql from bigframes.core.compile.sqlglot import sqlglot_ir import bigframes.core.events import bigframes.core.guid @@ -304,12 +305,13 @@ def _export_gbq( # BigQuery `RATE_LIMIT_EXCEEDED` errors, as per quota limits: # https://cloud.google.com/bigquery/quotas#standard_tables job_config = bigquery.QueryJobConfig() - ir = sqlglot_ir.SQLGlotIR.from_query_string(sql) + + ir = sqlglot_ir.SQLGlotIR.from_unparsed_query(sql) if spec.if_exists == "append": - sql = ir.insert(spec.table) + sql = sg_sql.to_sql(sg_sql.insert(ir.expr, spec.table)) else: # for "replace" assert spec.if_exists == "replace" - sql = ir.replace(spec.table) + sql = sg_sql.to_sql(sg_sql.replace(ir.expr, spec.table)) else: dispositions = { "fail": bigquery.WriteDisposition.WRITE_EMPTY,