Skip to content

Commit e4dfcfc

Browse files
committed
refactor: Ensure only valid IDs are propagated in identifier remapping
Corrected remap_variables to only propagate column IDs that are actually present in the current node's output fields. This prevents parent nodes from seeing internal or leaked column IDs from child nodes, which was specifically problematic for aggregate nodes. Added unit tests to verify correct propagation for AggregateNode with and without grouping.
1 parent 61c17e3 commit e4dfcfc

File tree

2 files changed

+54
-13
lines changed

2 files changed

+54
-13
lines changed

bigframes/core/rewrite/identifiers.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,6 @@ def remap_variables(
5757
new_root = root.transform_children(lambda node: remapped_children[node])
5858

5959
# Step 3: Transform the current node using the mappings from its children.
60-
# "reversed" is required for InNode so that in case of a duplicate column ID,
61-
# the left child's mapping is the one that's kept.
62-
downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = {
63-
k: v for mapping in reversed(new_child_mappings) for k, v in mapping.items()
64-
}
6560
if isinstance(new_root, nodes.InNode):
6661
new_root = typing.cast(nodes.InNode, new_root)
6762
new_root = dataclasses.replace(
@@ -71,6 +66,9 @@ def remap_variables(
7166
),
7267
)
7368
else:
69+
downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = {
70+
k: v for mapping in new_child_mappings for k, v in mapping.items()
71+
}
7472
new_root = new_root.remap_refs(downstream_mappings)
7573

7674
# Step 4: Create new IDs for columns defined by the current node.
@@ -82,12 +80,8 @@ def remap_variables(
8280
new_root._validate()
8381

8482
# Step 5: Determine which mappings to propagate up to the parent.
85-
if root.defines_namespace:
86-
# If a node defines a new namespace (e.g., a join), mappings from its
87-
# children are not visible to its parents.
88-
mappings_for_parent = node_defined_mappings
89-
else:
90-
# Otherwise, pass up the combined mappings from children and the current node.
91-
mappings_for_parent = downstream_mappings | node_defined_mappings
83+
propagated_mappings = {
84+
old_id: new_id for old_id, new_id in zip(root.ids, new_root.ids)
85+
}
9286

93-
return new_root, mappings_for_parent
87+
return new_root, propagated_mappings

tests/unit/core/rewrite/test_identifiers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
import typing
1515

1616
import bigframes.core as core
17+
import bigframes.core.agg_expressions as agg_ex
1718
import bigframes.core.expression as ex
1819
import bigframes.core.identifiers as identifiers
1920
import bigframes.core.nodes as nodes
2021
import bigframes.core.rewrite.identifiers as id_rewrite
22+
import bigframes.operations.aggregations as agg_ops
2123

2224

2325
def test_remap_variables_single_node(leaf):
@@ -51,6 +53,51 @@ def test_remap_variables_projection(leaf):
5153
assert set(mapping.values()) == {identifiers.ColumnId(f"id_{i}") for i in range(3)}
5254

5355

56+
def test_remap_variables_aggregate(leaf):
57+
# Aggregation: sum(col_a) AS sum_a
58+
# Group by nothing
59+
agg_op = agg_ex.UnaryAggregation(
60+
op=agg_ops.sum_op,
61+
arg=ex.DerefOp(leaf.fields[0].id),
62+
)
63+
node = nodes.AggregateNode(
64+
child=leaf,
65+
aggregations=((agg_op, identifiers.ColumnId("sum_a")),),
66+
by_column_ids=(),
67+
)
68+
69+
id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
70+
_, mapping = id_rewrite.remap_variables(node, id_generator)
71+
72+
# leaf has 2 columns: col_a, col_b
73+
# AggregateNode defines 1 column: sum_a
74+
# Output of AggregateNode should only be sum_a
75+
assert len(mapping) == 1
76+
assert identifiers.ColumnId("sum_a") in mapping
77+
78+
79+
def test_remap_variables_aggregate_with_grouping(leaf):
80+
# Aggregation: sum(col_b) AS sum_b
81+
# Group by col_a
82+
agg_op = agg_ex.UnaryAggregation(
83+
op=agg_ops.sum_op,
84+
arg=ex.DerefOp(leaf.fields[1].id),
85+
)
86+
node = nodes.AggregateNode(
87+
child=leaf,
88+
aggregations=((agg_op, identifiers.ColumnId("sum_b")),),
89+
by_column_ids=(ex.DerefOp(leaf.fields[0].id),),
90+
)
91+
92+
id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
93+
_, mapping = id_rewrite.remap_variables(node, id_generator)
94+
95+
# Output should have 2 columns: col_a (grouping) and sum_b (agg)
96+
assert len(mapping) == 2
97+
assert leaf.fields[0].id in mapping
98+
assert identifiers.ColumnId("sum_b") in mapping
99+
100+
54101
def test_remap_variables_nested_join_stability(leaf, fake_session, table):
55102
# Create two more distinct leaf nodes
56103
leaf2_uncached = core.ArrayValue.from_table(

0 commit comments

Comments
 (0)