diff --git a/src/power_grid_model_ds/_core/model/grids/base.py b/src/power_grid_model_ds/_core/model/grids/base.py index 2082e9b..2923bdb 100644 --- a/src/power_grid_model_ds/_core/model/grids/base.py +++ b/src/power_grid_model_ds/_core/model/grids/base.py @@ -379,18 +379,24 @@ def cache(self, cache_dir: Path, cache_name: str, compress: bool = True): ) return save_grid_to_pickle(self, cache_dir=cache_dir, cache_name=cache_name, compress=compress) - def serialize(self, path: Path, **kwargs) -> Path: + def serialize(self, path: Path, strict: bool = True, **kwargs) -> Path: """Serialize the grid. Args: path: Destination file path to write JSON to. + strict: Whether to raise an error if the grid object cannot be fully serialized. **kwargs: Additional keyword arguments forwarded to ``json.dump`` Returns: Path: The path where the file was saved. """ - return serialize_to_json(grid=self, path=path, strict=True, **kwargs) + return serialize_to_json(grid=self, path=path, strict=strict, **kwargs) @classmethod - def deserialize(cls: Type[Self], path: Path) -> Self: - """Deserialize the grid.""" - return deserialize_from_json(path=path, target_grid_class=cls) + def deserialize(cls: Type[Self], path: Path, strict: bool = True) -> Self: + """Deserialize the grid. + + Args: + path: Source file path to read JSON from. + strict: Whether to raise an error if the grid object cannot be fully restored. + """ + return deserialize_from_json(path=path, target_grid_class=cls, strict=strict) diff --git a/src/power_grid_model_ds/_core/model/grids/serialization/errors.py b/src/power_grid_model_ds/_core/model/grids/serialization/errors.py new file mode 100644 index 0000000..a07e66f --- /dev/null +++ b/src/power_grid_model_ds/_core/model/grids/serialization/errors.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Contributors to the Power Grid Model project +# +# SPDX-License-Identifier: MPL-2.0 + + +class JSONSerializationError(Exception): + """Exception raised for errors during JSON serialization of grid attributes.""" + + +class JSONDeserializationError(Exception): + """Exception raised for errors during JSON deserialization of grid attributes.""" diff --git a/src/power_grid_model_ds/_core/model/grids/serialization/json.py b/src/power_grid_model_ds/_core/model/grids/serialization/json.py index 8796fce..83893bc 100644 --- a/src/power_grid_model_ds/_core/model/grids/serialization/json.py +++ b/src/power_grid_model_ds/_core/model/grids/serialization/json.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, TypeVar from power_grid_model_ds._core.model.arrays.base.array import FancyArray +from power_grid_model_ds._core.model.grids.serialization.errors import JSONDeserializationError, JSONSerializationError if TYPE_CHECKING: # Import only for type checking to avoid circular imports at runtime @@ -49,6 +50,9 @@ def serialize_to_json(grid: G, path: Path, strict: bool = True, **kwargs) -> Pat serialized_data[field.name] = _serialize_array(field_value) continue + if hasattr(field_value, "to_dict"): + field_value = field_value.to_dict() + if _is_serializable(field_value, strict): serialized_data[field.name] = field_value @@ -61,12 +65,13 @@ def serialize_to_json(grid: G, path: Path, strict: bool = True, **kwargs) -> Pat return path -def deserialize_from_json(path: Path, target_grid_class: type[G]) -> G: +def deserialize_from_json(path: Path, target_grid_class: type[G], strict: bool = True) -> G: """Load a Grid object from JSON format with cross-type loading support. Args: path: The file path to load from target_grid_class: Grid class to load into. + strict: Whether to raise an error if the grid object cannot be fully restored. Returns: Grid: The deserialized Grid object of the specified target class @@ -75,13 +80,13 @@ def deserialize_from_json(path: Path, target_grid_class: type[G]) -> G: json_data = json.load(f) grid = target_grid_class.empty() - _restore_grid_values(grid, json_data["data"]) + _restore_grid_values(grid, json_data["data"], strict=strict) graph_class = grid.graphs.__class__ grid.graphs = graph_class.from_arrays(grid) return grid -def _restore_grid_values(grid: G, json_data: dict) -> None: +def _restore_grid_values(grid: G, json_data: dict, strict: bool) -> None: """Restore arrays to the grid.""" for attr_name, attr_values in json_data.items(): if not hasattr(grid, attr_name): @@ -94,9 +99,19 @@ def _restore_grid_values(grid: G, json_data: dict) -> None: array = _deserialize_array(array_data=attr_values, array_class=attr_class) setattr(grid, attr_name, array) continue + if hasattr(grid_attr, "from_dict"): + attr_value = grid_attr.from_dict(attr_values) + setattr(grid, attr_name, attr_value) + continue - # load other values - setattr(grid, attr_name, attr_class(attr_values)) + try: + setattr(grid, attr_name, attr_class(attr_values)) + except TypeError as error: + msg = f"Failed to set attribute '{attr_name}' on grid of type '{grid.__class__.__name__}'." + if strict: + msg += " Set strict=False to skip it or add a .from_dict() class method to the attribute's class." + raise JSONDeserializationError(msg) from error + logger.warning(msg) def _serialize_array(array: FancyArray) -> list[dict[str, Any]]: @@ -132,9 +147,10 @@ def _is_serializable(value: Any, strict: bool) -> bool: try: json.dumps(value) except TypeError as error: - msg = f"Failed to serialize '{value}'. You can set strict=False to ignore this attribute." + msg = f"Failed to serialize '{value.__class__.__name__}'. " if strict: - raise TypeError(msg) from error + msg += "Set strict=False to skip this attribute or add a .to_dict() method to the attribute's class." + raise JSONSerializationError(msg) from error logger.warning(msg) return False return True diff --git a/tests/unit/model/grids/serialization/test_json.py b/tests/unit/model/grids/serialization/test_json.py index 866150a..64868db 100644 --- a/tests/unit/model/grids/serialization/test_json.py +++ b/tests/unit/model/grids/serialization/test_json.py @@ -16,6 +16,7 @@ from power_grid_model_ds import Grid, PowerGridModelInterface from power_grid_model_ds._core.model.arrays.base.array import FancyArray from power_grid_model_ds._core.model.containers.helpers import container_equal +from power_grid_model_ds._core.model.grids.serialization.errors import JSONDeserializationError, JSONSerializationError from power_grid_model_ds._core.utils.misc import array_equal_with_nan from power_grid_model_ds.arrays import LineArray from power_grid_model_ds.arrays import NodeArray as BaseNodeArray @@ -52,7 +53,7 @@ class NonSerializableExtension: """A non-serializable extension class""" def __init__(self): - self.data = "non_serializable" + self.data = "the data" @dataclass @@ -62,6 +63,29 @@ class GridWithNonSerializableExtension(Grid): non_serializable: NonSerializableExtension = NonSerializableExtension() +class SerializableExtension: + """A non-serializable extension class""" + + def __init__(self): + self.data = "the data" + + def to_dict(self): + return {"data": self.data} + + @classmethod + def from_dict(cls, data: dict[str, str]): + instance = cls() + instance.data = data["data"] + return instance + + +@dataclass +class GridWithSerializableExtension(Grid): + """Grid with a non-serializable extension attribute""" + + serializable: SerializableExtension = SerializableExtension() + + @pytest.fixture def basic_grid(): """Basic grid fixture""" @@ -225,7 +249,7 @@ class GridWithCustomArray(Grid): class TestDeserialize: def test_deserialize(self, tmp_path: Path): - path = tmp_path / "json_data.json" + path = tmp_path / "grid.json" data = {"node": [{"id": 1, "u_rated": 10000}, {"id": 2, "u_rated": 20000}]} @@ -239,7 +263,7 @@ def test_deserialize(self, tmp_path: Path): assert grid.node.u_rated.tolist() == [10000, 20000] def test_extended_grid(self, tmp_path: Path, extended_grid: ExtendedGrid): - extended_data = { + data = { "node": [ {"id": 1, "u_rated": 10000, "analysis_flag": 42}, {"id": 2, "u_rated": 10000, "analysis_flag": 43}, @@ -247,76 +271,115 @@ def test_extended_grid(self, tmp_path: Path, extended_grid: ExtendedGrid): "value_extension": 4.2, } - path = tmp_path / "json_data.json" + path = tmp_path / "grid.json" with open(path, "w", encoding="utf-8") as f: - json.dump({"data": extended_data}, f) + json.dump({"data": data}, f) grid = ExtendedGrid.deserialize(path) assert grid.value_extension == 4.2 assert grid.node.analysis_flag.tolist() == [42, 43] def test_unexpected_field(self, tmp_path: Path): - path = tmp_path / "incompatible.json" + path = tmp_path / "grid.json" # Create incompatible JSON data - incompatible_data = { + data = { "node": [{"id": 1, "u_rated": 10000}, {"id": 2, "u_rated": 10000}], "unexpected_field": "unexpected_value", } # Write incompatible data to file with open(path, "w", encoding="utf-8") as f: - json.dump({"data": incompatible_data}, f) + json.dump({"data": data}, f) grid = Grid.deserialize(path) assert not hasattr(grid, "unexpected_field") def test_missing_defaulted_array_field(self, tmp_path: Path): - path = tmp_path / "missing_array.json" + path = tmp_path / "grid.json" # Node data does not contain 'id' field, but there is a default - missing_array_data = { + data = { "node": [{"u_rated": 10000}, {"u_rated": 10000}], } # Write data to file with open(path, "w", encoding="utf-8") as f: - json.dump({"data": missing_array_data}, f) + json.dump({"data": data}, f) Grid.deserialize(path) def test_missing_required_array_field(self, tmp_path: Path): - path = tmp_path / "missing_array.json" + path = tmp_path / "grid.json" # Node data does not contain 'id' field, but there is a default - missing_array_data = { + data = { "node": [{"id": 10000}, {"id": 123}], } # Write data to file with open(path, "w", encoding="utf-8") as f: - json.dump({"data": missing_array_data}, f) + json.dump({"data": data}, f) with pytest.raises(ValueError): Grid.deserialize(path) def test_some_records_miss_data(self, tmp_path): - path = tmp_path / "incomplete_array.json" - incomplete_data = { - "node": [{"id": 1, "u_rated": 10000}, {"u_rated": 10000}, {"id": 3}], + path = tmp_path / "grid.json" + data = { + "node": [ + {"id": 1, "u_rated": 10000}, + {"u_rated": 10000}, + {"id": 3}, + ], } with open(path, "w", encoding="utf-8") as f: - json.dump({"data": incomplete_data}, f) + json.dump({"data": data}, f) with pytest.raises(ValueError): Grid.deserialize(path) def test_non_serializable_extension(self, tmp_path: Path): - path = tmp_path / "non_serializable.json" + path = tmp_path / "grid.json" grid = GridWithNonSerializableExtension.empty() grid.non_serializable = NonSerializableExtension() - with pytest.raises(TypeError): + with pytest.raises(JSONSerializationError): grid.serialize(path) + + def test_deserialize_non_serializable_extension(self, tmp_path: Path): + path = tmp_path / "grid.json" + + data = {"non_serializable": {"data": "some_data"}} + with open(path, "w", encoding="utf-8") as f: + json.dump({"data": data}, f) + + with pytest.raises(JSONDeserializationError): + GridWithNonSerializableExtension.deserialize(path) + + def test_deserialize_non_serializable_extension_non_strict(self, tmp_path: Path): + path = tmp_path / "grid.json" + + data = {"non_serializable": {"data": "some_data"}} + with open(path, "w", encoding="utf-8") as f: + json.dump({"data": data}, f) + + grid = GridWithNonSerializableExtension.deserialize(path, strict=False) + assert grid.non_serializable.data == "the data" # Default value + + def test_serializable_extension(self, tmp_path: Path): + path = tmp_path / "grid.json" + grid = GridWithSerializableExtension.empty() + grid.serialize(path) + + def test_deserialize_serializable_extension(self, tmp_path: Path): + path = tmp_path / "grid.json" + + data = {"serializable": {"data": "some_data"}} + with open(path, "w", encoding="utf-8") as f: + json.dump({"data": data}, f) + + grid = GridWithSerializableExtension.deserialize(path) + assert grid.serializable.data == "some_data"