Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions bigframes/bigquery/_operations/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import google.cloud.bigquery

from bigframes.core.compile.sqlglot import sqlglot_ir
from bigframes.core.compile.sqlglot import sql
import bigframes.dtypes
import bigframes.operations
import bigframes.series
Expand Down Expand Up @@ -68,10 +68,7 @@ def sql_scalar(
# Another benefit of this is that if there is a syntax error in the SQL
# template, then this will fail with an error earlier in the process,
# aiding users in debugging.
literals_sql = [
sqlglot_ir._literal(None, column.dtype).sql(dialect="bigquery")
for column in columns
]
literals_sql = [sql.to_sql(sql.literal(None, column.dtype)) for column in columns]
select_sql = sql_template.format(*literals_sql)
dry_run_sql = f"SELECT {select_sql}"

Expand Down
16 changes: 8 additions & 8 deletions bigframes/core/compile/sqlglot/aggregations/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

from bigframes import dtypes
from bigframes.core import window_spec
from bigframes.core.compile.sqlglot import sql
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
from bigframes.core.compile.sqlglot.expressions import constants
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
from bigframes.operations import aggregations as agg_ops

UNARY_OP_REGISTRATION = reg.OpRegistration()
Expand Down Expand Up @@ -157,9 +157,9 @@ def _cut_ops_w_int_bins(
for this_bin in range(bins):
value: sge.Expression
if op.labels is False:
value = ir._literal(this_bin, dtypes.INT_DTYPE)
value = sql.literal(this_bin, dtypes.INT_DTYPE)
elif isinstance(op.labels, typing.Iterable):
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
value = sql.literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
else:
left_adj: sge.Expression = (
adj if this_bin == 0 and op.right else sge.convert(0)
Expand Down Expand Up @@ -217,10 +217,10 @@ def _cut_ops_w_intervals(
) -> sge.Case:
case_expr = sge.Case()
for this_bin, interval in enumerate(bins):
left: sge.Expression = ir._literal(
left: sge.Expression = sql.literal(
interval[0], dtypes.infer_literal_type(interval[0])
)
right: sge.Expression = ir._literal(
right: sge.Expression = sql.literal(
interval[1], dtypes.infer_literal_type(interval[1])
)
condition: sge.Expression
Expand All @@ -237,9 +237,9 @@ def _cut_ops_w_intervals(

value: sge.Expression
if op.labels is False:
value = ir._literal(this_bin, dtypes.INT_DTYPE)
value = sql.literal(this_bin, dtypes.INT_DTYPE)
elif isinstance(op.labels, typing.Iterable):
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
value = sql.literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
else:
if op.right:
left_identifier = sge.Identifier(this="left_exclusive", quoted=True)
Expand Down Expand Up @@ -609,7 +609,7 @@ def _(

# Will be null if all inputs are null. Pandas defaults to zero sum though.
zero = pd.to_timedelta(0) if column.dtype == dtypes.TIMEDELTA_DTYPE else 0
return sge.func("IFNULL", expr, ir._literal(zero, column.dtype))
return sge.func("IFNULL", expr, sql.literal(zero, column.dtype))


@UNARY_OP_REGISTRATION.register(agg_ops.VarOp)
Expand Down
82 changes: 46 additions & 36 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
sql_nodes,
)
from bigframes.core.compile import configs
from bigframes.core.compile.sqlglot import sql, sqlglot_ir
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
from bigframes.core.compile.sqlglot.aggregations import windows
import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
from bigframes.core.compile.sqlglot.expressions import typed_expr
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
from bigframes.core.logging import data_types as data_type_logger
import bigframes.core.ordering as bf_ordering
from bigframes.core.rewrite import schema_binding
Expand Down Expand Up @@ -108,20 +108,20 @@ def _compile_result_node(root: nodes.ResultNode) -> str:
# Probably, should defer even further
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))

sqlglot_ir = compile_node(rewrite.as_sql_nodes(root), uid_gen)
return sqlglot_ir.sql
sqlglot_ir_obj = compile_node(rewrite.as_sql_nodes(root), uid_gen)
return sqlglot_ir_obj.sql


def compile_node(
node: nodes.BigFrameNode, uid_gen: guid.SequentialUIDGenerator
) -> ir.SQLGlotIR:
) -> sqlglot_ir.SQLGlotIR:
"""Compiles the given BigFrameNode from bottem-up into SQLGlotIR."""
bf_to_sqlglot: dict[nodes.BigFrameNode, ir.SQLGlotIR] = {}
child_results: tuple[ir.SQLGlotIR, ...] = ()
bf_to_sqlglot: dict[nodes.BigFrameNode, sqlglot_ir.SQLGlotIR] = {}
child_results: tuple[sqlglot_ir.SQLGlotIR, ...] = ()
for current_node in list(node.iter_nodes_topo()):
if current_node.child_nodes == ():
# For leaf node, generates a dumpy child to pass the UID generator.
child_results = tuple([ir.SQLGlotIR(uid_gen=uid_gen)])
child_results = tuple([sqlglot_ir.SQLGlotIR(uid_gen=uid_gen)])
else:
# Child nodes should have been compiled in the reverse topological order.
child_results = tuple(
Expand All @@ -135,14 +135,14 @@ def compile_node(

@functools.singledispatch
def _compile_node(
node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR
) -> ir.SQLGlotIR:
node: nodes.BigFrameNode, *compiled_children: sqlglot_ir.SQLGlotIR
) -> sqlglot_ir.SQLGlotIR:
"""Defines transformation but isn't cached, always use compile_node instead"""
raise ValueError(f"Can't compile unrecognized node: {node}")


@_compile_node.register
def compile_sql_select(node: sql_nodes.SqlSelectNode, child: ir.SQLGlotIR):
def compile_sql_select(node: sql_nodes.SqlSelectNode, child: sqlglot_ir.SQLGlotIR):
ordering_cols = tuple(
sge.Ordered(
this=expression_compiler.expression_compiler.compile_expression(
Expand Down Expand Up @@ -175,7 +175,9 @@ def compile_sql_select(node: sql_nodes.SqlSelectNode, child: ir.SQLGlotIR):


@_compile_node.register
def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
def compile_readlocal(
node: nodes.ReadLocalNode, child: sqlglot_ir.SQLGlotIR
) -> sqlglot_ir.SQLGlotIR:
pa_table = node.local_data_source.data
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items])
Expand All @@ -184,16 +186,18 @@ def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLG
if offsets:
pa_table = pyarrow_utils.append_offsets(pa_table, offsets)

return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=child.uid_gen)
return sqlglot_ir.SQLGlotIR.from_pyarrow(
pa_table, node.schema, uid_gen=child.uid_gen
)


@_compile_node.register
def compile_readtable(node: sql_nodes.SqlDataSource, child: ir.SQLGlotIR):
table = node.source.table
return ir.SQLGlotIR.from_table(
table.project_id,
table.dataset_id,
table.table_id,
def compile_readtable(node: sql_nodes.SqlDataSource, child: sqlglot_ir.SQLGlotIR):
table_obj = node.source.table
return sqlglot_ir.SQLGlotIR.from_table(
table_obj.project_id,
table_obj.dataset_id,
table_obj.table_id,
uid_gen=child.uid_gen,
sql_predicate=node.source.sql_predicate,
system_time=node.source.at_time,
Expand All @@ -202,20 +206,20 @@ def compile_readtable(node: sql_nodes.SqlDataSource, child: ir.SQLGlotIR):

@_compile_node.register
def compile_join(
node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
) -> ir.SQLGlotIR:
node: nodes.JoinNode, left: sqlglot_ir.SQLGlotIR, right: sqlglot_ir.SQLGlotIR
) -> sqlglot_ir.SQLGlotIR:
conditions = tuple(
(
typed_expr.TypedExpr(
expression_compiler.expression_compiler.compile_expression(left),
left.output_type,
expression_compiler.expression_compiler.compile_expression(left_expr),
left_expr.output_type,
),
typed_expr.TypedExpr(
expression_compiler.expression_compiler.compile_expression(right),
right.output_type,
expression_compiler.expression_compiler.compile_expression(right_expr),
right_expr.output_type,
),
)
for left, right in node.conditions
for left_expr, right_expr in node.conditions
)

return left.join(
Expand All @@ -228,8 +232,8 @@ def compile_join(

@_compile_node.register
def compile_isin_join(
node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
) -> ir.SQLGlotIR:
node: nodes.InNode, left: sqlglot_ir.SQLGlotIR, right: sqlglot_ir.SQLGlotIR
) -> sqlglot_ir.SQLGlotIR:
right_field = node.right_child.fields[0]
conditions = (
typed_expr.TypedExpr(
Expand All @@ -253,7 +257,9 @@ def compile_isin_join(


@_compile_node.register
def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlotIR:
def compile_concat(
node: nodes.ConcatNode, *children: sqlglot_ir.SQLGlotIR
) -> sqlglot_ir.SQLGlotIR:
assert len(children) >= 1
uid_gen = children[0].uid_gen

Expand All @@ -264,24 +270,26 @@ def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlo
for default_output_id, output_id in zip(default_output_ids, node.output_ids)
]

return ir.SQLGlotIR.from_union(
return sqlglot_ir.SQLGlotIR.from_union(
[child._as_select() for child in children],
output_aliases=output_aliases,
uid_gen=uid_gen,
)


@_compile_node.register
def compile_explode(node: nodes.ExplodeNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
def compile_explode(
node: nodes.ExplodeNode, child: sqlglot_ir.SQLGlotIR
) -> sqlglot_ir.SQLGlotIR:
offsets_col = node.offsets_col.sql if (node.offsets_col is not None) else None
columns = tuple(ref.id.sql for ref in node.column_ids)
return child.explode(columns, offsets_col)


@_compile_node.register
def compile_fromrange(
node: nodes.FromRangeNode, start: ir.SQLGlotIR, end: ir.SQLGlotIR
) -> ir.SQLGlotIR:
node: nodes.FromRangeNode, start: sqlglot_ir.SQLGlotIR, end: sqlglot_ir.SQLGlotIR
) -> sqlglot_ir.SQLGlotIR:
start_col_id = node.start.fields[0].id
end_col_id = node.end.fields[0].id

Expand All @@ -291,20 +299,22 @@ def compile_fromrange(
end_expr = expression_compiler.expression_compiler.compile_expression(
expression.DerefOp(end_col_id)
)
step_expr = ir._literal(node.step, dtypes.INT_DTYPE)
step_expr = sql.literal(node.step, dtypes.INT_DTYPE)

return start.resample(end, node.output_id.sql, start_expr, end_expr, step_expr)


@_compile_node.register
def compile_random_sample(
node: nodes.RandomSampleNode, child: ir.SQLGlotIR
) -> ir.SQLGlotIR:
node: nodes.RandomSampleNode, child: sqlglot_ir.SQLGlotIR
) -> sqlglot_ir.SQLGlotIR:
return child.sample(node.fraction)


@_compile_node.register
def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
def compile_aggregate(
node: nodes.AggregateNode, child: sqlglot_ir.SQLGlotIR
) -> sqlglot_ir.SQLGlotIR:
# The BigQuery ordered aggregation cannot support for NULL FIRST/LAST,
# so we need to add extra expressions to enforce the null ordering.
ordering_cols = windows.get_window_order_by(node.order_by, override_null_order=True)
Expand Down
4 changes: 2 additions & 2 deletions bigframes/core/compile/sqlglot/expression_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import bigframes_vendored.sqlglot.expressions as sge

import bigframes.core.agg_expressions as agg_exprs
from bigframes.core.compile.sqlglot import sql
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
import bigframes.core.expression as ex
import bigframes.operations as ops

Expand Down Expand Up @@ -77,7 +77,7 @@ def _(self, expr: ex.DerefOp) -> sge.Expression:

@compile_expression.register
def _(self, expr: ex.ScalarConstantExpression) -> sge.Expression:
return ir._literal(expr.value, expr.dtype)
return sql.literal(expr.value, expr.dtype)

@compile_expression.register
def _(self, expr: agg_exprs.WindowExpression) -> sge.Expression:
Expand Down
10 changes: 5 additions & 5 deletions bigframes/core/compile/sqlglot/expressions/comparison_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from bigframes import dtypes
from bigframes import operations as ops
from bigframes.core.compile.sqlglot import sqlglot_ir
from bigframes.core.compile.sqlglot import sql
import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr

Expand Down Expand Up @@ -59,9 +59,9 @@ def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression:

@register_binary_op(ops.eq_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if sqlglot_ir._is_null_literal(left.expr):
if sql.is_null_literal(left.expr):
return sge.Is(this=right.expr, expression=sge.Null())
if sqlglot_ir._is_null_literal(right.expr):
if sql.is_null_literal(right.expr):
return sge.Is(this=left.expr, expression=sge.Null())
left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)
Expand Down Expand Up @@ -140,12 +140,12 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:

@register_binary_op(ops.ne_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if sqlglot_ir._is_null_literal(left.expr):
if sql.is_null_literal(left.expr):
return sge.Is(
this=sge.paren(right.expr, copy=False),
expression=sg.not_(sge.Null(), copy=False),
)
if sqlglot_ir._is_null_literal(right.expr):
if sql.is_null_literal(right.expr):
return sge.Is(
this=sge.paren(left.expr, copy=False),
expression=sg.not_(sge.Null(), copy=False),
Expand Down
Loading
Loading