Skip to content

Commit 4aa642d

Browse files
committed
feat: add ingress decorator for automatic argument transformation with various usage patterns
1 parent 99be6b4 commit 4aa642d

File tree

4 files changed

+741
-0
lines changed

4 files changed

+741
-0
lines changed

i2/castgraph.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ def transform_func(obj, ctx): ...
177177
)
178178
from collections.abc import Callable, Hashable, Iterable, MutableMapping
179179

180+
from i2.wrapper import Wrap
181+
from i2.signatures import Sig
182+
180183

181184
T = TypeVar("T")
182185
U = TypeVar("U")
@@ -942,6 +945,127 @@ def kinds(self) -> set[Hashable]:
942945
"""
943946
return self._kinds.copy()
944947

948+
# -----------------------------
949+
# Ingress decorator for automatic argument transformation
950+
# -----------------------------
951+
952+
@property
953+
def ingress(self):
954+
"""Return decorator factory with attribute access for kinds.
955+
956+
This property provides a flexible interface for decorating functions to
957+
automatically transform their arguments to specified kinds.
958+
959+
Usage patterns:
960+
961+
1. Specify kind and argument name:
962+
@graph.ingress('text', 'content')
963+
def func(content): ...
964+
965+
2. Specify kind only (transforms first argument):
966+
@graph.ingress('text')
967+
def func(arg): ...
968+
969+
3. Use keyword argument:
970+
@graph.ingress(arg_name='text')
971+
def func(arg_name): ...
972+
973+
4. Attribute-based syntax for registered kinds:
974+
@graph.ingress.text('content')
975+
def func(content): ...
976+
977+
5. Attribute-based for first argument:
978+
@graph.ingress.text
979+
def func(arg): ...
980+
981+
Examples
982+
--------
983+
>>> graph = TransformationGraph()
984+
>>> graph.add_node('text', isa=lambda x: isinstance(x, str))
985+
>>> graph.add_node(int)
986+
>>> @graph.register_edge('text', int)
987+
... def text_to_int(s, ctx): return int(s)
988+
>>> @graph.ingress('text')
989+
... def process(x):
990+
... return x + ' processed'
991+
>>> # Can now pass int, will be transformed to text first
992+
"""
993+
return _IngressProxy(self)
994+
995+
def _ingress_decorator(
996+
self,
997+
kind_or_arg: Hashable | None = None,
998+
arg_name: str | None = None,
999+
*,
1000+
context: dict | None = None,
1001+
):
1002+
"""Internal method to create ingress decorator.
1003+
1004+
Parameters
1005+
----------
1006+
kind_or_arg : Hashable | None
1007+
The target kind for transformation, or argument name if arg_name is provided
1008+
arg_name : str | None
1009+
The name of the argument to transform. If None, transforms first argument.
1010+
context : dict | None
1011+
Optional context to pass to transformations
1012+
1013+
Returns
1014+
-------
1015+
Callable
1016+
Decorator function
1017+
"""
1018+
# Determine target kind and argument name
1019+
if arg_name is not None:
1020+
# @graph.ingress(str, 'x') or @graph.ingress('text', 'x')
1021+
target_kind = kind_or_arg
1022+
target_arg = arg_name
1023+
elif isinstance(kind_or_arg, str) or isinstance(kind_or_arg, type):
1024+
# @graph.ingress('text') or @graph.ingress(int)
1025+
target_kind = kind_or_arg
1026+
target_arg = None # Will use first arg
1027+
else:
1028+
# kind_or_arg could be None or some other hashable
1029+
target_kind = kind_or_arg
1030+
target_arg = None
1031+
1032+
def decorator(func):
1033+
nonlocal target_arg
1034+
sig = Sig(func)
1035+
1036+
# Default to first arg if not specified
1037+
if target_arg is None:
1038+
if not sig.names:
1039+
raise ValueError(
1040+
f"Function {func.__name__} has no parameters to transform"
1041+
)
1042+
target_arg = sig.names[0]
1043+
1044+
# Validate target_arg exists
1045+
if target_arg not in sig.names:
1046+
raise ValueError(
1047+
f"Argument '{target_arg}' not found in function {func.__name__}. "
1048+
f"Available arguments: {sig.names}"
1049+
)
1050+
1051+
# Create ingress function
1052+
def ingress_func(*args, **kwargs):
1053+
# Map to all kwargs
1054+
all_kwargs = sig.map_arguments(args, kwargs, apply_defaults=False)
1055+
1056+
# Transform the target argument if present
1057+
if target_arg in all_kwargs:
1058+
all_kwargs[target_arg] = self.transform_any(
1059+
all_kwargs[target_arg], target_kind, context=context
1060+
)
1061+
1062+
# Convert back to args/kwargs respecting signature
1063+
return sig.mk_args_and_kwargs(all_kwargs, allow_partial=True)
1064+
1065+
return Wrap(func, ingress=ingress_func)
1066+
1067+
return decorator
1068+
9451069
# -----------------------------
9461070
# Backward compatibility methods (deprecated)
9471071
# -----------------------------
@@ -1558,6 +1682,69 @@ def _find_min_cost_path(self, src: type[Any], dst: type[Any]) -> list[Edge]:
15581682
# ----------------------------------------------------------------------
15591683

15601684

1685+
class _IngressProxy:
1686+
"""Helper class to provide attribute-based access to kinds for ingress decorator.
1687+
1688+
This class enables syntax like @graph.ingress.text or @graph.ingress.int
1689+
by dynamically looking up kinds and creating decorators.
1690+
"""
1691+
1692+
def __init__(self, graph: TransformationGraph):
1693+
self._graph = graph
1694+
1695+
def __call__(
1696+
self,
1697+
kind_or_arg: Hashable | None = None,
1698+
arg_name: str | None = None,
1699+
*,
1700+
context: dict | None = None,
1701+
):
1702+
"""Allow calling as @graph.ingress(kind, arg_name)."""
1703+
return self._graph._ingress_decorator(kind_or_arg, arg_name, context=context)
1704+
1705+
def __getattr__(self, kind_name: str):
1706+
"""Enable attribute access like @graph.ingress.text or @graph.ingress.int.
1707+
1708+
Looks up the kind by string name or by type.__name__.
1709+
Returns a decorator or a decorator factory depending on usage.
1710+
"""
1711+
# Look up kind by string name or type.__name__
1712+
matching_kind = None
1713+
for kind in self._graph.kinds():
1714+
if isinstance(kind, str) and kind == kind_name:
1715+
matching_kind = kind
1716+
break
1717+
elif isinstance(kind, type) and kind.__name__ == kind_name:
1718+
matching_kind = kind
1719+
break
1720+
1721+
if matching_kind is None:
1722+
raise AttributeError(
1723+
f"Kind '{kind_name}' not found in graph. "
1724+
f"Available kinds: {self._graph.kinds()}"
1725+
)
1726+
1727+
# Return a factory that can be used as @graph.ingress.kind or @graph.ingress.kind(arg_name)
1728+
return _KindIngressFactory(self._graph, matching_kind)
1729+
1730+
1731+
class _KindIngressFactory:
1732+
"""Factory to handle @graph.ingress.kind and @graph.ingress.kind(arg_name) syntax."""
1733+
1734+
def __init__(self, graph: TransformationGraph, kind: Hashable):
1735+
self._graph = graph
1736+
self._kind = kind
1737+
1738+
def __call__(self, func_or_arg_name):
1739+
"""Handle both @graph.ingress.kind and @graph.ingress.kind(arg_name) patterns."""
1740+
if callable(func_or_arg_name):
1741+
# Used as @graph.ingress.kind (no parentheses, applied to first arg)
1742+
return self._graph._ingress_decorator(self._kind, None)(func_or_arg_name)
1743+
else:
1744+
# Used as @graph.ingress.kind(arg_name) or @graph.ingress.kind('arg')
1745+
return self._graph._ingress_decorator(self._kind, func_or_arg_name)
1746+
1747+
15611748
def design_guidelines() -> str:
15621749
"""
15631750
Returns concise guidance for organizing casting in Python.

