From 79aa21665222ee7e1957885186ef3bbbf976438f Mon Sep 17 00:00:00 2001 From: Daizu Date: Wed, 18 Feb 2026 23:08:05 +0000 Subject: [PATCH] feat: Add string indexing and enhance filter capabilities for Collection --- pyproject.toml | 17 +++++-------- src/hydraflow/core/collection.py | 43 ++++++++++++++++++++++++-------- tests/core/test_collection.py | 18 +++++++++++++ 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 610246d5..75a849f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ keywords = [ dependencies = [ "filelock>=3.12.2", "hydra-core>=1.3", - "mlflow>=3.6", + "mlflow>=3.8", "omegaconf>=2.3", "polars>=1.26", "python-ulid>=3.0.0", @@ -55,19 +55,18 @@ hydraflow = "hydraflow.cli:app" [dependency-groups] dev = [ - "basedpyright>=1.37.3", + "basedpyright>=1.38.1", "hydra-joblib-launcher>=1.2", - "prek>=0.3.1", + "prek>=0.3.3", "pytest>=9", "pytest-clarity>=1", "pytest-cov>=7", - "pytest-mock>=3.15.1", + "pytest-mock>=3", "pytest-order>=1", "pytest-randomly>=4", - "pytest-sugar>=1.1.1", + "pytest-sugar>=1", "pytest-xdist>=3", - "ruff>=0.15.0", - "ty>=0.0.2", + "ruff>=0.15.1", ] docs = ["markdown-exec[ansi]", "mkapi>=4.4", "mkdocs-material"] @@ -132,7 +131,3 @@ reportAny = false reportExplicitAny = false reportImportCycles = false reportUnusedCallResult = false - -[tool.ty.rules] -unresolved-import = "ignore" -possibly-unbound-attribute = "ignore" diff --git a/src/hydraflow/core/collection.py b/src/hydraflow/core/collection.py index fdebe1c2..afdf5198 100644 --- a/src/hydraflow/core/collection.py +++ b/src/hydraflow/core/collection.py @@ -64,8 +64,14 @@ def __getitem__(self, index: slice) -> Self: ... @overload def __getitem__(self, index: Iterable[int]) -> Self: ... + @overload + def __getitem__(self, index: str) -> Series: ... + @override - def __getitem__(self, index: int | slice | Iterable[int]) -> I | Self: + def __getitem__(self, index: int | slice | Iterable[int] | str) -> I | Self | Any: + if isinstance(index, str): + return self.to_series(index) + if isinstance(index, int): return self._items[index] @@ -80,7 +86,7 @@ def __iter__(self) -> Iterator[I]: def filter( self, - *criteria: Callable[[I], bool] | tuple[str, Any], + *criteria: Callable[[I], bool] | Iterable[Any] | tuple[str, Any], **kwargs: Any, ) -> Self: """Filter items based on criteria. @@ -117,6 +123,9 @@ def filter( # Filter using a key-value tuple filtered = collection.filter(("age", 25)) + # Filter using an iterable + filtered = collection.filter([True, False, True]) + # Filter using keyword arguments filtered = collection.filter(age=25, name="John") @@ -128,19 +137,33 @@ def filter( ``` """ - items = self._items + if kwargs: + criteria = (*criteria, *kwargs.items()) + index = set(range(len(self._items))) for c in criteria: - if callable(c): - items = [i for i in items if c(i)] - else: - items = [i for i in items if matches(self._get(i, c[0], MISSING), c[1])] - - for key, value in kwargs.items(): - items = [i for i in items if matches(self._get(i, key, MISSING), value)] + index = index.intersection(self._filter(c, index)) + items = [self._items[i] for i in index] return self.__class__(items, self._get) + def _filter( + self, + c: Callable[[I], bool] | Iterable[Any] | tuple[str, Any], + index: set[int], + ) -> Iterator[int]: + if callable(c) or isinstance(c, tuple): + it = [(i, self._items[i]) for i in index] + else: + it = enumerate(self._items) + + if callable(c): + yield from (k for k, i in it if c(i)) + elif isinstance(c, tuple): + yield from (k for k, i in it if matches(self._get(i, c[0], MISSING), c[1])) + else: + yield from (k for (k, _), v in zip(it, c, strict=True) if v) + def try_get( self, *criteria: Callable[[I], bool] | tuple[str, Any], diff --git a/tests/core/test_collection.py b/tests/core/test_collection.py index f1f8d4e4..02d3742c 100644 --- a/tests/core/test_collection.py +++ b/tests/core/test_collection.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Self, override import numpy as np +import polars as pl import pytest from omegaconf import ListConfig @@ -91,6 +92,10 @@ def test_getitem_iterable(rc: Rc) -> None: assert rc._get(rc[0], "y", None) == "c" +def test_getitem_str(rc: Rc) -> None: + assert isinstance(rc["x"], pl.Series) + + def test_iter(rc: Rc) -> None: assert len(list(iter(rc))) == 12 @@ -119,6 +124,19 @@ def test_filter_tuple(rc: Rc) -> None: assert len(rc.filter(("x", (1, 2)), ("y", ["a", "c"]))) == 4 +def test_filter_iterable(rc: Rc) -> None: + rc = rc.filter(rc["x"] >= 2, rc["y"].is_in(["a", "b"])) + assert len(rc) == 4 + assert all(rc["x"] >= 2) + assert all(rc["y"].is_in(["a", "b"])) + + +def test_filter_complex(rc: Rc) -> None: + rc = rc.filter(rc["x"] <= 2, ("x", (1, 3)), rc["y"] > "b", y=["a", "b", "d"]) + assert rc["x"].to_list() == [1, 2] + assert rc["y"].to_list() == ["d", "d"] + + def test_try_get(rc: Rc) -> None: assert rc.try_get(("x", 10)) is None