Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/s2python/common/power_forecast_element.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict
from typing_extensions import Self

from pydantic import model_validator
Expand Down Expand Up @@ -29,7 +29,7 @@ class PowerForecastElement(GenPowerForecastElement, S2MessageComponent):
def validate_values_at_most_one_per_commodity_quantity(self) -> Self:
"""Validates the power measurement values to check that there is at most 1 PowerValue per CommodityQuantity."""

has_value: dict[CommodityQuantity, bool] = {}
has_value: Dict[CommodityQuantity, bool] = {}

for value in self.power_values:
if has_value.get(value.commodity_quantity, False):
Expand Down
4 changes: 2 additions & 2 deletions src/s2python/common/power_measurement.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import List
from typing import List, Dict
from typing_extensions import Self

from pydantic import model_validator
Expand All @@ -26,7 +26,7 @@ class PowerMeasurement(GenPowerMeasurement, S2MessageComponent):
def validate_values_at_most_one_per_commodity_quantity(self) -> Self:
"""Validates the power measurement values to check that there is at most 1 PowerValue per CommodityQuantity."""

has_value: dict[CommodityQuantity, bool] = {}
has_value: Dict[CommodityQuantity, bool] = {}

for value in self.values:
if has_value.get(value.commodity_quantity, False):
Expand Down
2 changes: 1 addition & 1 deletion src/s2python/ombc/ombc_operation_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class OMBCOperationMode(GenOMBCOperationMode, S2MessageComponent):
model_config["validate_assignment"] = True

