Skip to content

Commit 63f9710

Browse files
committed
add dml unit tests
1 parent d0fba03 commit 63f9710

File tree

7 files changed

+105
-5
lines changed

7 files changed

+105
-5
lines changed

bigframes/core/compile/sqlglot/sql/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Google LLC
1+
# Copyright 2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -58,6 +58,7 @@ def identifier(id: str) -> sge.Identifier:
5858

5959

6060
def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
61+
"""Return a string representing column reference in a SQL."""
6162
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
6263
if sqlglot_type is None:
6364
if not pd.isna(value):
@@ -110,13 +111,15 @@ def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
110111

111112

112113
def cast(arg: typing.Any, to: str, safe: bool = False) -> sge.Cast | sge.TryCast:
114+
"""Return a SQL expression that casts the given argument to the specified type."""
113115
if safe:
114116
return sge.TryCast(this=arg, to=to)
115117
else:
116118
return sge.Cast(this=arg, to=to)
117119

118120

119121
def table(table: bigquery.TableReference) -> sge.Table:
122+
"""Return a SQLGlot Table expression representing the given BigQuery table reference."""
120123
return sge.Table(
121124
this=sge.to_identifier(table.table_id, quoted=True),
122125
db=sge.to_identifier(table.dataset_id, quoted=True),

bigframes/core/compile/sqlglot/sql/dml.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@ def replace(
4141
this=base.table(destination),
4242
using=_as_from_item(query_or_table),
4343
on=base.literal(False, dtypes.BOOL_DTYPE),
44-
whens=[
45-
sge.When(matched=False, source=True, then=sge.Delete()),
46-
sge.When(matched=False, then=sge.Insert(this=sge.Var(this="ROW"))),
47-
],
44+
whens=sge.Whens(
45+
expressions=[
46+
sge.When(matched=False, source=True, then=sge.Delete()),
47+
sge.When(matched=False, then=sge.Insert(this=sge.Var(this="ROW"))),
48+
]
49+
),
4850
)
4951

5052

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
INSERT INTO `bigframes-dev`.`sqlglot_test`.`dest_table`
2+
(
3+
SELECT
4+
*
5+
FROM `source_table`
6+
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
INSERT INTO `bigframes-dev`.`sqlglot_test`.`dest_table`
2+
`source_table`
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
MERGE INTO `bigframes-dev`.`sqlglot_test`.`dest_table`
2+
USING (
3+
SELECT
4+
*
5+
FROM `source_table`
6+
)
7+
ON FALSE
8+
WHEN NOT MATCHED BY SOURCE THEN DELETE
9+
WHEN NOT MATCHED THEN INSERT ROW
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
MERGE INTO `bigframes-dev`.`sqlglot_test`.`dest_table`
2+
USING `source_table`
3+
ON FALSE
4+
WHEN NOT MATCHED BY SOURCE THEN DELETE
5+
WHEN NOT MATCHED THEN INSERT ROW
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import bigframes_vendored.sqlglot.expressions as sge
16+
from google.cloud import bigquery
17+
import pytest
18+
19+
from bigframes.core.compile.sqlglot.sql import base, dml
20+
21+
pytest.importorskip("pytest_snapshot")
22+
23+
24+
def test_insert_from_select(snapshot):
25+
query = sge.select("*").from_(
26+
sge.Table(this=sge.Identifier(this="source_table", quoted=True))
27+
)
28+
destination = bigquery.TableReference.from_string(
29+
"bigframes-dev.sqlglot_test.dest_table"
30+
)
31+
32+
expr = dml.insert(query, destination)
33+
sql = base.to_sql(expr)
34+
35+
snapshot.assert_match(sql, "out.sql")
36+
37+
38+
def test_insert_from_table(snapshot):
39+
query = sge.Table(this=sge.Identifier(this="source_table", quoted=True))
40+
destination = bigquery.TableReference.from_string(
41+
"bigframes-dev.sqlglot_test.dest_table"
42+
)
43+
44+
expr = dml.insert(query, destination)
45+
sql = base.to_sql(expr)
46+
47+
snapshot.assert_match(sql, "out.sql")
48+
49+
50+
def test_replace_from_select(snapshot):
51+
query = sge.select("*").from_(
52+
sge.Table(this=sge.Identifier(this="source_table", quoted=True))
53+
)
54+
destination = bigquery.TableReference.from_string(
55+
"bigframes-dev.sqlglot_test.dest_table"
56+
)
57+
58+
expr = dml.replace(query, destination)
59+
sql = base.to_sql(expr)
60+
61+
snapshot.assert_match(sql, "out.sql")
62+
63+
64+
def test_replace_from_table(snapshot):
65+
query = sge.Table(this=sge.Identifier(this="source_table", quoted=True))
66+
destination = bigquery.TableReference.from_string(
67+
"bigframes-dev.sqlglot_test.dest_table"
68+
)
69+
70+
expr = dml.replace(query, destination)
71+
sql = base.to_sql(expr)
72+
73+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)