Skip to content

Commit 3cd4357

Browse files
feat: Support pd.col expressions with .loc and getitem
1 parent 61a9484 commit 3cd4357

File tree

4 files changed

+45
-3
lines changed

4 files changed

+45
-3
lines changed

bigframes/core/array_value.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,12 @@ def filter_by_id(self, predicate_id: str, keep_null: bool = False) -> ArrayValue
204204
return self.filter(predicate)
205205

206206
def filter(self, predicate: ex.Expression):
207-
return ArrayValue(nodes.FilterNode(child=self.node, predicate=predicate))
207+
if predicate.is_scalar_expr:
208+
return ArrayValue(nodes.FilterNode(child=self.node, predicate=predicate))
209+
else:
210+
arr, filter_ids = self.compute_general_expression([predicate])
211+
arr = arr.filter_by_id(filter_ids[0])
212+
return arr.drop_columns(filter_ids)
208213

209214
def order_by(
210215
self, by: Sequence[OrderingExpression], is_total_order: bool = False

bigframes/core/indexers.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import pandas as pd
2424

2525
import bigframes.core.blocks
26+
import bigframes.core.col
2627
import bigframes.core.expression as ex
2728
import bigframes.core.guid as guid
2829
import bigframes.core.indexes as indexes
@@ -36,7 +37,11 @@
3637

3738
if typing.TYPE_CHECKING:
3839
LocSingleKey = Union[
39-
bigframes.series.Series, indexes.Index, slice, bigframes.core.scalar.Scalar
40+
bigframes.series.Series,
41+
indexes.Index,
42+
slice,
43+
bigframes.core.scalar.Scalar,
44+
bigframes.core.col.Expression,
4045
]
4146

4247

@@ -309,6 +314,15 @@ def _loc_getitem_series_or_dataframe(
309314
raise NotImplementedError(
310315
f"loc does not yet support indexing with a slice. {constants.FEEDBACK_LINK}"
311316
)
317+
if isinstance(key, bigframes.core.col.Expression):
318+
label_to_col_ref = {
319+
label: ex.deref(id)
320+
for id, label in series_or_dataframe._block.col_id_to_label.items()
321+
}
322+
resolved_expr = key._value.bind_variables(label_to_col_ref)
323+
result = series_or_dataframe.copy()
324+
result._set_block(series_or_dataframe._block.filter(resolved_expr))
325+
return result
312326
if callable(key):
313327
raise NotImplementedError(
314328
f"loc does not yet support indexing with a callable. {constants.FEEDBACK_LINK}"

bigframes/dataframe.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,13 +623,18 @@ def __getitem__(
623623
): # No return type annotations (like pandas) as type cannot always be determined statically
624624
# NOTE: This implements the operations described in
625625
# https://pandas.pydata.org/docs/getting_started/intro_tutorials/03_subset_data.html
626+
import bigframes.core.col
627+
import bigframes.pandas
626628

627-
if isinstance(key, bigframes.series.Series):
629+
if isinstance(key, bigframes.pandas.Series):
628630
return self._getitem_bool_series(key)
629631

630632
if isinstance(key, slice):
631633
return self.iloc[key]
632634

635+
if isinstance(key, bigframes.core.col.Expression):
636+
return self.loc[key]
637+
633638
# TODO(tswast): Fix this pylance warning: Class overlaps "Hashable"
634639
# unsafely and could produce a match at runtime
635640
if isinstance(key, blocks.Label):

tests/unit/test_col.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,21 @@ def test_pd_col_binary_bool_operators(scalars_dfs, op):
158158
pd_result = scalars_pandas_df.assign(**pd_kwargs)
159159

160160
assert_frame_equal(bf_result, pd_result)
161+
162+
163+
def test_loc_with_pd_col(scalars_dfs):
164+
scalars_df, scalars_pandas_df = scalars_dfs
165+
166+
bf_result = scalars_df.loc[bpd.col("float64_col") > 4].to_pandas()
167+
pd_result = scalars_pandas_df.loc[pd.col("float64_col") > 4] # type: ignore
168+
169+
assert_frame_equal(bf_result, pd_result)
170+
171+
172+
def test_getitem_with_pd_col(scalars_dfs):
173+
scalars_df, scalars_pandas_df = scalars_dfs
174+
175+
bf_result = scalars_df[bpd.col("float64_col") > 4].to_pandas()
176+
pd_result = scalars_pandas_df[pd.col("float64_col") > 4] # type: ignore
177+
178+
assert_frame_equal(bf_result, pd_result)

0 commit comments

Comments
 (0)