id: uuid.UUID = GenOMBCOperationMode.model_fields["id"] # type: ignore[assignment]
power_ranges: List[PowerRange] = GenOMBCOperationMode.model_fields[
power_ranges: List[PowerRange] = GenOMBCOperationMode.model_fields[ # type: ignore[reportIncompatibleVariableOverride]
"power_ranges"
] # type: ignore[assignment]
abnormal_condition_only: bool = GenOMBCOperationMode.model_fields[
Expand Down
2 changes: 1 addition & 1 deletion src/s2python/ombc/ombc_system_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class OMBCSystemDescription(GenOMBCSystemDescription, S2MessageComponent):
model_config["validate_assignment"] = True

message_id: uuid.UUID = GenOMBCSystemDescription.model_fields["message_id"] # type: ignore[assignment]
operation_modes: List[OMBCOperationMode] = GenOMBCSystemDescription.model_fields[
operation_modes: List[OMBCOperationMode] = GenOMBCSystemDescription.model_fields[ # type: ignore[reportIncompatibleVariableOverride]
"operation_modes"
] # type: ignore[assignment]
transitions: List[Transition] = GenOMBCSystemDescription.model_fields["transitions"] # type: ignore[assignment]
Expand Down
3 changes: 1 addition & 2 deletions src/s2python/s2_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ def parse_as_any_message(unparsed_message: Union[dict, str, bytes]) -> S2Message
None,
message_json,
f"Unable to parse {message_type} as an S2 message. Type unknown.",
None,
)

return TYPE_TO_MESSAGE_CLASS[message_type].model_validate(message_json)
return TYPE_TO_MESSAGE_CLASS[message_type].from_dict(message_json)

@staticmethod
def parse_as_message(
Expand Down
8 changes: 1 addition & 7 deletions src/s2python/s2_validation_error.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from dataclasses import dataclass
from typing import Union, Type, Optional

from pydantic import ValidationError
from pydantic.v1.error_wrappers import ValidationError as ValidationErrorV1
from typing import Type, Optional


@dataclass
class S2ValidationError(Exception):
class_: Optional[Type]
obj: object
msg: str
pydantic_validation_error: Union[
ValidationErrorV1, ValidationError, TypeError, None
]
35 changes: 23 additions & 12 deletions src/s2python/validate_values_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from typing_extensions import Self

from pydantic.v1.error_wrappers import display_errors # pylint: disable=no-name-in-module

from pydantic import ( # pylint: disable=no-name-in-module
BaseModel,
ValidationError,
Expand All @@ -28,25 +26,43 @@


class S2MessageComponent(BaseModel):
def __setattr__(self, name: str, value: Any) -> None:
try:
super().__setattr__(name, value)
except (ValidationError, TypeError) as e:
raise S2ValidationError(
type(self), self, "Pydantic raised a validation error.",
) from e

def to_json(self) -> str:
try:
return self.model_dump_json(by_alias=True, exclude_none=True)
except (ValidationError, TypeError) as e:
raise S2ValidationError(
type(self), self, "Pydantic raised a format validation error.", e
type(self), self, "Pydantic raised a validation error.",
) from e

def to_dict(self) -> Dict[str, Any]:
return self.model_dump()

@classmethod
def from_json(cls, json_str: str) -> Self:
gen_model = cls.model_validate_json(json_str)
try:
gen_model = cls.model_validate_json(json_str)
except (ValidationError, TypeError) as e:
raise S2ValidationError(
type(cls), cls, "Pydantic raised a validation error.",
) from e
return gen_model

@classmethod
def from_dict(cls, json_dict: Dict[str, Any]) -> Self:
gen_model = cls.model_validate(json_dict)
try:
gen_model = cls.model_validate(json_dict)
except (ValidationError, TypeError) as e:
raise S2ValidationError(
type(cls), cls, "Pydantic raised a validation error.",
) from e
return gen_model


Expand All @@ -61,9 +77,9 @@ def inner(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
else:
class_type = None

raise S2ValidationError(class_type, args, display_errors(e.errors()), e) from e # type: ignore[arg-type]
raise S2ValidationError(class_type, args, str(e)) from e
except TypeError as e:
raise S2ValidationError(None, args, str(e), e) from e
raise S2ValidationError(None, args, str(e)) from e

inner.__doc__ = f.__doc__
inner.__annotations__ = f.__annotations__
Expand All @@ -76,10 +92,5 @@ def inner(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:

def catch_and_convert_exceptions(input_class: Type[S]) -> Type[S]:
input_class.__init__ = convert_to_s2exception(input_class.__init__) # type: ignore[method-assign]
input_class.__setattr__ = convert_to_s2exception(input_class.__setattr__) # type: ignore[method-assign]
input_class.model_validate_json = convert_to_s2exception( # type: ignore[method-assign]
input_class.model_validate_json
)
input_class.model_validate = convert_to_s2exception(input_class.model_validate) # type: ignore[method-assign]

return input_class
51 changes: 51 additions & 0 deletions tests/unit/inheritance_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import datetime
import unittest
import uuid
from typing import Optional

from pydantic import Field

from s2python.frbc import FRBCStorageStatus as FRBCStorageStatusOfficial
from s2python.s2_validation_error import S2ValidationError


class FRBCStorageStatus(FRBCStorageStatusOfficial):
measurement_timestamp: Optional[datetime.datetime] = Field(
default=None, description="Timestamp when fill level was measured."
)


class InheritanceTest(unittest.TestCase):
def test__inheritance__init(self):
# Arrange / Act
frbc_storage_status = FRBCStorageStatus(message_id=uuid.uuid4(),
present_fill_level=0.0,
measurement_timestamp=None)

# Assert
self.assertIsInstance(frbc_storage_status, FRBCStorageStatus)
self.assertIsNone(frbc_storage_status.measurement_timestamp)

def test__inheritance__init_wrong(self):
# Arrange / Act / Assert
with self.assertRaises(S2ValidationError):
FRBCStorageStatus(message_id=uuid.uuid4(),
present_fill_level=0.0,
measurement_timestamp=False) # pyright: ignore [reportArgumentType]

def test__inheritance__from_json(self):
# Arrange
json_str = """
{
"message_id": "6bad8186-9ebf-4647-ac45-1c6856511a2f",
"message_type": "FRBC.StorageStatus",
"present_fill_level": 2443.939298819414,
"measurement_timestamp": "2025-01-01T00:00:00Z"
}"""

# Act
frbc_storage_status = FRBCStorageStatus.from_json(json_str)

# Assert
self.assertIsInstance(frbc_storage_status, FRBCStorageStatus)
self.assertEqual(frbc_storage_status.measurement_timestamp, datetime.datetime.fromisoformat("2025-01-01T00:00:00+00:00"))