Skip to content

Commit d146465

Browse files
committed
refactor: add brackets on sqlglot expression
1 parent 61c17e3 commit d146465

File tree

14 files changed

+76
-40
lines changed

14 files changed

+76
-40
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import typing
1818

19+
import bigframes_vendored.sqlglot as sg
1920
import bigframes_vendored.sqlglot.expressions as sge
2021
import pandas as pd
2122

@@ -189,7 +190,10 @@ def _cut_ops_w_int_bins(
189190

190191
condition: sge.Expression
191192
if this_bin == bins - 1:
192-
condition = sge.Is(this=column.expr, expression=sge.Not(this=sge.Null()))
193+
condition = sge.Is(
194+
this=sge.paren(column.expr, copy=False),
195+
expression=sg.not_(sge.Null(), copy=False),
196+
)
193197
else:
194198
if op.right:
195199
condition = sge.LTE(

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression:
125125

126126
@register_unary_op(ops.notnull_op)
127127
def _(expr: TypedExpr) -> sge.Expression:
128-
return sge.Not(this=sge.Is(this=sge.paren(expr.expr), expression=sge.Null()))
128+
return sge.Is(
129+
this=sge.paren(expr.expr, copy=False),
130+
expression=sg.not_(sge.Null(), copy=False),
131+
)
129132

130133

131134
@register_ternary_op(ops.where_op)

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def _float_pow_op(
362362
sge.If(
363363
this=sge.and_(
364364
sge.LT(this=left_expr, expression=constants._ZERO),
365-
sge.Not(this=exponent_is_whole),
365+
sge.Not(this=sge.paren(exponent_is_whole)),
366366
),
367367
true=constants._NAN,
368368
),

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ SELECT
3030
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
3131
) + 0 AS `right_inclusive`
3232
)
33-
WHEN `int64_col` IS NOT NULL
33+
WHEN (
34+
`int64_col`
35+
) IS NOT NULL
3436
THEN STRUCT(
3537
(
3638
MIN(`int64_col`) OVER () + (

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ SELECT
88
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
99
)
1010
THEN 'b'
11-
WHEN `int64_col` IS NOT NULL
11+
WHEN (
12+
`int64_col`
13+
) IS NOT NULL
1214
THEN 'c'
1315
END AS `int_bins_labels`
1416
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,47 @@ SELECT
22
`rowindex`,
33
`int64_col`,
44
IF(
5-
NOT (
5+
(
66
`int64_col`
7-
) IS NULL,
7+
) IS NOT NULL,
88
IF(
99
`int64_col` IS NULL,
1010
NULL,
1111
CAST(GREATEST(
1212
CEIL(
13-
PERCENT_RANK() OVER (PARTITION BY NOT (
13+
PERCENT_RANK() OVER (PARTITION BY (
1414
`int64_col`
15-
) IS NULL ORDER BY `int64_col` ASC) * 4
15+
) IS NOT NULL ORDER BY `int64_col` ASC) * 4
1616
) - 1,
1717
0
1818
) AS INT64)
1919
),
2020
NULL
2121
) AS `qcut_w_int`,
2222
IF(
23-
NOT (
23+
(
2424
`int64_col`
25-
) IS NULL,
25+
) IS NOT NULL,
2626
CASE
27-
WHEN PERCENT_RANK() OVER (PARTITION BY NOT (
27+
WHEN PERCENT_RANK() OVER (PARTITION BY (
2828
`int64_col`
29-
) IS NULL ORDER BY `int64_col` ASC) < 0
29+
) IS NOT NULL ORDER BY `int64_col` ASC) < 0
3030
THEN NULL
31-
WHEN PERCENT_RANK() OVER (PARTITION BY NOT (
31+
WHEN PERCENT_RANK() OVER (PARTITION BY (
3232
`int64_col`
33-
) IS NULL ORDER BY `int64_col` ASC) <= 0.25
33+
) IS NOT NULL ORDER BY `int64_col` ASC) <= 0.25
3434
THEN 0
35-
WHEN PERCENT_RANK() OVER (PARTITION BY NOT (
35+
WHEN PERCENT_RANK() OVER (PARTITION BY (
3636
`int64_col`
37-
) IS NULL ORDER BY `int64_col` ASC) <= 0.5
37+
) IS NOT NULL ORDER BY `int64_col` ASC) <= 0.5
3838
THEN 1
39-
WHEN PERCENT_RANK() OVER (PARTITION BY NOT (
39+
WHEN PERCENT_RANK() OVER (PARTITION BY (
4040
`int64_col`
41-
) IS NULL ORDER BY `int64_col` ASC) <= 0.75
41+
) IS NOT NULL ORDER BY `int64_col` ASC) <= 0.75
4242
THEN 2
43-
WHEN PERCENT_RANK() OVER (PARTITION BY NOT (
43+
WHEN PERCENT_RANK() OVER (PARTITION BY (
4444
`int64_col`
45-
) IS NULL ORDER BY `int64_col` ASC) <= 1
45+
) IS NOT NULL ORDER BY `int64_col` ASC) <= 1
4646
THEN 3
4747
ELSE NULL
4848
END,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
SELECT
2-
NOT (
2+
(
33
`float64_col`
4-
) IS NULL AS `float64_col`
4+
) IS NOT NULL AS `float64_col`
55
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
NOT IS_INF(`float64_col`) OR IS_NAN(`float64_col`) AS `float64_col`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ SELECT
3333
END
3434
)
3535
WHEN `int64_col` < CAST(0 AS INT64)
36-
AND NOT CAST(`float64_col` AS INT64) = `float64_col`
36+
AND NOT (
37+
CAST(`float64_col` AS INT64) = `float64_col`
38+
)
3739
THEN CAST('NaN' AS FLOAT64)
3840
WHEN `int64_col` <> CAST(0 AS INT64) AND `float64_col` * LN(ABS(`int64_col`)) > 709.78
3941
THEN CAST('Infinity' AS FLOAT64) * CASE
@@ -75,7 +77,10 @@ SELECT
7577
ELSE `int64_col`
7678
END
7779
)
78-
WHEN `float64_col` < CAST(0 AS INT64) AND NOT CAST(`int64_col` AS INT64) = `int64_col`
80+
WHEN `float64_col` < CAST(0 AS INT64)
81+
AND NOT (
82+
CAST(`int64_col` AS INT64) = `int64_col`
83+
)
7984
THEN CAST('NaN' AS FLOAT64)
8085
WHEN `float64_col` <> CAST(0 AS INT64)
8186
AND `int64_col` * LN(ABS(`float64_col`)) > 709.78
@@ -119,7 +124,9 @@ SELECT
119124
END
120125
)
121126
WHEN `float64_col` < CAST(0 AS INT64)
122-
AND NOT CAST(`float64_col` AS INT64) = `float64_col`
127+
AND NOT (
128+
CAST(`float64_col` AS INT64) = `float64_col`
129+
)
123130
THEN CAST('NaN' AS FLOAT64)
124131
WHEN `float64_col` <> CAST(0 AS INT64)
125132
AND `float64_col` * LN(ABS(`float64_col`)) > 709.78
@@ -167,7 +174,9 @@ SELECT
167174
ELSE 0
168175
END
169176
)
170-
WHEN `float64_col` < CAST(0 AS INT64) AND NOT CAST(0 AS INT64) = 0
177+
WHEN `float64_col` < CAST(0 AS INT64) AND NOT (
178+
CAST(0 AS INT64) = 0
179+
)
171180
THEN CAST('NaN' AS FLOAT64)
172181
WHEN `float64_col` <> CAST(0 AS INT64) AND 0 * LN(ABS(`float64_col`)) > 709.78
173182
THEN CAST('Infinity' AS FLOAT64) * CASE
@@ -214,7 +223,9 @@ SELECT
214223
ELSE 1
215224
END
216225
)
217-
WHEN `float64_col` < CAST(0 AS INT64) AND NOT CAST(1 AS INT64) = 1
226+
WHEN `float64_col` < CAST(0 AS INT64) AND NOT (
227+
CAST(1 AS INT64) = 1
228+
)
218229
THEN CAST('NaN' AS FLOAT64)
219230
WHEN `float64_col` <> CAST(0 AS INT64) AND 1 * LN(ABS(`float64_col`)) > 709.78
220231
THEN CAST('Infinity' AS FLOAT64) * CASE

tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from bigframes import operations as ops
1919
import bigframes.core.expression as ex
20+
from bigframes.operations import numeric_ops
2021
import bigframes.pandas as bpd
2122
from bigframes.testing import utils
2223

@@ -156,6 +157,16 @@ def test_floor(scalar_types_df: bpd.DataFrame, snapshot):
156157
snapshot.assert_match(sql, "out.sql")
157158

158159

160+
def test_isfinite(scalar_types_df: bpd.DataFrame, snapshot):
161+
col_name = "float64_col"
162+
bf_df = scalar_types_df[[col_name]]
163+
sql = utils._apply_ops_to_sql(
164+
bf_df, [numeric_ops.isfinite_op.as_expr(col_name)], [col_name]
165+
)
166+
167+
snapshot.assert_match(sql, "out.sql")
168+
169+
159170
def test_ln(scalar_types_df: bpd.DataFrame, snapshot):
160171
col_name = "float64_col"
161172
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)