Skip to content

Commit 9ac8bc8

Browse files
alexfiklinducer
authored andcommitted
dataclass: refactor evaluating string fields
1 parent c4f00b8 commit 9ac8bc8

File tree

1 file changed

+44
-20
lines changed

1 file changed

+44
-20
lines changed

arraycontext/container/dataclass.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,22 @@
3232
"""
3333

3434
from collections.abc import Mapping, Sequence
35-
from dataclasses import Field, fields, is_dataclass
36-
from typing import Union, get_args, get_origin
35+
from dataclasses import fields, is_dataclass
36+
from typing import NamedTuple, Union, get_args, get_origin
3737

3838
from arraycontext.container import is_array_container_type
3939

4040

4141
# {{{ dataclass containers
4242

43+
class _Field(NamedTuple):
44+
"""Small lookalike for :class:`dataclasses.Field`."""
45+
46+
init: bool
47+
name: str
48+
type: type
49+
50+
4351
def is_array_type(tp: type) -> bool:
4452
from arraycontext import Array
4553
return tp is Array or is_array_container_type(tp)
@@ -73,7 +81,9 @@ def dataclass_array_container(cls: type) -> type:
7381

7482
assert is_dataclass(cls)
7583

76-
def is_array_field(f: Field, field_type: type) -> bool:
84+
def is_array_field(f: _Field) -> bool:
85+
field_type = f.type
86+
7787
# NOTE: unions of array containers are treated separately to handle
7888
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
7989
# they can work seamlessly with arithmetic and traversal.
@@ -96,10 +106,8 @@ def is_array_field(f: Field, field_type: type) -> bool:
96106
f"Field '{f.name}' union contains non-array container "
97107
"arguments. All arguments must be array containers.")
98108

99-
if isinstance(field_type, str):
100-
raise TypeError(
101-
f"String annotation on field '{f.name}' not supported. "
102-
"(this may be due to 'from __future__ import annotations')")
109+
# NOTE: this should never happen due to using `inspect.get_annotations`
110+
assert not isinstance(field_type, str)
103111

104112
if __debug__:
105113
if not f.init:
@@ -127,36 +135,52 @@ def is_array_field(f: Field, field_type: type) -> bool:
127135

128136
return is_array_type(field_type)
129137

138+
from pytools import partition
139+
140+
array_fields = _get_annotated_fields(cls)
141+
array_fields, non_array_fields = partition(is_array_field, array_fields)
142+
143+
if not array_fields:
144+
raise ValueError(f"'{cls}' must have fields with array container type "
145+
"in order to use the 'dataclass_array_container' decorator")
146+
147+
return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
148+
149+
150+
def _get_annotated_fields(cls: type) -> Sequence[_Field]:
151+
"""Get a list of fields in the class *cls* with evaluated types.
152+
153+
If any of the fields in *cls* have type annotations that are strings, e.g.
154+
from using ``from __future__ import annotations``, this function evaluates
155+
them using :func:`inspect.get_annotations`. Note that this requires the class
156+
to live in a module that is importable.
157+
158+
:return: a list of fields.
159+
"""
160+
130161
from inspect import get_annotations
131162

132-
array_fields: list[Field] = []
133-
non_array_fields: list[Field] = []
163+
result = []
134164
cls_ann: Mapping[str, type] | None = None
135165
for field in fields(cls):
136166
field_type_or_str = field.type
137167
if isinstance(field_type_or_str, str):
138168
if cls_ann is None:
139169
cls_ann = get_annotations(cls, eval_str=True)
170+
140171
field_type = cls_ann[field.name]
141172
else:
142173
field_type = field_type_or_str
143174

144-
if is_array_field(field, field_type):
145-
array_fields.append(field)
146-
else:
147-
non_array_fields.append(field)
148-
149-
if not array_fields:
150-
raise ValueError(f"'{cls}' must have fields with array container type "
151-
"in order to use the 'dataclass_array_container' decorator")
175+
result.append(_Field(init=field.init, name=field.name, type=field_type))
152176

153-
return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
177+
return result
154178

155179

156180
def _inject_dataclass_serialization(
157181
cls: type,
158-
array_fields: Sequence[Field],
159-
non_array_fields: Sequence[Field]) -> type:
182+
array_fields: Sequence[_Field],
183+
non_array_fields: Sequence[_Field]) -> type:
160184
"""Implements :func:`~arraycontext.serialize_container` and
161185
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
162186

0 commit comments

Comments
 (0)