Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions csp/impl/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def __new__(cls, name, bases, dct):
# Lists need to be normalized too as potentially we need to add a boolean flag to use FastList
if v == FastList:
raise TypeError(f"{v} annotation is not supported without args")
if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v):
if (
CspTypingUtils.is_generic_container(v)
or CspTypingUtils.is_union_type(v)
or CspTypingUtils.is_literal_type(v)
):
actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v)
if CspTypingUtils.is_generic_container(actual_type):
raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]")
Expand Down Expand Up @@ -147,6 +151,17 @@ def serializer(val, handler):


class Struct(_csptypesimpl.PyStruct, metaclass=StructMeta):
@classmethod
def type_adapter(cls):
# Late import to avoid autogen issues
from pydantic import TypeAdapter

internal_type_adapter = getattr(cls, "_pydantic_type_adapter", None)
if internal_type_adapter:
return internal_type_adapter
cls._pydantic_type_adapter = TypeAdapter(cls)
return cls._pydantic_type_adapter

@classmethod
def metadata(cls, typed=False):
if typed:
Expand Down Expand Up @@ -191,7 +206,8 @@ def _obj_from_python(cls, json, obj_type):
if CspTypingUtils.is_generic_container(obj_type):
if CspTypingUtils.get_origin(obj_type) in (typing.List, typing.Set, typing.Tuple, FastList):
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
(expected_item_type,) = obj_type.__args__
# We only take the first item, so like for a Tuple, we would ignore arguments after
expected_item_type = obj_type.__args__[0]
return_type = list if isinstance(return_type, list) else return_type
return return_type(cls._obj_from_python(v, expected_item_type) for v in json)
elif CspTypingUtils.get_origin(obj_type) is typing.Dict:
Expand All @@ -206,6 +222,13 @@ def _obj_from_python(cls, json, obj_type):
return json
else:
raise NotImplementedError(f"Can not deserialize {obj_type} from json")
elif CspTypingUtils.is_union_type(obj_type):
return json ## no checks, just let it through
elif CspTypingUtils.is_literal_type(obj_type):
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
if isinstance(json, return_type):
return json
raise ValueError(f"Expected type {return_type} received {json.__class__}")
elif issubclass(obj_type, Struct):
if not isinstance(json, dict):
raise TypeError("Representation of struct as json is expected to be of dict type")
Expand All @@ -223,7 +246,9 @@ def _obj_from_python(cls, json, obj_type):
return obj_type(json)

@classmethod
def from_dict(cls, json: dict):
def from_dict(cls, json: dict, use_pydantic: bool = False):
if use_pydantic:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than an arg can we base this on whether we are using pydantic type checking ( which defaults to true now )?
My only concern is speed. If we can do some timing tests and pydantic is as fast as the current code, I would just default it to use pydantic ( and eventually remove the old impl )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old way is a bit permissive and there will be some inconsistencies that might be annoying to fix, so I figured itd be easier to make it a flag for now. But eventually would like to switch to pydantic if its performant enough

return cls.type_adapter().validate_python(json)
return cls._obj_from_python(json, cls)

def to_dict_depr(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have an equivalent pydantic to_dict?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

Copy link
Collaborator

@arhamchopra arhamchopra Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robambalu Should we just remove to_dict_depr now?

Expand Down
24 changes: 12 additions & 12 deletions csp/impl/types/container_type_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,21 @@ def normalized_type_to_actual_python_type(cls, typ, level=0):
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True]
if origin is typing.List and level == 0:
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)]
if origin is typing.Literal:
# Import here to prevent circular import
from csp.impl.types.instantiation_type_resolver import UpcastRegistry

args = typing.get_args(typ)
typ = type(args[0])
for arg in args[1:]:
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
if typ:
return typ
else:
return object
return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ)
elif CspTypingUtils.is_union_type(typ):
return object
elif CspTypingUtils.is_literal_type(typ):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize we just moved code here, but looks like we can probably use normalized_type_to_actual_python_type here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But literal takes instances of types, not types

(so like Literal[7] versus Literal[int]) and normalized_type_to_actual_python_type doesnt seem to actually pull the true base class out except in this case for Literal

# Import here to prevent circular import
from csp.impl.types.instantiation_type_resolver import UpcastRegistry

args = typing.get_args(typ)
typ = type(args[0])
for arg in args[1:]:
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
if typ:
return typ
else:
return object
else:
return typ

Expand Down
4 changes: 3 additions & 1 deletion csp/impl/types/pydantic_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import types
import typing
from typing import Any, ForwardRef, Generic, Optional, Type, TypeVar, Union, get_args, get_origin
from typing import Any, ForwardRef, Generic, Literal, Optional, Type, TypeVar, Union, get_args, get_origin

from pydantic import GetCoreSchemaHandler, ValidationInfo, ValidatorFunctionWrapHandler
from pydantic_core import CoreSchema, core_schema
Expand Down Expand Up @@ -184,6 +184,8 @@ def adjust_annotations(
return TsType[
adjust_annotations(args[0], top_level=False, in_ts=True, make_optional=False, forced_tvars=forced_tvars)
]
if origin is Literal: # for literals, we stop converting
return Optional[annotation] if make_optional else annotation
else:
try:
if origin is CspTypeVar or origin is CspTypeVarType:
Expand Down
18 changes: 8 additions & 10 deletions csp/impl/types/type_annotation_normalizer_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def visit_arg(self, node):
return node

def visit_Subscript(self, node):
# We choose to avoid parsing here
# to maintain current behavior of allowing empty lists in our types
return node

def visit_List(self, node):
Expand Down Expand Up @@ -98,17 +100,13 @@ def visit_Call(self, node):
return node

def visit_Constant(self, node):
if not self._cur_arg:
return node

if self._cur_arg:
return ast.Call(
func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()),
args=[node],
keywords=[],
)
else:
if not self._cur_arg or not isinstance(node.value, str):
return node
return ast.Call(
func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()),
args=[node],
keywords=[],
)

def visit_Str(self, node):
return self.visit_Constant(node)
25 changes: 24 additions & 1 deletion csp/impl/types/typing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@ class FastList(typing.List, typing.Generic[T]): # Need to inherit from Generic[
def __init__(self):
raise NotImplementedError("Can not init FastList class")

@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core import core_schema

# Late import to not interfere with autogen
args = typing.get_args(source_type)
if args:
inner_type = args[0]
list_schema = handler.generate_schema(typing.List[inner_type])
else:
list_schema = handler.generate_schema(typing.List)

def create_instance(raw_data, validator):
if isinstance(raw_data, FastList):
return raw_data
return validator(raw_data) # just return a list

return core_schema.no_info_wrap_validator_function(function=create_instance, schema=list_schema)


class CspTypingUtils39:
_ORIGIN_COMPAT_MAP = {list: typing.List, set: typing.Set, dict: typing.Dict, tuple: typing.Tuple}
Expand All @@ -23,7 +42,7 @@ class CspTypingUtils39:

@classmethod
def is_generic_container(cls, typ):
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ not in (typing.Union, typing.Literal)

@classmethod
def is_type_spec(cls, val):
Expand Down Expand Up @@ -56,6 +75,10 @@ def is_numpy_nd_array_type(cls, typ):
def is_union_type(cls, typ):
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Union

@classmethod
def is_literal_type(cls, typ):
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Literal

@classmethod
def is_forward_ref(cls, typ):
return isinstance(typ, typing.ForwardRef)
Expand Down
Loading
Loading