Skip to content

Commit 3003e3f

Browse files
committed
refactor: split sql builders from the sqlglot_ir class
1 parent cb00daa commit 3003e3f

File tree

17 files changed

+436
-344
lines changed

17 files changed

+436
-344
lines changed

bigframes/bigquery/_operations/sql.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import google.cloud.bigquery
2222

23-
from bigframes.core.compile.sqlglot import sqlglot_ir
23+
from bigframes.core.compile.sqlglot import sql
2424
import bigframes.dtypes
2525
import bigframes.operations
2626
import bigframes.series
@@ -68,10 +68,7 @@ def sql_scalar(
6868
# Another benefit of this is that if there is a syntax error in the SQL
6969
# template, then this will fail with an error earlier in the process,
7070
# aiding users in debugging.
71-
literals_sql = [
72-
sqlglot_ir._literal(None, column.dtype).sql(dialect="bigquery")
73-
for column in columns
74-
]
71+
literals_sql = [sql.to_sql(sql.literal(None, column.dtype)) for column in columns]
7572
select_sql = sql_template.format(*literals_sql)
7673
dry_run_sql = f"SELECT {select_sql}"
7774

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222

2323
from bigframes import dtypes
2424
from bigframes.core import window_spec
25+
from bigframes.core.compile.sqlglot import sql
2526
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
2627
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
2728
from bigframes.core.compile.sqlglot.expressions import constants
2829
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
29-
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
3030
from bigframes.operations import aggregations as agg_ops
3131

3232
UNARY_OP_REGISTRATION = reg.OpRegistration()
@@ -157,9 +157,9 @@ def _cut_ops_w_int_bins(
157157
for this_bin in range(bins):
158158
value: sge.Expression
159159
if op.labels is False:
160-
value = ir._literal(this_bin, dtypes.INT_DTYPE)
160+
value = sql.literal(this_bin, dtypes.INT_DTYPE)
161161
elif isinstance(op.labels, typing.Iterable):
162-
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
162+
value = sql.literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
163163
else:
164164
left_adj: sge.Expression = (
165165
adj if this_bin == 0 and op.right else sge.convert(0)
@@ -217,10 +217,10 @@ def _cut_ops_w_intervals(
217217
) -> sge.Case:
218218
case_expr = sge.Case()
219219
for this_bin, interval in enumerate(bins):
220-
left: sge.Expression = ir._literal(
220+
left: sge.Expression = sql.literal(
221221
interval[0], dtypes.infer_literal_type(interval[0])
222222
)
223-
right: sge.Expression = ir._literal(
223+
right: sge.Expression = sql.literal(
224224
interval[1], dtypes.infer_literal_type(interval[1])
225225
)
226226
condition: sge.Expression
@@ -237,9 +237,9 @@ def _cut_ops_w_intervals(
237237

238238
value: sge.Expression
239239
if op.labels is False:
240-
value = ir._literal(this_bin, dtypes.INT_DTYPE)
240+
value = sql.literal(this_bin, dtypes.INT_DTYPE)
241241
elif isinstance(op.labels, typing.Iterable):
242-
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
242+
value = sql.literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
243243
else:
244244
if op.right:
245245
left_identifier = sge.Identifier(this="left_exclusive", quoted=True)
@@ -609,7 +609,7 @@ def _(
609609

610610
# Will be null if all inputs are null. Pandas defaults to zero sum though.
611611
zero = pd.to_timedelta(0) if column.dtype == dtypes.TIMEDELTA_DTYPE else 0
612-
return sge.func("IFNULL", expr, ir._literal(zero, column.dtype))
612+
return sge.func("IFNULL", expr, sql.literal(zero, column.dtype))
613613

614614

615615
@UNARY_OP_REGISTRATION.register(agg_ops.VarOp)

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
sql_nodes,
3131
)
3232
from bigframes.core.compile import configs
33+
from bigframes.core.compile.sqlglot import sql, sqlglot_ir
3334
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
3435
from bigframes.core.compile.sqlglot.aggregations import windows
3536
import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
3637
from bigframes.core.compile.sqlglot.expressions import typed_expr
37-
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
3838
from bigframes.core.logging import data_types as data_type_logger
3939
import bigframes.core.ordering as bf_ordering
4040
from bigframes.core.rewrite import schema_binding
@@ -108,20 +108,20 @@ def _compile_result_node(root: nodes.ResultNode) -> str:
108108
# Probably, should defer even further
109109
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
110110

111-
sqlglot_ir = compile_node(rewrite.as_sql_nodes(root), uid_gen)
112-
return sqlglot_ir.sql
111+
sqlglot_ir_obj = compile_node(rewrite.as_sql_nodes(root), uid_gen)
112+
return sqlglot_ir_obj.sql
113113

114114

115115
def compile_node(
116116
node: nodes.BigFrameNode, uid_gen: guid.SequentialUIDGenerator
117-
) -> ir.SQLGlotIR:
117+
) -> sqlglot_ir.SQLGlotIR:
118118
"""Compiles the given BigFrameNode from bottem-up into SQLGlotIR."""
119-
bf_to_sqlglot: dict[nodes.BigFrameNode, ir.SQLGlotIR] = {}
120-
child_results: tuple[ir.SQLGlotIR, ...] = ()
119+
bf_to_sqlglot: dict[nodes.BigFrameNode, sqlglot_ir.SQLGlotIR] = {}
120+
child_results: tuple[sqlglot_ir.SQLGlotIR, ...] = ()
121121
for current_node in list(node.iter_nodes_topo()):
122122
if current_node.child_nodes == ():
123123
# For leaf node, generates a dumpy child to pass the UID generator.
124-
child_results = tuple([ir.SQLGlotIR(uid_gen=uid_gen)])
124+
child_results = tuple([sqlglot_ir.SQLGlotIR(uid_gen=uid_gen)])
125125
else:
126126
# Child nodes should have been compiled in the reverse topological order.
127127
child_results = tuple(
@@ -135,14 +135,14 @@ def compile_node(
135135

136136
@functools.singledispatch
137137
def _compile_node(
138-
node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR
139-
) -> ir.SQLGlotIR:
138+
node: nodes.BigFrameNode, *compiled_children: sqlglot_ir.SQLGlotIR
139+
) -> sqlglot_ir.SQLGlotIR:
140140
"""Defines transformation but isn't cached, always use compile_node instead"""
141141
raise ValueError(f"Can't compile unrecognized node: {node}")
142142

143143

144144
@_compile_node.register
145-
def compile_sql_select(node: sql_nodes.SqlSelectNode, child: ir.SQLGlotIR):
145+
def compile_sql_select(node: sql_nodes.SqlSelectNode, child: sqlglot_ir.SQLGlotIR):
146146
ordering_cols = tuple(
147147
sge.Ordered(
148148
this=expression_compiler.expression_compiler.compile_expression(
@@ -175,7 +175,9 @@ def compile_sql_select(node: sql_nodes.SqlSelectNode, child: ir.SQLGlotIR):
175175

176176

177177
@_compile_node.register
178-
def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
178+
def compile_readlocal(
179+
node: nodes.ReadLocalNode, child: sqlglot_ir.SQLGlotIR
180+
) -> sqlglot_ir.SQLGlotIR:
179181
pa_table = node.local_data_source.data
180182
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
181183
pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items])
@@ -184,16 +186,18 @@ def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLG
184186
if offsets:
185187
pa_table = pyarrow_utils.append_offsets(pa_table, offsets)
186188

187-
return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=child.uid_gen)
189+
return sqlglot_ir.SQLGlotIR.from_pyarrow(
190+
pa_table, node.schema, uid_gen=child.uid_gen
191+
)
188192

189193

190194
@_compile_node.register
191-
def compile_readtable(node: sql_nodes.SqlDataSource, child: ir.SQLGlotIR):
192-
table = node.source.table
193-
return ir.SQLGlotIR.from_table(
194-
table.project_id,
195-
table.dataset_id,
196-
table.table_id,
195+
def compile_readtable(node: sql_nodes.SqlDataSource, child: sqlglot_ir.SQLGlotIR):
196+
table_obj = node.source.table
197+
return sqlglot_ir.SQLGlotIR.from_table(
198+
table_obj.project_id,
199+
table_obj.dataset_id,
200+
table_obj.table_id,
197201
uid_gen=child.uid_gen,
198202
sql_predicate=node.source.sql_predicate,
199203
system_time=node.source.at_time,
@@ -202,20 +206,20 @@ def compile_readtable(node: sql_nodes.SqlDataSource, child: ir.SQLGlotIR):
202206

203207
@_compile_node.register
204208
def compile_join(
205-
node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
206-
) -> ir.SQLGlotIR:
209+
node: nodes.JoinNode, left: sqlglot_ir.SQLGlotIR, right: sqlglot_ir.SQLGlotIR
210+
) -> sqlglot_ir.SQLGlotIR:
207211
conditions = tuple(
208212
(
209213
typed_expr.TypedExpr(
210-
expression_compiler.expression_compiler.compile_expression(left),
211-
left.output_type,
214+
expression_compiler.expression_compiler.compile_expression(left_expr),
215+
left_expr.output_type,
212216
),
213217
typed_expr.TypedExpr(
214-
expression_compiler.expression_compiler.compile_expression(right),
215-
right.output_type,
218+
expression_compiler.expression_compiler.compile_expression(right_expr),
219+
right_expr.output_type,
216220
),
217221
)
218-
for left, right in node.conditions
222+
for left_expr, right_expr in node.conditions
219223
)
220224

221225
return left.join(
@@ -228,8 +232,8 @@ def compile_join(
228232

229233
@_compile_node.register
230234
def compile_isin_join(
231-
node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
232-
) -> ir.SQLGlotIR:
235+
node: nodes.InNode, left: sqlglot_ir.SQLGlotIR, right: sqlglot_ir.SQLGlotIR
236+
) -> sqlglot_ir.SQLGlotIR:
233237
right_field = node.right_child.fields[0]
234238
conditions = (
235239
typed_expr.TypedExpr(
@@ -253,7 +257,9 @@ def compile_isin_join(
253257

254258

255259
@_compile_node.register
256-
def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlotIR:
260+
def compile_concat(
261+
node: nodes.ConcatNode, *children: sqlglot_ir.SQLGlotIR
262+
) -> sqlglot_ir.SQLGlotIR:
257263
assert len(children) >= 1
258264
uid_gen = children[0].uid_gen
259265

@@ -264,24 +270,26 @@ def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlo
264270
for default_output_id, output_id in zip(default_output_ids, node.output_ids)
265271
]
266272

267-
return ir.SQLGlotIR.from_union(
273+
return sqlglot_ir.SQLGlotIR.from_union(
268274
[child._as_select() for child in children],
269275
output_aliases=output_aliases,
270276
uid_gen=uid_gen,
271277
)
272278

273279

274280
@_compile_node.register
275-
def compile_explode(node: nodes.ExplodeNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
281+
def compile_explode(
282+
node: nodes.ExplodeNode, child: sqlglot_ir.SQLGlotIR
283+
) -> sqlglot_ir.SQLGlotIR:
276284
offsets_col = node.offsets_col.sql if (node.offsets_col is not None) else None
277285
columns = tuple(ref.id.sql for ref in node.column_ids)
278286
return child.explode(columns, offsets_col)
279287

280288

281289
@_compile_node.register
282290
def compile_fromrange(
283-
node: nodes.FromRangeNode, start: ir.SQLGlotIR, end: ir.SQLGlotIR
284-
) -> ir.SQLGlotIR:
291+
node: nodes.FromRangeNode, start: sqlglot_ir.SQLGlotIR, end: sqlglot_ir.SQLGlotIR
292+
) -> sqlglot_ir.SQLGlotIR:
285293
start_col_id = node.start.fields[0].id
286294
end_col_id = node.end.fields[0].id
287295

@@ -291,20 +299,22 @@ def compile_fromrange(
291299
end_expr = expression_compiler.expression_compiler.compile_expression(
292300
expression.DerefOp(end_col_id)
293301
)
294-
step_expr = ir._literal(node.step, dtypes.INT_DTYPE)
302+
step_expr = sql.literal(node.step, dtypes.INT_DTYPE)
295303

296304
return start.resample(end, node.output_id.sql, start_expr, end_expr, step_expr)
297305

298306

299307
@_compile_node.register
300308
def compile_random_sample(
301-
node: nodes.RandomSampleNode, child: ir.SQLGlotIR
302-
) -> ir.SQLGlotIR:
309+
node: nodes.RandomSampleNode, child: sqlglot_ir.SQLGlotIR
310+
) -> sqlglot_ir.SQLGlotIR:
303311
return child.sample(node.fraction)
304312

305313

306314
@_compile_node.register
307-
def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
315+
def compile_aggregate(
316+
node: nodes.AggregateNode, child: sqlglot_ir.SQLGlotIR
317+
) -> sqlglot_ir.SQLGlotIR:
308318
# The BigQuery ordered aggregation cannot support for NULL FIRST/LAST,
309319
# so we need to add extra expressions to enforce the null ordering.
310320
ordering_cols = windows.get_window_order_by(node.order_by, override_null_order=True)

bigframes/core/compile/sqlglot/expression_compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import bigframes_vendored.sqlglot.expressions as sge
2020

2121
import bigframes.core.agg_expressions as agg_exprs
22+
from bigframes.core.compile.sqlglot import sql
2223
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
23-
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2424
import bigframes.core.expression as ex
2525
import bigframes.operations as ops
2626

@@ -77,7 +77,7 @@ def _(self, expr: ex.DerefOp) -> sge.Expression:
7777

7878
@compile_expression.register
7979
def _(self, expr: ex.ScalarConstantExpression) -> sge.Expression:
80-
return ir._literal(expr.value, expr.dtype)
80+
return sql.literal(expr.value, expr.dtype)
8181

8282
@compile_expression.register
8383
def _(self, expr: agg_exprs.WindowExpression) -> sge.Expression:

bigframes/core/compile/sqlglot/expressions/comparison_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from bigframes import dtypes
2424
from bigframes import operations as ops
25-
from bigframes.core.compile.sqlglot import sqlglot_ir
25+
from bigframes.core.compile.sqlglot import sql
2626
import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
2727
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2828

@@ -59,9 +59,9 @@ def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression:
5959

6060
@register_binary_op(ops.eq_op)
6161
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
62-
if sqlglot_ir._is_null_literal(left.expr):
62+
if sql.is_null_literal(left.expr):
6363
return sge.Is(this=right.expr, expression=sge.Null())
64-
if sqlglot_ir._is_null_literal(right.expr):
64+
if sql.is_null_literal(right.expr):
6565
return sge.Is(this=left.expr, expression=sge.Null())
6666
left_expr = _coerce_bool_to_int(left)
6767
right_expr = _coerce_bool_to_int(right)
@@ -140,12 +140,12 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
140140

141141
@register_binary_op(ops.ne_op)
142142
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
143-
if sqlglot_ir._is_null_literal(left.expr):
143+
if sql.is_null_literal(left.expr):
144144
return sge.Is(
145145
this=sge.paren(right.expr, copy=False),
146146
expression=sg.not_(sge.Null(), copy=False),
147147
)
148-
if sqlglot_ir._is_null_literal(right.expr):
148+
if sql.is_null_literal(right.expr):
149149
return sge.Is(
150150
this=sge.paren(left.expr, copy=False),
151151
expression=sg.not_(sge.Null(), copy=False),

0 commit comments

Comments
 (0)