diff --git a/docs/source/CHANGELOG.rst b/docs/source/CHANGELOG.rst index 9f628311..63a4d21f 100644 --- a/docs/source/CHANGELOG.rst +++ b/docs/source/CHANGELOG.rst @@ -9,13 +9,13 @@ Primary (X.-.-) version numbers are used to denote backwards incompatibilities between versions, while minor (-.X.-) numbers primarily indicate new features and documentation. -5.0.0-per3 (XXXX-XX-XX) +5.0.0-per3 (2025-04-25) ^^^^^^^^^^^^^^^^^^^^^^^^ * Remove 'GStepBStdevModel', which is useless. Use GstepBModel. -5.0.0-pre2 (2024-05-24) +5.0.0-pre2 (2025-03-24) ^^^^^^^^^^^^^^^^^^^^^^^^ * Fix problem in kz-conversion.py diff --git a/src/arpes/io.py b/src/arpes/io.py index c9db6cc3..17a799de 100644 --- a/src/arpes/io.py +++ b/src/arpes/io.py @@ -11,35 +11,30 @@ of data. """ -from __future__ import annotations - -import pickle +import json import warnings from collections.abc import Iterable from dataclasses import dataclass +from functools import singledispatch from logging import DEBUG, INFO from pathlib import Path -from typing import TYPE_CHECKING +from typing import Any, Literal, TypedDict, Unpack import numpy as np import pandas as pd import xarray as xr +from arpes._typing import DataType, XrTypes + from .debug import setup_logger from .endstations import ScanDesc, load_scan from .example_data.mock import build_mock_tarpes -if TYPE_CHECKING: - from _typeshed import Incomplete - - from arpes._typing import XrTypes - - __all__ = ( - "easy_pickle", - "list_pickles", + "load_custom_netcdf", "load_data", "load_example_data", + "save_custom_netcdf", "stitch", ) @@ -52,7 +47,7 @@ def load_data( file: str | Path, location: str | None = None, - **kwargs: Incomplete, + **kwargs: Any, ) -> xr.Dataset: """Loads a piece of data using available plugins. This the user facing API for data loading. @@ -260,71 +255,130 @@ def _df_or_list_to_files( return list(df_or_list) -def file_for_pickle(name: str) -> Path | str: - here = Path() - from .config import CONFIG +class ToNetCDFParams(TypedDict, total=False): + """Typed dictionary for the parameters of the to_netcdf function. - if CONFIG["WORKSPACE"] and "path" in CONFIG["WORKSPACE"]: - here = Path(CONFIG["WORKSPACE"]["path"]) - path = here / "picklejar" / f"{name}.pickle" - path.parent.mkdir(exist_ok=True) - return str(path) + There are many parameters are available, but here only two are defined. + """ + mode: Literal["w", "a"] + engine: Literal["netcdf4", "h5netcdf"] -def load_pickle(name: str) -> object: - """Loads a workspace local pickle. Inverse to `save_pickle`.""" - with Path(file_for_pickle(name)).open("rb") as file: - return pickle.load(file) # noqa: S301 +@singledispatch +def save_custom_netcdf( + obj: xr.DataArray | xr.Dataset | xr.DataTree, + path: str | Path, + **kwargs: Unpack[ToNetCDFParams], +) -> None: + """Save an xarray object (DataArray, Dataset, or DataTree) to a NetCDF file. -def save_pickle(data: object, name: str) -> None: - """Saves a workspace local pickle. Inverse to `load_pickle`.""" - with Path(file_for_pickle(name)).open("wb") as pickle_file: - pickle.dump(data, pickle_file) + Args: + obj (xr.DataArray | xr.Dataset | xr.DataTree): The xarray object to save. + path (str | Path): The file path to write the NetCDF file to. + **kwargs: Additional keyword arguments passed to `to_netcdf()`. + + Returns: + None + """ + del path, kwargs + msg = f"Unsupported type: {type(obj)}" + raise NotImplementedError(msg) -def easy_pickle(data_or_str: str | object, name: str = "") -> object: - """A convenience function around pickling. +def _jsonify_attrs(obj: DataType) -> DataType: + if hasattr(obj, "attrs"): + obj.attrs = {"__json__": json.dumps(obj.attrs)} + return obj - Provides a workspace scoped associative set of named pickles which - can be used for - Examples: - Retaining analysis results between sessions. +@save_custom_netcdf.register +def _( + obj: xr.DataArray, + path: str | Path, + **kwargs: Unpack[ToNetCDFParams], +) -> None: + """Save an xarray object (DataArray, Dataset, or DataTree) to a NetCDF file. - Sharing results between workspaces. + encoding all attrs as JSON strings. - Caching expensive or interim work. + Args: + obj (xr.DataArray | xr.Dataset | xr.DataTree): The xarray object to save. + path (str | Path): The file path to write the NetCDF file to. + **kwargs: Additional keyword arguments passed to `to_netcdf()`. - For reproducibility reasons, you should generally prefer to - duplicate anaysis results using common code to prevent stale data - dependencies, but there are good reasons to use pickling as well. + Returns: + None + """ + _jsonify_attrs(obj).to_netcdf(path, **kwargs) - This function knows whether we are pickling or unpickling depending on - whether one or two arguments are provided. + +@save_custom_netcdf.register +def _( + obj: xr.Dataset, + path: str | Path, + **kwargs: Unpack[ToNetCDFParams], +) -> None: + """Save an xarray object (DataArray, Dataset, or DataTree) to a NetCDF file. + + encoding all attrs as JSON strings. Args: - data_or_str: If saving, the data to be pickled. If loading, the name of the pickle to load. - name: If saving (non-None value), the name to associate. Defaults to None. + obj (xr.DataArray | xr.Dataset | xr.DataTree): The xarray object to save. + path (str | Path): The file path to write the NetCDF file to. + **kwargs: Additional keyword arguments passed to `to_netcdf()`. Returns: - None if name is not None, which indicates that we are saving data. - Otherwise, returns the unpickled value associated to `name`. + None """ - # we are loading data - if isinstance(data_or_str, str) or not name: - assert isinstance(data_or_str, str) - return load_pickle(data_or_str) - # we are saving data - assert isinstance(name, str) - save_pickle(data_or_str, name) - return None + ds_jsonified = xr.Dataset({name: _jsonify_attrs(var) for name, var in obj.data_vars.items()}) + _jsonify_attrs(ds_jsonified).to_netcdf(path, **kwargs) + + +@save_custom_netcdf.register +def _( + obj: xr.DataTree, + path: str | Path, + **kwargs: Unpack[ToNetCDFParams], +) -> None: + def jsonify_attrs_dataset(obj: xr.Dataset) -> xr.Dataset: + ds_jsonified = xr.Dataset( + {name: _jsonify_attrs(var) for name, var in obj.data_vars.items()}, + ) + assert isinstance(ds_jsonified, xr.Dataset) + return _jsonify_attrs(ds_jsonified) + + new_tree = obj.map_over_datasets(jsonify_attrs_dataset) + assert isinstance(new_tree, xr.DataTree) + new_tree.to_netcdf(path, **kwargs) + + +def load_custom_netcdf( + path: str | Path, + **kwargs: Unpack[ToNetCDFParams], +) -> xr.DataArray | xr.Dataset | xr.DataTree: + """Load an xarray object from a NetCDF file and decode all JSON-encoded attrs. + The object type (DataArray, Dataset, or DataTree) is determined automatically -def list_pickles() -> list[str]: - """Generates a summary list of (workspace-local) pickled results and data. + + from the saved metadata. + + Args: + path (str | Path): The file path to read the NetCDF file from. + **kwargs: Additional keyword arguments passed to the appropriate `open_*` function. Returns: - A list of the named pickles, suitable for passing to `easy_pickle`. + xr.DataArray | xr.Dataset | xr.DataTree: The loaded and decoded xarray object. """ - return [str(s.stem) for s in Path(file_for_pickle("just-a-pickle")).parent.glob("*.pickle")] + obj = xr.open_dataset(path, **kwargs) + + # Check and decode attrs if necessary + if hasattr(obj, "attrs"): + obj.attrs = {k: (json.loads(v) if isinstance(v, str) else v) for k, v in obj.attrs.items()} + + # If it's a DataTree, reconstruct it + if isinstance(obj, xr.DataTree): + obj = xr.open_datatree(path, **kwargs) + + return obj diff --git a/tests/test_io.py b/tests/test_io.py index 659784fd..0fc9b3a2 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,11 +1,122 @@ """Unit test of io module in aryspes.""" +import json +from pathlib import Path + import pytest +import xarray as xr -from arpes.io import load_example_data +from arpes.io import load_custom_netcdf, load_example_data, save_custom_netcdf def test_load_example_raises_kye_error() -> None: msg = "Could not find requested example_name: cut0.*" with pytest.raises(KeyError, match=msg): load_example_data("cut0") + + +@pytest.fixture +def sample_dataarray() -> xr.DataArray: + """Fixture to provide a sample xarray.DataArray for testing.""" + data = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [0, 1, 2]}) + data.attrs = {"description": "Test DataArray", "info": {"nested": "value"}} + return data + + +@pytest.fixture +def sample_dataset() -> xr.Dataset: + """Fixture to provide a sample xarray.Dataset for testing.""" + data = xr.Dataset({"var": ("x", [1, 2, 3])}, coords={"x": [0, 1, 2]}) + data.attrs = {"description": "Test Dataset", "info": {"nested": "value"}} + return data + + +@pytest.fixture +def sample_datatree() -> xr.DataTree: + """Fixture to provide a sample xarray.DataTree for testing.""" + node1 = xr.DataArray([1, 2], name="node1") + node2 = xr.DataArray([3, 4], name="node2") + dataset = xr.Dataset({"node1": node1, "node2": node2}) + + data = xr.DataTree(dataset) + + data.attrs = {"description": "Test DataTree", "info": {"nested": "value"}} + return data + + +@pytest.fixture +def sample_hierarchy_datatree() -> xr.DataTree: + """Sample hierarchy datatree taken from xarray document.ation. + + https://docs.xarray.dev/en/latest/user-guide/hierarchical-data.html + + """ + bart = xr.DataTree(name="Bart") + lisa = xr.DataTree(name="Lisa") + homer = xr.DataTree(name="Homer", children={"Bart": bart, "Lisa": lisa}) + maggie = xr.DataTree(name="Maggie") + homer.children = {"Bart": bart, "Lisa": lisa, "Maggie": maggie} + abe = xr.DataTree(name="Abe") + abe.children = {"Homer": homer} + herbert = xr.DataTree(name="Herb") + return abe.assign({"Herbert": herbert}) + + +def test_save_load_dataarray(sample_dataarray: xr.DataArray, tmp_path: Path): + """Test saving and loading a DataArray with JSON-encoded attrs.""" + file_path = tmp_path / "dataarray.nc" + + # Save the DataArray + save_custom_netcdf(sample_dataarray, file_path) + + # Load the DataArray + loaded = load_custom_netcdf(file_path) + + assert isinstance(loaded, xr.DataArray) + assert loaded.attrs["description"] == "Test DataArray" + assert json.loads(loaded.attrs["info"]) == {"nested": "value"} + + +def test_save_load_dataset(sample_dataset: xr.Dataset, tmp_path: Path): + """Test saving and loading a Dataset with JSON-encoded attrs.""" + file_path = tmp_path / "dataset.nc" + + # Save the Dataset + save_custom_netcdf(sample_dataset, file_path) + + # Load the Dataset + loaded = load_custom_netcdf(file_path) + + assert isinstance(loaded, xr.Dataset) + assert loaded.attrs["description"] == "Test Dataset" + assert json.loads(loaded.attrs["info"]) == {"nested": "value"} + + +def test_save_load_datatree(sample_datatree: xr.DataTree, tmp_path: Path): + """Test saving and loading a DataTree with JSON-encoded attrs.""" + file_path = tmp_path / "datatree.nc" + + # Save the DataTree + save_custom_netcdf(sample_datatree, file_path) + + # Load the DataTree + loaded = load_custom_netcdf(file_path) + + assert isinstance(loaded, xr.DataTree) + assert loaded.attrs["description"] == "Test DataTree" + assert json.loads(loaded.attrs["info"]) == {"nested": "value"} + + +def test_save_load_with_kwargs(sample_dataarray: xr.DataArray, tmp_path: Path): + """Test saving and loading with additional kwargs passed to `to_netcdf`.""" + file_path = tmp_path / "dataarray_with_kwargs.nc" + + # Save with engine kwargs + save_custom_netcdf(sample_dataarray, file_path, engine="h5netcdf") + + # Load the DataArray + loaded = load_custom_netcdf(file_path) + + assert isinstance(loaded, xr.DataArray) + assert loaded.attrs["description"] == "Test DataArray" + assert json.loads(loaded.attrs["info"]) == {"nested": "value"}