i2/tests/test_castgraph.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,83 @@ class M2: ...
332332
@reg.register()
333333
def bad(x):
334334
return None
335+
336+
337+
# --- Tests for ingress decorator ---
338+
339+
340+
def test_ingress_with_string_kinds():
341+
"""Test ingress decorator with string-based kinds."""
342+
from i2.castgraph import TransformationGraph
343+
344+
graph = TransformationGraph()
345+
graph.add_node('text', isa=lambda x: isinstance(x, str))
346+
graph.add_node('number', isa=lambda x: isinstance(x, (int, float)))
347+
348+
@graph.register_edge('text', 'number')
349+
def text_to_number(t, ctx):
350+
return float(t)
351+
352+
@graph.ingress('number', 'x')
353+
def double(x):
354+
return x * 2
355+
356+
result = double("21")
357+
assert result == 42.0
358+
359+
360+
def test_ingress_with_type_kinds():
361+
"""Test ingress decorator with type-based kinds."""
362+
from i2.castgraph import TransformationGraph
363+
364+
graph = TransformationGraph()
365+
366+
@graph.register_edge(str, int)
367+
def str_to_int(s, ctx):
368+
return int(s)
369+
370+
@graph.ingress(int)
371+
def square(n):
372+
return n**2
373+
374+
result = square("5")
375+
assert result == 25
376+
377+
378+
def test_ingress_attribute_syntax():
379+
"""Test ingress decorator with attribute-based syntax."""
380+
from i2.castgraph import TransformationGraph
381+
382+
graph = TransformationGraph()
383+
graph.add_node('data', isa=lambda x: isinstance(x, dict))
384+
385+
@graph.register_edge(str, 'data')
386+
def str_to_data(s, ctx):
387+
return json.loads(s)
388+
389+
@graph.ingress.data('obj')
390+
def get_value(obj):
391+
return obj.get('value', 0)
392+
393+
result = get_value('{"value": 42}')
394+
assert result == 42
395+
396+
397+
def test_ingress_preserves_function_signature():
398+
"""Test that ingress decorator preserves function behavior."""
399+
from i2.castgraph import TransformationGraph
400+
401+
graph = TransformationGraph()
402+
403+
@graph.register_edge(str, int)
404+
def str_to_int(s, ctx):
405+
return int(s)
406+
407+
@graph.ingress(int, 'a')
408+
def add(a, b):
409+
return a + b
410+
411+
# Test with various call patterns
412+
assert add("10", 5) == 15
413+
assert add(a="20", b=22) == 42
414+
assert add("7", b=8) == 15

0 commit comments

Comments
 (0)