Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ keywords = [
dependencies = [
"filelock>=3.12.2",
"hydra-core>=1.3",
"mlflow>=3.6",
"mlflow>=3.8",
"omegaconf>=2.3",
"polars>=1.26",
"python-ulid>=3.0.0",
Expand All @@ -55,19 +55,18 @@ hydraflow = "hydraflow.cli:app"

[dependency-groups]
dev = [
"basedpyright>=1.37.3",
"basedpyright>=1.38.1",
"hydra-joblib-launcher>=1.2",
"prek>=0.3.1",
"prek>=0.3.3",
"pytest>=9",
"pytest-clarity>=1",
"pytest-cov>=7",
"pytest-mock>=3.15.1",
"pytest-mock>=3",
"pytest-order>=1",
"pytest-randomly>=4",
"pytest-sugar>=1.1.1",
"pytest-sugar>=1",
"pytest-xdist>=3",
"ruff>=0.15.0",
"ty>=0.0.2",
"ruff>=0.15.1",
]
docs = ["markdown-exec[ansi]", "mkapi>=4.4", "mkdocs-material"]

Expand Down Expand Up @@ -132,7 +131,3 @@ reportAny = false
reportExplicitAny = false
reportImportCycles = false
reportUnusedCallResult = false

[tool.ty.rules]
unresolved-import = "ignore"
possibly-unbound-attribute = "ignore"
43 changes: 33 additions & 10 deletions src/hydraflow/core/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,14 @@ def __getitem__(self, index: slice) -> Self: ...
@overload
def __getitem__(self, index: Iterable[int]) -> Self: ...

@overload
def __getitem__(self, index: str) -> Series: ...

@override
def __getitem__(self, index: int | slice | Iterable[int]) -> I | Self:
def __getitem__(self, index: int | slice | Iterable[int] | str) -> I | Self | Any:
if isinstance(index, str):
return self.to_series(index)

if isinstance(index, int):
return self._items[index]

Expand All @@ -80,7 +86,7 @@ def __iter__(self) -> Iterator[I]:

def filter(
self,
*criteria: Callable[[I], bool] | tuple[str, Any],
*criteria: Callable[[I], bool] | Iterable[Any] | tuple[str, Any],
**kwargs: Any,
) -> Self:
"""Filter items based on criteria.
Expand Down Expand Up @@ -117,6 +123,9 @@ def filter(
# Filter using a key-value tuple
filtered = collection.filter(("age", 25))

# Filter using an iterable
filtered = collection.filter([True, False, True])

# Filter using keyword arguments
filtered = collection.filter(age=25, name="John")

Expand All @@ -128,19 +137,33 @@ def filter(
```

"""
items = self._items
if kwargs:
criteria = (*criteria, *kwargs.items())

index = set(range(len(self._items)))
for c in criteria:
if callable(c):
items = [i for i in items if c(i)]
else:
items = [i for i in items if matches(self._get(i, c[0], MISSING), c[1])]

for key, value in kwargs.items():
items = [i for i in items if matches(self._get(i, key, MISSING), value)]
index = index.intersection(self._filter(c, index))

items = [self._items[i] for i in index]
return self.__class__(items, self._get)

def _filter(
self,
c: Callable[[I], bool] | Iterable[Any] | tuple[str, Any],
index: set[int],
) -> Iterator[int]:
if callable(c) or isinstance(c, tuple):
it = [(i, self._items[i]) for i in index]
else:
it = enumerate(self._items)

if callable(c):
yield from (k for k, i in it if c(i))
elif isinstance(c, tuple):
yield from (k for k, i in it if matches(self._get(i, c[0], MISSING), c[1]))
else:
yield from (k for (k, _), v in zip(it, c, strict=True) if v)

def try_get(
self,
*criteria: Callable[[I], bool] | tuple[str, Any],
Expand Down
18 changes: 18 additions & 0 deletions tests/core/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Any, Self, override

import numpy as np
import polars as pl
import pytest
from omegaconf import ListConfig

Expand Down Expand Up @@ -91,6 +92,10 @@ def test_getitem_iterable(rc: Rc) -> None:
assert rc._get(rc[0], "y", None) == "c"


def test_getitem_str(rc: Rc) -> None:
assert isinstance(rc["x"], pl.Series)


def test_iter(rc: Rc) -> None:
assert len(list(iter(rc))) == 12

Expand Down Expand Up @@ -119,6 +124,19 @@ def test_filter_tuple(rc: Rc) -> None:
assert len(rc.filter(("x", (1, 2)), ("y", ["a", "c"]))) == 4


def test_filter_iterable(rc: Rc) -> None:
rc = rc.filter(rc["x"] >= 2, rc["y"].is_in(["a", "b"]))
assert len(rc) == 4
assert all(rc["x"] >= 2)
assert all(rc["y"].is_in(["a", "b"]))


def test_filter_complex(rc: Rc) -> None:
rc = rc.filter(rc["x"] <= 2, ("x", (1, 3)), rc["y"] > "b", y=["a", "b", "d"])
assert rc["x"].to_list() == [1, 2]
assert rc["y"].to_list() == ["d", "d"]


def test_try_get(rc: Rc) -> None:
assert rc.try_get(("x", 10)) is None

Expand Down