Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/power_grid_model_ds/_core/model/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,18 +379,24 @@ def cache(self, cache_dir: Path, cache_name: str, compress: bool = True):
)
return save_grid_to_pickle(self, cache_dir=cache_dir, cache_name=cache_name, compress=compress)

def serialize(self, path: Path, **kwargs) -> Path:
def serialize(self, path: Path, strict: bool = True, **kwargs) -> Path:
"""Serialize the grid.

Args:
path: Destination file path to write JSON to.
strict: Whether to raise an error if the grid object cannot be fully serialized.
**kwargs: Additional keyword arguments forwarded to ``json.dump``
Returns:
Path: The path where the file was saved.
"""
return serialize_to_json(grid=self, path=path, strict=True, **kwargs)
return serialize_to_json(grid=self, path=path, strict=strict, **kwargs)

@classmethod
def deserialize(cls: Type[Self], path: Path) -> Self:
"""Deserialize the grid."""
return deserialize_from_json(path=path, target_grid_class=cls)
def deserialize(cls: Type[Self], path: Path, strict: bool = True) -> Self:
"""Deserialize the grid.

Args:
path: Source file path to read JSON from.
strict: Whether to raise an error if the grid object cannot be fully restored.
"""
return deserialize_from_json(path=path, target_grid_class=cls, strict=strict)
11 changes: 11 additions & 0 deletions src/power_grid_model_ds/_core/model/grids/serialization/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: Contributors to the Power Grid Model project <powergridmodel@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0


class JSONSerializationError(Exception):
"""Exception raised for errors during JSON serialization of grid attributes."""


class JSONDeserializationError(Exception):
"""Exception raised for errors during JSON deserialization of grid attributes."""
30 changes: 23 additions & 7 deletions src/power_grid_model_ds/_core/model/grids/serialization/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TYPE_CHECKING, Any, TypeVar

from power_grid_model_ds._core.model.arrays.base.array import FancyArray
from power_grid_model_ds._core.model.grids.serialization.errors import JSONDeserializationError, JSONSerializationError

if TYPE_CHECKING:
# Import only for type checking to avoid circular imports at runtime
Expand Down Expand Up @@ -49,6 +50,9 @@ def serialize_to_json(grid: G, path: Path, strict: bool = True, **kwargs) -> Pat
serialized_data[field.name] = _serialize_array(field_value)
continue

if hasattr(field_value, "to_dict"):
field_value = field_value.to_dict()

if _is_serializable(field_value, strict):
serialized_data[field.name] = field_value

Expand All @@ -61,12 +65,13 @@ def serialize_to_json(grid: G, path: Path, strict: bool = True, **kwargs) -> Pat
return path


