3232"""
3333
3434from collections .abc import Mapping , Sequence
35- from dataclasses import Field , fields , is_dataclass
36- from typing import Union , get_args , get_origin
35+ from dataclasses import fields , is_dataclass
36+ from typing import NamedTuple , Union , get_args , get_origin
3737
3838from arraycontext .container import is_array_container_type
3939
4040
4141# {{{ dataclass containers
4242
43+ class _Field (NamedTuple ):
44+ """Small lookalike for :class:`dataclasses.Field`."""
45+
46+ init : bool
47+ name : str
48+ type : type
49+
50+
4351def is_array_type (tp : type ) -> bool :
4452 from arraycontext import Array
4553 return tp is Array or is_array_container_type (tp )
@@ -73,7 +81,9 @@ def dataclass_array_container(cls: type) -> type:
7381
7482 assert is_dataclass (cls )
7583
76- def is_array_field (f : Field , field_type : type ) -> bool :
84+ def is_array_field (f : _Field ) -> bool :
85+ field_type = f .type
86+
7787 # NOTE: unions of array containers are treated separately to handle
7888 # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
7989 # they can work seamlessly with arithmetic and traversal.
@@ -96,10 +106,8 @@ def is_array_field(f: Field, field_type: type) -> bool:
96106 f"Field '{ f .name } ' union contains non-array container "
97107 "arguments. All arguments must be array containers." )
98108
99- if isinstance (field_type , str ):
100- raise TypeError (
101- f"String annotation on field '{ f .name } ' not supported. "
102- "(this may be due to 'from __future__ import annotations')" )
109+ # NOTE: this should never happen due to using `inspect.get_annotations`
110+ assert not isinstance (field_type , str )
103111
104112 if __debug__ :
105113 if not f .init :
@@ -127,36 +135,52 @@ def is_array_field(f: Field, field_type: type) -> bool:
127135
128136 return is_array_type (field_type )
129137
138+ from pytools import partition
139+
140+ array_fields = _get_annotated_fields (cls )
141+ array_fields , non_array_fields = partition (is_array_field , array_fields )
142+
143+ if not array_fields :
144+ raise ValueError (f"'{ cls } ' must have fields with array container type "
145+ "in order to use the 'dataclass_array_container' decorator" )
146+
147+ return _inject_dataclass_serialization (cls , array_fields , non_array_fields )
148+
149+
150+ def _get_annotated_fields (cls : type ) -> Sequence [_Field ]:
151+ """Get a list of fields in the class *cls* with evaluated types.
152+
153+ If any of the fields in *cls* have type annotations that are strings, e.g.
154+ from using ``from __future__ import annotations``, this function evaluates
155+ them using :func:`inspect.get_annotations`. Note that this requires the class
156+ to live in a module that is importable.
157+
158+ :return: a list of fields.
159+ """
160+
130161 from inspect import get_annotations
131162
132- array_fields : list [Field ] = []
133- non_array_fields : list [Field ] = []
163+ result = []
134164 cls_ann : Mapping [str , type ] | None = None
135165 for field in fields (cls ):
136166 field_type_or_str = field .type
137167 if isinstance (field_type_or_str , str ):
138168 if cls_ann is None :
139169 cls_ann = get_annotations (cls , eval_str = True )
170+
140171 field_type = cls_ann [field .name ]
141172 else :
142173 field_type = field_type_or_str
143174
144- if is_array_field (field , field_type ):
145- array_fields .append (field )
146- else :
147- non_array_fields .append (field )
148-
149- if not array_fields :
150- raise ValueError (f"'{ cls } ' must have fields with array container type "
151- "in order to use the 'dataclass_array_container' decorator" )
175+ result .append (_Field (init = field .init , name = field .name , type = field_type ))
152176
153- return _inject_dataclass_serialization ( cls , array_fields , non_array_fields )
177+ return result
154178
155179
156180def _inject_dataclass_serialization (
157181 cls : type ,
158- array_fields : Sequence [Field ],
159- non_array_fields : Sequence [Field ]) -> type :
182+ array_fields : Sequence [_Field ],
183+ non_array_fields : Sequence [_Field ]) -> type :
160184 """Implements :func:`~arraycontext.serialize_container` and
161185 :func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
162186
0 commit comments