diff --git a/csp/impl/struct.py b/csp/impl/struct.py index c088e8548..dbacb32e9 100644 --- a/csp/impl/struct.py +++ b/csp/impl/struct.py @@ -35,7 +35,11 @@ def __new__(cls, name, bases, dct): # Lists need to be normalized too as potentially we need to add a boolean flag to use FastList if v == FastList: raise TypeError(f"{v} annotation is not supported without args") - if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v): + if ( + CspTypingUtils.is_generic_container(v) + or CspTypingUtils.is_union_type(v) + or CspTypingUtils.is_literal_type(v) + ): actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v) if CspTypingUtils.is_generic_container(actual_type): raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]") @@ -147,6 +151,17 @@ def serializer(val, handler): class Struct(_csptypesimpl.PyStruct, metaclass=StructMeta): + @classmethod + def type_adapter(cls): + # Late import to avoid autogen issues + from pydantic import TypeAdapter + + internal_type_adapter = getattr(cls, "_pydantic_type_adapter", None) + if internal_type_adapter: + return internal_type_adapter + cls._pydantic_type_adapter = TypeAdapter(cls) + return cls._pydantic_type_adapter + @classmethod def metadata(cls, typed=False): if typed: @@ -191,7 +206,8 @@ def _obj_from_python(cls, json, obj_type): if CspTypingUtils.is_generic_container(obj_type): if CspTypingUtils.get_origin(obj_type) in (typing.List, typing.Set, typing.Tuple, FastList): return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type) - (expected_item_type,) = obj_type.__args__ + # We only take the first item, so like for a Tuple, we would ignore arguments after + expected_item_type = obj_type.__args__[0] return_type = list if isinstance(return_type, list) else return_type return return_type(cls._obj_from_python(v, expected_item_type) for v in json) elif CspTypingUtils.get_origin(obj_type) is typing.Dict: @@ -206,6 +222,13 @@ def _obj_from_python(cls, json, obj_type): return json else: raise NotImplementedError(f"Can not deserialize {obj_type} from json") + elif CspTypingUtils.is_union_type(obj_type): + return json ## no checks, just let it through + elif CspTypingUtils.is_literal_type(obj_type): + return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type) + if isinstance(json, return_type): + return json + raise ValueError(f"Expected type {return_type} received {json.__class__}") elif issubclass(obj_type, Struct): if not isinstance(json, dict): raise TypeError("Representation of struct as json is expected to be of dict type") @@ -223,7 +246,9 @@ def _obj_from_python(cls, json, obj_type): return obj_type(json) @classmethod - def from_dict(cls, json: dict): + def from_dict(cls, json: dict, use_pydantic: bool = False): + if use_pydantic: + return cls.type_adapter().validate_python(json) return cls._obj_from_python(json, cls) def to_dict_depr(self): diff --git a/csp/impl/types/container_type_normalizer.py b/csp/impl/types/container_type_normalizer.py index 5532cd66e..d8b9b8811 100644 --- a/csp/impl/types/container_type_normalizer.py +++ b/csp/impl/types/container_type_normalizer.py @@ -81,21 +81,21 @@ def normalized_type_to_actual_python_type(cls, typ, level=0): return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True] if origin is typing.List and level == 0: return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)] - if origin is typing.Literal: - # Import here to prevent circular import - from csp.impl.types.instantiation_type_resolver import UpcastRegistry - - args = typing.get_args(typ) - typ = type(args[0]) - for arg in args[1:]: - typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False) - if typ: - return typ - else: - return object return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ) elif CspTypingUtils.is_union_type(typ): return object + elif CspTypingUtils.is_literal_type(typ): + # Import here to prevent circular import + from csp.impl.types.instantiation_type_resolver import UpcastRegistry + + args = typing.get_args(typ) + typ = type(args[0]) + for arg in args[1:]: + typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False) + if typ: + return typ + else: + return object else: return typ diff --git a/csp/impl/types/pydantic_types.py b/csp/impl/types/pydantic_types.py index 6968a394d..4d4d27989 100644 --- a/csp/impl/types/pydantic_types.py +++ b/csp/impl/types/pydantic_types.py @@ -1,7 +1,7 @@ import sys import types import typing -from typing import Any, ForwardRef, Generic, Optional, Type, TypeVar, Union, get_args, get_origin +from typing import Any, ForwardRef, Generic, Literal, Optional, Type, TypeVar, Union, get_args, get_origin from pydantic import GetCoreSchemaHandler, ValidationInfo, ValidatorFunctionWrapHandler from pydantic_core import CoreSchema, core_schema @@ -184,6 +184,8 @@ def adjust_annotations( return TsType[ adjust_annotations(args[0], top_level=False, in_ts=True, make_optional=False, forced_tvars=forced_tvars) ] + if origin is Literal: # for literals, we stop converting + return Optional[annotation] if make_optional else annotation else: try: if origin is CspTypeVar or origin is CspTypeVarType: diff --git a/csp/impl/types/type_annotation_normalizer_transformer.py b/csp/impl/types/type_annotation_normalizer_transformer.py index 1194d77c7..00700b472 100644 --- a/csp/impl/types/type_annotation_normalizer_transformer.py +++ b/csp/impl/types/type_annotation_normalizer_transformer.py @@ -51,6 +51,8 @@ def visit_arg(self, node): return node def visit_Subscript(self, node): + # We choose to avoid parsing here + # to maintain current behavior of allowing empty lists in our types return node def visit_List(self, node): @@ -98,17 +100,13 @@ def visit_Call(self, node): return node def visit_Constant(self, node): - if not self._cur_arg: - return node - - if self._cur_arg: - return ast.Call( - func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()), - args=[node], - keywords=[], - ) - else: + if not self._cur_arg or not isinstance(node.value, str): return node + return ast.Call( + func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()), + args=[node], + keywords=[], + ) def visit_Str(self, node): return self.visit_Constant(node) diff --git a/csp/impl/types/typing_utils.py b/csp/impl/types/typing_utils.py index 6de3d8c38..03f248a10 100644 --- a/csp/impl/types/typing_utils.py +++ b/csp/impl/types/typing_utils.py @@ -15,6 +15,25 @@ class FastList(typing.List, typing.Generic[T]): # Need to inherit from Generic[ def __init__(self): raise NotImplementedError("Can not init FastList class") + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + from pydantic_core import core_schema + + # Late import to not interfere with autogen + args = typing.get_args(source_type) + if args: + inner_type = args[0] + list_schema = handler.generate_schema(typing.List[inner_type]) + else: + list_schema = handler.generate_schema(typing.List) + + def create_instance(raw_data, validator): + if isinstance(raw_data, FastList): + return raw_data + return validator(raw_data) # just return a list + + return core_schema.no_info_wrap_validator_function(function=create_instance, schema=list_schema) + class CspTypingUtils39: _ORIGIN_COMPAT_MAP = {list: typing.List, set: typing.Set, dict: typing.Dict, tuple: typing.Tuple} @@ -23,7 +42,7 @@ class CspTypingUtils39: @classmethod def is_generic_container(cls, typ): - return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union + return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ not in (typing.Union, typing.Literal) @classmethod def is_type_spec(cls, val): @@ -56,6 +75,10 @@ def is_numpy_nd_array_type(cls, typ): def is_union_type(cls, typ): return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Union + @classmethod + def is_literal_type(cls, typ): + return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Literal + @classmethod def is_forward_ref(cls, typ): return isinstance(typ, typing.ForwardRef) diff --git a/csp/tests/impl/test_struct.py b/csp/tests/impl/test_struct.py index b246ff4f3..e88123c35 100644 --- a/csp/tests/impl/test_struct.py +++ b/csp/tests/impl/test_struct.py @@ -1,6 +1,8 @@ import enum import json +import os import pickle +import sys import typing import unittest from datetime import date, datetime, time, timedelta @@ -17,6 +19,8 @@ from csp.impl.types.typing_utils import FastList from csp.typing import Numpy1DArray +USE_PYDANTIC = os.environ.get("CSP_PYDANTIC", True) + class MyEnum(csp.Enum): A = 1 @@ -796,6 +800,8 @@ class MyStruct(csp.Struct): def test_from_dict_with_enum(self): struct = StructWithDefaults.from_dict({"e": MyEnum.A}) self.assertEqual(MyEnum.A, getattr(struct, "e")) + struct = StructWithDefaults.from_dict({"e": MyEnum.A}, use_pydantic=True) + self.assertEqual(MyEnum.A, getattr(struct, "e")) def test_from_dict_with_list_derived_type(self): class ListDerivedType(list): @@ -809,32 +815,40 @@ class StructWithListDerivedType(csp.Struct): self.assertTrue(isinstance(s1.to_dict()["ldt"], ListDerivedType)) s2 = StructWithListDerivedType.from_dict(s1.to_dict()) self.assertEqual(s1, s2) + s3 = StructWithListDerivedType.from_dict(s1.to_dict(), use_pydantic=True) + self.assertEqual(s1, s3) def test_from_dict_loop_no_defaults(self): looped = StructNoDefaults.from_dict(StructNoDefaults(a1=[9, 10]).to_dict()) self.assertEqual(looped, StructNoDefaults(a1=[9, 10])) + looped = StructNoDefaults.from_dict(StructNoDefaults(a1=[9, 10]).to_dict(), use_pydantic=True) + self.assertEqual(looped, StructNoDefaults(a1=[9, 10])) def test_from_dict_loop_with_defaults(self): - looped = StructWithDefaults.from_dict(StructWithDefaults().to_dict()) - # Note that we cant compare numpy arrays, so we check them independently - comp = StructWithDefaults() - self.assertTrue(np.array_equal(looped.np_arr, comp.np_arr)) + for use_pydantic in [True, False]: + looped = StructWithDefaults.from_dict(StructWithDefaults().to_dict(), use_pydantic=use_pydantic) + # Note that we cant compare numpy arrays, so we check them independently + comp = StructWithDefaults() + self.assertTrue(np.array_equal(looped.np_arr, comp.np_arr)) - del looped.np_arr - del comp.np_arr - self.assertEqual(looped, comp) + del looped.np_arr + del comp.np_arr + self.assertEqual(looped, comp) def test_from_dict_loop_with_generic_typing(self): class MyStruct(csp.Struct): foo: Set[int] - bar: Tuple[str] + bar: Tuple[str, ...] np_arr: csp.typing.NumpyNDArray[float] - looped = MyStruct.from_dict(MyStruct(foo=set((9, 10)), bar=("a", "b"), np_arr=np.array([1, 3])).to_dict()) - expected = MyStruct(foo=set((9, 10)), bar=("a", "b"), np_arr=np.array([1, 3])) - self.assertEqual(looped.foo, expected.foo) - self.assertEqual(looped.bar, expected.bar) - self.assertTrue(np.all(looped.np_arr == expected.np_arr)) + for use_pydantic in [True, False]: + looped = MyStruct.from_dict( + MyStruct(foo=set((9, 10)), bar=("a", "b"), np_arr=np.array([1, 3])).to_dict(), use_pydantic=use_pydantic + ) + expected = MyStruct(foo=set((9, 10)), bar=("a", "b"), np_arr=np.array([1, 3])) + self.assertEqual(looped.foo, expected.foo) + self.assertEqual(looped.bar, expected.bar) + self.assertTrue(np.all(looped.np_arr == expected.np_arr)) def test_struct_yaml_serialization(self): class S1(csp.Struct): @@ -3010,7 +3024,7 @@ class SimpleStruct(csp.Struct): # Valid data valid_data = {"value": 11, "name": "ya", "scores": [1.1, 2.2, 3.3]} - result = TypeAdapter(SimpleStruct).validate_python(valid_data) + result = SimpleStruct.from_dict(valid_data, use_pydantic=True) self.assertIsInstance(result, SimpleStruct) self.assertEqual(result.value, 11) self.assertEqual(result.name, "ya") @@ -3019,11 +3033,11 @@ class SimpleStruct(csp.Struct): invalid_data = valid_data.copy() invalid_data["missing"] = False with self.assertRaises(ValidationError): - TypeAdapter(SimpleStruct).validate_python(invalid_data) # extra fields throw an error + SimpleStruct.from_dict(invalid_data, use_pydantic=True) # extra fields throw an error # Test that we can validate existing structs existing = SimpleStruct(value=1, scores=[1]) - new = TypeAdapter(SimpleStruct).validate_python(existing) + new = SimpleStruct.from_dict(existing, use_pydantic=True) self.assertTrue(existing is new) # we do not revalidate self.assertEqual(existing.value, 1) @@ -3032,7 +3046,7 @@ class SimpleStruct(csp.Struct): "value": "42", # string should convert to int "scores": ["1.1", 2, "3.3"], # mixed types should convert to float } - result = TypeAdapter(SimpleStruct).validate_python(coercion_data) + result = SimpleStruct.from_dict(coercion_data, use_pydantic=True) self.assertEqual(result.value, 42) self.assertEqual(result.scores, [1.1, 2.0, 3.3]) @@ -3042,7 +3056,7 @@ class NestedStruct(csp.Struct): tags: List[str] nested_data = {"simple": {"value": 11, "name": "ya", "scores": [1.1, 2.2, 3.3]}, "tags": ["test1", "test2"]} - result = TypeAdapter(NestedStruct).validate_python(nested_data) + result = NestedStruct.from_dict(nested_data, use_pydantic=True) self.assertIsInstance(result, NestedStruct) self.assertIsInstance(result.simple, SimpleStruct) self.assertEqual(result.simple.value, 11) @@ -3050,7 +3064,7 @@ class NestedStruct(csp.Struct): # 3. Test validation errors with self.assertRaises(ValidationError) as exc_info: - TypeAdapter(SimpleStruct).validate_python({"value": "not an integer", "scores": [1.1, 2.2, "invalid"]}) + SimpleStruct.from_dict({"value": "not an integer", "scores": [1.1, 2.2, "invalid"]}, use_pydantic=True) self.assertIn("Input should be a valid integer", str(exc_info.exception)) # 4. Test with complex types @@ -3063,7 +3077,7 @@ class ComplexStruct(csp.Struct): "dates": ["2023-01-01", "2023-01-02"], # strings should convert to datetime "mapping": {"a": "1.1", "b": 2.2}, # mixed types should convert to float } - result = TypeAdapter(ComplexStruct).validate_python(complex_data) + result = ComplexStruct.from_dict(complex_data, use_pydantic=True) self.assertIsInstance(result.dates[0], datetime) self.assertEqual(result.mapping, {"a": 1.1, "b": 2.2}) @@ -3077,7 +3091,7 @@ class EnumStruct(csp.Struct): enum_list: List[MyEnum] enum_data = {"enum_field": "A", "enum_list": ["A", "B", "A"]} - result = TypeAdapter(EnumStruct).validate_python(enum_data) + result = EnumStruct.from_dict(enum_data, use_pydantic=True) self.assertEqual(result.enum_field, MyEnum.A) self.assertEqual(result.enum_list, [MyEnum.A, MyEnum.B, MyEnum.A]) @@ -3095,7 +3109,7 @@ class StructWithDummy(csp.Struct): val = DummyBlankClass() struct_as_dict = dict(x=12, y=val, z=[val], z1={val: val}, z2=None) - new_struct = TypeAdapter(StructWithDummy).validate_python(struct_as_dict) + new_struct = StructWithDummy.from_dict(struct_as_dict, use_pydantic=True) self.assertTrue(new_struct.y is val) self.assertTrue(new_struct.z[0] is val) self.assertTrue(new_struct.z1[val] is val) @@ -3113,7 +3127,7 @@ class StructWithDummy(csp.Struct): z3=z3_val, z4=z3_val, ) - new_struct = TypeAdapter(StructWithDummy).validate_python(struct_as_dict) + new_struct = StructWithDummy.from_dict(struct_as_dict, use_pydantic=True) self.assertTrue(new_struct.y is val) self.assertTrue(new_struct.z[0] is val) self.assertTrue(new_struct.z1[val] is val) @@ -3206,7 +3220,7 @@ class ProjectStruct(csp.Struct): } # 1. Test validation - result = TypeAdapter(ProjectStruct).validate_python(project_data) + result = ProjectStruct.from_dict(project_data, use_pydantic=True) # Verify the structure self.assertIsInstance(result, ProjectStruct) @@ -3249,14 +3263,14 @@ class ProjectStruct(csp.Struct): invalid_data["task_statuses"][99] = [] # Invalid enum value with self.assertRaises(ValidationError) as exc_info: - TypeAdapter(ProjectStruct).validate_python(invalid_data) + ProjectStruct.from_dict(invalid_data, use_pydantic=True) # 4. Test validation errors with invalid nested types invalid_task_data = project_data.copy() invalid_task_data["task_statuses"][1][0]["metadata"]["priority"] = 99 # Invalid priority with self.assertRaises(ValidationError) as exc_info: - TypeAdapter(ProjectStruct).validate_python(invalid_task_data) + ProjectStruct.from_dict(invalid_task_data, use_pydantic=True) def test_pydantic_models_with_csp_structs(self): """Test Pydantic BaseModels containing CSP Structs as attributes""" @@ -3405,13 +3419,13 @@ class OuterStruct(csp.Struct): inner: Annotated[InnerStruct, WrapValidator(struct_validator)] # Test simple value validation - inner = TypeAdapter(InnerStruct).validate_python({"value": "21"}) + inner = InnerStruct.from_dict({"value": "21"}, use_pydantic=True) self.assertEqual(inner.value, 42) # "21" -> 21 -> 42 self.assertEqual(inner.description, "default") self.assertFalse(hasattr(inner, "z")) # test existing instance - inner_new = TypeAdapter(InnerStruct).validate_python(inner) + inner_new = InnerStruct.from_dict(inner, use_pydantic=True) self.assertTrue(inner is inner_new) # No revalidation self.assertEqual(inner_new.value, 42) @@ -3419,26 +3433,26 @@ class OuterStruct(csp.Struct): # Test validation with invalid value in existing instance inner.value = -5 # Set invalid value # No revalidation, no error - self.assertTrue(inner is TypeAdapter(InnerStruct).validate_python(inner)) + self.assertTrue(inner is InnerStruct.from_dict(inner, use_pydantic=True)) with self.assertRaises(ValidationError) as cm: - TypeAdapter(InnerStruct).validate_python(inner.to_dict()) + InnerStruct.from_dict(inner.to_dict(), use_pydantic=True) self.assertIn("value must be positive", str(cm.exception)) # Test simple value validation - inner = TypeAdapter(InnerStruct).validate_python({"value": "21", "z": 17}) + inner = InnerStruct.from_dict({"value": "21", "z": 17}, use_pydantic=True) self.assertEqual(inner.value, 42) # "21" -> 21 -> 42 self.assertEqual(inner.description, "default") self.assertEqual(inner.z, 17) # Test struct validation with expansion - outer = TypeAdapter(OuterStruct).validate_python({"name": "test", "inner": {"value": 10, "z": 12}}) + outer = OuterStruct.from_dict({"name": "test", "inner": {"value": 10, "z": 12}}, use_pydantic=True) self.assertEqual(outer.inner.value, 20) # 10 -> 20 (doubled) self.assertEqual(outer.inner.description, "auto_generated") self.assertEqual(outer.inner.z, 12) # Test normal full structure still works - outer = TypeAdapter(OuterStruct).validate_python( - {"name": "test", "inner": {"value": "5", "description": "custom"}} + outer = OuterStruct.from_dict( + {"name": "test", "inner": {"value": "5", "description": "custom"}}, use_pydantic=True ) self.assertEqual(outer.inner.value, 10) # "5" -> 5 -> 10 (doubled) self.assertEqual(outer.inner.description, "custom") @@ -3457,50 +3471,55 @@ class MetricStruct(csp.Struct): tags: Union[str, List[str]] = "default" # Test with different value types - metric1 = TypeAdapter(MetricStruct).validate_python( + metric1 = MetricStruct.from_dict( { "value": 42, # int - } + }, + use_pydantic=True, ) self.assertEqual(metric1.value, 42) self.assertIsNone(metric1.name) self.assertEqual(metric1.tags, "default") - metric2 = TypeAdapter(MetricStruct).validate_python( + metric2 = MetricStruct.from_dict( { "value": 42.5, # float "name": "test", "tags": ["tag1", "tag2"], - } + }, + use_pydantic=True, ) self.assertEqual(metric2.value, 42.5) self.assertEqual(metric2.name, "test") self.assertEqual(metric2.tags, ["tag1", "tag2"]) # Test with string that should convert to float - metric3 = TypeAdapter(MetricStruct).validate_python( + metric3 = MetricStruct.from_dict( { "value": "42.5", # should convert to float "tags": "single_tag", # single string tag - } + }, + use_pydantic=True, ) self.assertEqual(metric3.value, 42.5) self.assertEqual(metric3.tags, "single_tag") # Test validation error with invalid type with self.assertRaises(ValidationError) as exc_info: - TypeAdapter(MetricStruct).validate_python( + MetricStruct.from_dict( { "value": "not a number", - } + }, + use_pydantic=True, ) self.assertIn("Input should be a valid number", str(exc_info.exception)) # Test with string that should convert to float - metric3 = TypeAdapter(MetricStruct).validate_python( + metric3 = MetricStruct.from_dict( { "tags": "single_tag" # single string tag - } + }, + use_pydantic=True, ) self.assertFalse(hasattr(metric3, "value")) self.assertEqual(metric3.tags, "single_tag") @@ -3527,7 +3546,7 @@ class DataPoint(csp.Struct): # Test with MetricStruct metric_data = {"id": "metric-1", "data": {"value": 42.5, "unit": "celsius"}} - result = TypeAdapter(DataPoint).validate_python(metric_data) + result = DataPoint.from_dict(metric_data, use_pydantic=True) self.assertIsInstance(result.data, MetricStruct) self.assertEqual(result.data.value, 42.5) self.assertEqual(result.data.unit, "celsius") @@ -3541,14 +3560,14 @@ class DataPoint(csp.Struct): {"name": "previous_event", "timestamp": "2023-01-01T11:00:00"}, ], } - result = TypeAdapter(DataPoint).validate_python(event_data) + result = DataPoint.from_dict(event_data, use_pydantic=True) self.assertIsInstance(result.data, EventStruct) self.assertEqual(result.data.name, "system_start") self.assertIsInstance(result.history[0], MetricStruct) self.assertIsInstance(result.history[1], EventStruct) # Test serialization and deserialization - result = TypeAdapter(DataPoint).validate_python(event_data) + result = DataPoint.from_dict(event_data, use_pydantic=True) json_data = result.to_json() restored = TypeAdapter(DataPoint).validate_json(json_data) @@ -3589,14 +3608,14 @@ class DataPoint(csp.Struct): "precision": 1, # specific to TemperatureMetric }, } - result = TypeAdapter(DataPoint).validate_python(temp_data) + result = DataPoint.from_dict(temp_data, use_pydantic=True) self.assertIsInstance(result.metric, TemperatureMetric) # Should be TemperatureMetric, not BaseMetric self.assertEqual(result.metric.unit, "celsius") self.assertEqual(result.metric.precision, 1) # Test with PressureMetric data pressure_data = {"id": "pressure-1", "metric": {"name": "pressure", "value": 101.325, "altitude": 0.0}} - result = TypeAdapter(DataPoint).validate_python(pressure_data) + result = DataPoint.from_dict(pressure_data, use_pydantic=True) self.assertIsInstance(result.metric, PressureMetric) # Should be PressureMetric, not BaseMetric self.assertEqual(result.metric.unit, "pascal") self.assertEqual(result.metric.altitude, 0.0) @@ -3617,7 +3636,7 @@ class DataPoint(csp.Struct): }, ], } - result = TypeAdapter(DataPoint).validate_python(mixed_data) + result = DataPoint.from_dict(mixed_data, use_pydantic=True) self.assertIsInstance(result.metric, BaseMetric) # Should be base metric self.assertIsInstance(result.history[0], TemperatureMetric) # Should be temperature self.assertIsInstance(result.history[1], PressureMetric) # Should be pressure @@ -3782,7 +3801,7 @@ class NestedStruct(csp.Struct): self.assertEqual(enum_as_enum.name, enum_as_str) self.assertEqual( - nested, TypeAdapter(NestedStruct).validate_python(TypeAdapter(NestedStruct).dump_python(nested)) + nested, NestedStruct.from_dict(TypeAdapter(NestedStruct).dump_python(nested), use_pydantic=True) ) json_native = nested.to_json() @@ -3802,7 +3821,7 @@ class NPStruct(csp.Struct): NPStruct(arr=np.array([1, 3, "ab"])) # No error, even though the types are wrong with self.assertRaises(ValidationError) as exc_info: - TypeAdapter(NPStruct).validate_python(dict(arr=[1, 3, "ab"])) + NPStruct.from_dict(dict(arr=[1, 3, "ab"]), use_pydantic=True) self.assertIn("could not convert string to float", str(exc_info.exception)) # We should be able to generate the json_schema TypeAdapter(NPStruct).json_schema() @@ -3851,7 +3870,7 @@ class DataPoint(csp.Struct): }, } - result = TypeAdapter(DataPoint).validate_python(metric_data) + result = DataPoint.from_dict(metric_data, use_pydantic=True) # Verify private fields are properly set including inherited ones self.assertEqual(result._last_updated, datetime(2023, 1, 1, 12, 0)) @@ -3893,7 +3912,7 @@ class DataPoint(csp.Struct): }, } - result = TypeAdapter(DataPoint).validate_python(event_data) + result = DataPoint.from_dict(event_data, use_pydantic=True) # Verify private fields are set but excluded from serialization self.assertEqual(result._last_updated, datetime(2023, 1, 1, 12, 0)) @@ -3903,6 +3922,174 @@ class DataPoint(csp.Struct): self.assertNotIn("_last_updated", json_data) self.assertNotIn("_source", json_data["data"]) + def test_literal_types_validation(self): + """Test that Literal type annotations correctly validate input values in CSP Structs""" + + # Define a simple class with various Literal types + class StructWithLiterals(csp.Struct): + # String literals + color: Literal["red", "green", "blue"] + # Integer literals + size: Literal[1, 2, 3] + # Mixed type literals + status: Literal["on", "off", 0, 1, True, False] + # Optional literal with default + mode: Optional[Literal["fast", "slow"]] = "fast" + + # Test valid assignments + s1 = StructWithLiterals(color="red", size=2, status="on") + self.assertEqual(s1.color, "red") + self.assertEqual(s1.size, 2) + self.assertEqual(s1.status, "on") + self.assertEqual(s1.mode, "fast") # Default value + + s2 = StructWithLiterals.from_dict(dict(color="blue", size=1, status=True, mode="slow")) + s2_dump = s2.to_json() + s2_looped = TypeAdapter(StructWithLiterals).validate_json(s2_dump) + self.assertEqual(s2, s2_looped) + s2_dict = s2.to_dict() + s2_looped_dict = s2.from_dict(s2_dict) + self.assertEqual(s2_looped_dict, s2) + + # Invalid color, but from_dict still accepts + StructWithLiterals.from_dict(dict(color="yellow", size=1, status="on"), use_pydantic=False) + + # Invalid size but from_dict still accepts + StructWithLiterals.from_dict(dict(color="red", size=4, status="on"), use_pydantic=False) + + # Invalid status but from_dict still accepts + StructWithLiterals.from_dict(dict(color="red", size=1, status="standby"), use_pydantic=False) + + # Invalid mode but from_dict still accepts + StructWithLiterals.from_dict(dict(color="red", size=1, mode=12), use_pydantic=False) + + # Invalid size and since the literals are all the same type + # If we give an incorrect type, we catch the error + with self.assertRaises(ValueError) as exc_info: + StructWithLiterals.from_dict(dict(color="red", size="adasd", mode=12), use_pydantic=False) + self.assertIn("Expected type received ", str(exc_info.exception)) + + # Test valid values + result = StructWithLiterals.from_dict({"color": "green", "size": 3, "status": 0}, use_pydantic=True) + self.assertEqual(result.color, "green") + self.assertEqual(result.size, 3) + self.assertEqual(result.status, 0) + + # Test invalid color with Pydantic validation + with self.assertRaises(ValidationError) as exc_info: + StructWithLiterals.from_dict({"color": "yellow", "size": 1, "status": "on"}, use_pydantic=True) + self.assertIn("1 validation error for", str(exc_info.exception)) + self.assertIn("color", str(exc_info.exception)) + + # Test invalid size with Pydantic validation + with self.assertRaises(ValidationError) as exc_info: + StructWithLiterals.from_dict({"color": "red", "size": 4, "status": "on"}, use_pydantic=True) + self.assertIn("1 validation error for", str(exc_info.exception)) + self.assertIn("size", str(exc_info.exception)) + + # Test invalid status with Pydantic validation + with self.assertRaises(ValidationError) as exc_info: + StructWithLiterals.from_dict({"color": "red", "size": 1, "status": "standby"}, use_pydantic=True) + self.assertIn("1 validation error for", str(exc_info.exception)) + self.assertIn("status", str(exc_info.exception)) + + # Test invalid mode with Pydantic validation + with self.assertRaises(ValidationError) as exc_info: + StructWithLiterals.from_dict( + {"color": "red", "size": 1, "status": "on", "mode": "medium"}, use_pydantic=True + ) + self.assertIn("1 validation error for", str(exc_info.exception)) + self.assertIn("mode", str(exc_info.exception)) + # Test serialization and deserialization preserves literal values + result = StructWithLiterals.from_dict({"color": "green", "size": 3, "status": 0}, use_pydantic=True) + json_data = TypeAdapter(StructWithLiterals).dump_json(result) + restored = TypeAdapter(StructWithLiterals).validate_json(json_data) + self.assertEqual(restored.color, "green") + self.assertEqual(restored.size, 3) + self.assertEqual(restored.status, 0) + + def test_pipe_operator_types(self): + """Test using the pipe operator for union types in Python 3.10+""" + if sys.version_info >= (3, 10): # Only run on Python 3.10+ + # Define a class using various pipe operator combinations + class PipeTypesConfig(csp.Struct): + # Basic primitive types with pipe + id_field: str | int + # Pipe with None (similar to Optional) + description: str | None = None + # Multiple types with pipe + value: str | int | float | bool + # Container with pipe + tags: List[str] | Dict[str, str] | None = None + # Pipe with literal for comparison + status: Literal["active", "inactive"] | None = "active" + + # Test with string ID + p1 = PipeTypesConfig(id_field="abc123", value="test_value") + self.assertEqual(p1.id_field, "abc123") + self.assertIsNone(p1.description) + self.assertEqual(p1.value, "test_value") + self.assertIsNone(p1.tags) + self.assertEqual(p1.status, "active") + + # Test with integer ID + p2 = PipeTypesConfig(id_field=42, value=3.14, description="A config") + self.assertEqual(p2.id_field, 42) + self.assertEqual(p2.description, "A config") + self.assertEqual(p2.value, 3.14) + + # Test with boolean value and list tags + p3 = PipeTypesConfig(id_field=99, value=True, tags=["tag1", "tag2"]) + self.assertEqual(p3.id_field, 99) + self.assertEqual(p3.value, True) + self.assertEqual(p3.tags, ["tag1", "tag2"]) + + # Test with dict tags + p4 = PipeTypesConfig(id_field="xyz", value=42, tags={"key1": "val1", "key2": "val2"}) + self.assertEqual(p4.id_field, "xyz") + self.assertEqual(p4.value, 42) + self.assertEqual(p4.tags, {"key1": "val1", "key2": "val2"}) + + # Test direct assignment + p5 = PipeTypesConfig(id_field="test", value=1) + p5.id_field = 100 + p5.value = False + p5.tags = ["new", "tags"] + p5.description = "Updated" + self.assertEqual(p5.id_field, 100) + self.assertEqual(p5.value, False) + self.assertEqual(p5.tags, ["new", "tags"]) + self.assertEqual(p5.description, "Updated") + + # Test all valid types + valid_cases = [ + {"id_field": "string_id", "value": "string_value"}, + {"id_field": 42, "value": 123}, + {"id_field": "mixed", "value": 3.14}, + {"id_field": 999, "value": True}, + {"id_field": "with_desc", "value": 1, "description": "Description"}, + {"id_field": "with_dict", "value": 1, "tags": None}, + ] + + for case in valid_cases: + for use_pydantic in [True, False]: + result = PipeTypesConfig.from_dict(case, use_pydantic=use_pydantic) + # use the other route to get back the result + result_to_dict_loop = PipeTypesConfig.from_dict(result.to_dict(), use_pydantic=not use_pydantic) + self.assertEqual(result, result_to_dict_loop) + + # Test invalid values + invalid_cases = [ + {"id_field": 3.14, "value": 1}, # Float for id_field + {"id_field": None, "value": 1}, # None for required id_field + {"id_field": "test", "value": {}}, # Dict for value + {"id_field": "test", "value": None}, # None for required value + {"id_field": "test", "value": 1, "status": "unknown"}, # Invalid literal + ] + for case in invalid_cases: + with self.assertRaises(ValidationError): + PipeTypesConfig.from_dict(case, use_pydantic=True) + if __name__ == "__main__": unittest.main() diff --git a/csp/tests/impl/types/test_pydantic_types.py b/csp/tests/impl/types/test_pydantic_types.py index 945919d62..9ef2fb222 100644 --- a/csp/tests/impl/types/test_pydantic_types.py +++ b/csp/tests/impl/types/test_pydantic_types.py @@ -1,6 +1,6 @@ import sys from inspect import isclass -from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_origin +from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union, get_args, get_origin from unittest import TestCase import csp @@ -160,3 +160,12 @@ def test_force_tvars(self): self.assertAnnotationsEqual( adjust_annotations(CspTypeVarType[T], forced_tvars={"T": float}), Union[Type[float], Type[int]] ) + + def test_literal(self): + self.assertAnnotationsEqual(adjust_annotations(Literal["a", "b"]), Literal["a", "b"]) + self.assertAnnotationsEqual( + adjust_annotations(Literal["a", "b"], make_optional=True), Optional[Literal["a", "b"]] + ) + self.assertAnnotationsEqual(adjust_annotations(Literal[123, "a"]), Literal[123, "a"]) + self.assertAnnotationsEqual(adjust_annotations(Literal[123, None]), Literal[123, None]) + self.assertAnnotationsEqual(adjust_annotations(ts[Literal[123, None]]), ts[Literal[123, None]]) diff --git a/csp/tests/test_type_checking.py b/csp/tests/test_type_checking.py index 37a8c4d1e..f2a1d1c84 100644 --- a/csp/tests/test_type_checking.py +++ b/csp/tests/test_type_checking.py @@ -1,6 +1,7 @@ import os import pickle import re +import sys import typing import unittest from datetime import datetime, time, timedelta @@ -938,6 +939,106 @@ def test_is_callable(self): result = CspTypingUtils.is_callable(input_type) self.assertEqual(result, expected) + def test_literal_typing(self): + """Test using Literal types for type checking in CSP nodes.""" + from typing import Literal + + @csp.node + def node_with_literal(x: ts[int], choice: Literal["a", "b", "c"]) -> ts[str]: + if csp.ticked(x): + return str(choice) + + @csp.graph + def graph_with_literal(choice: Literal["a", "b", "c"]) -> ts[str]: + return csp.const(str(choice)) + + @csp.node + def dummy_node(x: ts["T"]): # to avoid pruning + if csp.ticked(x): + pass + + def graph(): + # These should work - valid literal values + dummy_node(node_with_literal(csp.const(10), "a")) + dummy_node(node_with_literal(csp.const(10), "b")) + dummy_node(node_with_literal(csp.const(10), "c")) + + graph_with_literal("a") + graph_with_literal("b") + graph_with_literal("c") + + # This should fail with invalid literal value + # But only pydantic type checking catches this + if USE_PYDANTIC: + msg = "(?s)1 validation error for node_with_literal.*choice.*" + with self.assertRaisesRegex(TypeError, msg): + dummy_node(node_with_literal(csp.const(10), "d")) + + csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) + + # Test direct graph building + csp.build_graph(graph_with_literal, "a") + + # This should fail with invalid literal value + # But only pydantic type checking catches this + if USE_PYDANTIC: + msg = "(?s)1 validation error for graph_with_literal.*choice.*" + with self.assertRaisesRegex(TypeError, msg): + csp.build_graph(graph_with_literal, "d") + + def test_union_with_pipe_operator(self): + """Test using the pipe operator for Union types in Python 3.10+.""" + if sys.version_info >= (3, 10): # pipe operator was introduced in Python 3.10 + + @csp.node + def node_with_pipe_union(x: ts[int], value: str | int | None) -> ts[str]: + if csp.ticked(x): + return str(value) if value is not None else "none" + + @csp.graph + def graph_with_pipe_union(value: str | int | None) -> ts[str]: + return csp.const(str(value) if value is not None else "none") + + @csp.node + def dummy_node(x: ts["T"]): # to avoid pruning + if csp.ticked(x): + pass + + def graph(): + # These should work - valid union types (str, int, None) + dummy_node(node_with_pipe_union(csp.const(10), "hello")) + dummy_node(node_with_pipe_union(csp.const(10), 42)) + dummy_node(node_with_pipe_union(csp.const(10), None)) + + graph_with_pipe_union("world") + graph_with_pipe_union(123) + graph_with_pipe_union(None) + + # This should fail - float is not part of the union + if USE_PYDANTIC: + # Pydantic provides a structured error message + msg = "(?s)2 validation errors for node_with_pipe_union.*value.*" + else: + # Non-Pydantic error has specific format to match + msg = r"In function node_with_pipe_union: Expected str \| int \| None for argument 'value', got .* \(float\)" + with self.assertRaisesRegex(TypeError, msg): + dummy_node(node_with_pipe_union(csp.const(10), 3.14)) + + csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) + + # Test direct graph building + csp.build_graph(graph_with_pipe_union, "test") + csp.build_graph(graph_with_pipe_union, 42) + csp.build_graph(graph_with_pipe_union, None) + + # This should fail - bool is not explicitly included in the union + if USE_PYDANTIC: + msg = "(?s)2 validation errors for graph_with_pipe_union.*value.*" + else: + msg = r"In function graph_with_pipe_union: Expected str \| int \| None for argument 'value', got .*" + with self.assertRaisesRegex(TypeError, msg): + csp.build_graph(graph_with_pipe_union, 3.14) + if __name__ == "__main__": unittest.main()