def deserialize_from_json(path: Path, target_grid_class: type[G]) -> G:
def deserialize_from_json(path: Path, target_grid_class: type[G], strict: bool = True) -> G:
"""Load a Grid object from JSON format with cross-type loading support.

Args:
path: The file path to load from
target_grid_class: Grid class to load into.
strict: Whether to raise an error if the grid object cannot be fully restored.

Returns:
Grid: The deserialized Grid object of the specified target class
Expand All @@ -75,13 +80,13 @@ def deserialize_from_json(path: Path, target_grid_class: type[G]) -> G:
json_data = json.load(f)

grid = target_grid_class.empty()
_restore_grid_values(grid, json_data["data"])
_restore_grid_values(grid, json_data["data"], strict=strict)
graph_class = grid.graphs.__class__
grid.graphs = graph_class.from_arrays(grid)
return grid


def _restore_grid_values(grid: G, json_data: dict) -> None:
def _restore_grid_values(grid: G, json_data: dict, strict: bool) -> None:
"""Restore arrays to the grid."""
for attr_name, attr_values in json_data.items():
if not hasattr(grid, attr_name):
Expand All @@ -94,9 +99,19 @@ def _restore_grid_values(grid: G, json_data: dict) -> None:
array = _deserialize_array(array_data=attr_values, array_class=attr_class)
setattr(grid, attr_name, array)
continue
if hasattr(grid_attr, "from_dict"):
attr_value = grid_attr.from_dict(attr_values)
setattr(grid, attr_name, attr_value)
continue

# load other values
setattr(grid, attr_name, attr_class(attr_values))
try:
setattr(grid, attr_name, attr_class(attr_values))
except TypeError as error:
msg = f"Failed to set attribute '{attr_name}' on grid of type '{grid.__class__.__name__}'."
if strict:
msg += " Set strict=False to skip it or add a .from_dict() class method to the attribute's class."
raise JSONDeserializationError(msg) from error
logger.warning(msg)


def _serialize_array(array: FancyArray) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -132,9 +147,10 @@ def _is_serializable(value: Any, strict: bool) -> bool:
try:
json.dumps(value)
except TypeError as error:
msg = f"Failed to serialize '{value}'. You can set strict=False to ignore this attribute."
msg = f"Failed to serialize '{value.__class__.__name__}'. "
if strict:
raise TypeError(msg) from error
msg += "Set strict=False to skip this attribute or add a .to_dict() method to the attribute's class."
raise JSONSerializationError(msg) from error
logger.warning(msg)
return False
return True
103 changes: 83 additions & 20 deletions tests/unit/model/grids/serialization/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from power_grid_model_ds import Grid, PowerGridModelInterface
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
from power_grid_model_ds._core.model.containers.helpers import container_equal
from power_grid_model_ds._core.model.grids.serialization.errors import JSONDeserializationError, JSONSerializationError
from power_grid_model_ds._core.utils.misc import array_equal_with_nan
from power_grid_model_ds.arrays import LineArray
from power_grid_model_ds.arrays import NodeArray as BaseNodeArray
Expand Down Expand Up @@ -52,7 +53,7 @@ class NonSerializableExtension:
"""A non-serializable extension class"""

def __init__(self):
self.data = "non_serializable"
self.data = "the data"


@dataclass
Expand All @@ -62,6 +63,29 @@ class GridWithNonSerializableExtension(Grid):
non_serializable: NonSerializableExtension = NonSerializableExtension()


class SerializableExtension:
"""A non-serializable extension class"""

def __init__(self):
self.data = "the data"

def to_dict(self):
return {"data": self.data}

@classmethod
def from_dict(cls, data: dict[str, str]):
instance = cls()
instance.data = data["data"]
return instance


@dataclass
class GridWithSerializableExtension(Grid):
"""Grid with a non-serializable extension attribute"""

serializable: SerializableExtension = SerializableExtension()


@pytest.fixture
def basic_grid():
"""Basic grid fixture"""
Expand Down Expand Up @@ -225,7 +249,7 @@ class GridWithCustomArray(Grid):

class TestDeserialize:
def test_deserialize(self, tmp_path: Path):
path = tmp_path / "json_data.json"
path = tmp_path / "grid.json"

data = {"node": [{"id": 1, "u_rated": 10000}, {"id": 2, "u_rated": 20000}]}

Expand All @@ -239,84 +263,123 @@ def test_deserialize(self, tmp_path: Path):
assert grid.node.u_rated.tolist() == [10000, 20000]

def test_extended_grid(self, tmp_path: Path, extended_grid: ExtendedGrid):
extended_data = {
data = {
"node": [
{"id": 1, "u_rated": 10000, "analysis_flag": 42},
{"id": 2, "u_rated": 10000, "analysis_flag": 43},
],
"value_extension": 4.2,
}

path = tmp_path / "json_data.json"
path = tmp_path / "grid.json"
with open(path, "w", encoding="utf-8") as f:
json.dump({"data": extended_data}, f)
json.dump({"data": data}, f)

grid = ExtendedGrid.deserialize(path)
assert grid.value_extension == 4.2
assert grid.node.analysis_flag.tolist() == [42, 43]

def test_unexpected_field(self, tmp_path: Path):
path = tmp_path / "incompatible.json"
path = tmp_path / "grid.json"

# Create incompatible JSON data
incompatible_data = {
data = {
"node": [{"id": 1, "u_rated": 10000}, {"id": 2, "u_rated": 10000}],
"unexpected_field": "unexpected_value",
}

# Write incompatible data to file
with open(path, "w", encoding="utf-8") as f:
json.dump({"data": incompatible_data}, f)
json.dump({"data": data}, f)

grid = Grid.deserialize(path)
assert not hasattr(grid, "unexpected_field")

def test_missing_defaulted_array_field(self, tmp_path: Path):
path = tmp_path / "missing_array.json"
path = tmp_path / "grid.json"

# Node data does not contain 'id' field, but there is a default
missing_array_data = {
data = {
"node": [{"u_rated": 10000}, {"u_rated": 10000}],
}

# Write data to file
with open(path, "w", encoding="utf-8") as f:
json.dump({"data": missing_array_data}, f)
json.dump({"data": data}, f)

Grid.deserialize(path)

def test_missing_required_array_field(self, tmp_path: Path):
path = tmp_path / "missing_array.json"
path = tmp_path / "grid.json"

# Node data does not contain 'id' field, but there is a default
missing_array_data = {
data = {
"node": [{"id": 10000}, {"id": 123}],
}

# Write data to file
with open(path, "w", encoding="utf-8") as f:
json.dump({"data": missing_array_data}, f)
json.dump({"data": data}, f)

with pytest.raises(ValueError):
Grid.deserialize(path)

def test_some_records_miss_data(self, tmp_path):
path = tmp_path / "incomplete_array.json"
incomplete_data = {
"node": [{"id": 1, "u_rated": 10000}, {"u_rated": 10000}, {"id": 3}],
path = tmp_path / "grid.json"
data = {
"node": [
{"id": 1, "u_rated": 10000},
{"u_rated": 10000},
{"id": 3},
],
}

with open(path, "w", encoding="utf-8") as f:
json.dump({"data": incomplete_data}, f)
json.dump({"data": data}, f)

with pytest.raises(ValueError):
Grid.deserialize(path)

def test_non_serializable_extension(self, tmp_path: Path):
path = tmp_path / "non_serializable.json"
path = tmp_path / "grid.json"

grid = GridWithNonSerializableExtension.empty()
grid.non_serializable = NonSerializableExtension()

with pytest.raises(TypeError):
with pytest.raises(JSONSerializationError):
grid.serialize(path)

def test_deserialize_non_serializable_extension(self, tmp_path: Path):
path = tmp_path / "grid.json"

data = {"non_serializable": {"data": "some_data"}}
with open(path, "w", encoding="utf-8") as f:
json.dump({"data": data}, f)

with pytest.raises(JSONDeserializationError):
GridWithNonSerializableExtension.deserialize(path)

def test_deserialize_non_serializable_extension_non_strict(self, tmp_path: Path):
path = tmp_path / "grid.json"

data = {"non_serializable": {"data": "some_data"}}
with open(path, "w", encoding="utf-8") as f:
json.dump({"data": data}, f)

grid = GridWithNonSerializableExtension.deserialize(path, strict=False)
assert grid.non_serializable.data == "the data" # Default value

def test_serializable_extension(self, tmp_path: Path):
path = tmp_path / "grid.json"
grid = GridWithSerializableExtension.empty()
grid.serialize(path)

def test_deserialize_serializable_extension(self, tmp_path: Path):
path = tmp_path / "grid.json"

data = {"serializable": {"data": "some_data"}}
with open(path, "w", encoding="utf-8") as f:
json.dump({"data": data}, f)

grid = GridWithSerializableExtension.deserialize(path)
assert grid.serializable.data == "some_data"
Loading