Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ Primary (X.-.-) version numbers are used to denote backwards
incompatibilities between versions, while minor (-.X.-) numbers
primarily indicate new features and documentation.

5.0.0-per3 (XXXX-XX-XX)
5.0.0-per3 (2025-04-25)
^^^^^^^^^^^^^^^^^^^^^^^^

* Remove 'GStepBStdevModel', which is useless. Use GstepBModel.


5.0.0-pre2 (2024-05-24)
5.0.0-pre2 (2025-03-24)
^^^^^^^^^^^^^^^^^^^^^^^^

* Fix problem in kz-conversion.py
Expand Down
170 changes: 112 additions & 58 deletions src/arpes/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,30 @@
of data.
"""

from __future__ import annotations

import pickle
import json
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from functools import singledispatch
from logging import DEBUG, INFO
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any, Literal, TypedDict, Unpack

import numpy as np
import pandas as pd
import xarray as xr

from arpes._typing import DataType, XrTypes

from .debug import setup_logger
from .endstations import ScanDesc, load_scan
from .example_data.mock import build_mock_tarpes

if TYPE_CHECKING:
from _typeshed import Incomplete

from arpes._typing import XrTypes


__all__ = (
"easy_pickle",
"list_pickles",
"load_custom_netcdf",
"load_data",
"load_example_data",
"save_custom_netcdf",
"stitch",
)

Expand All @@ -52,7 +47,7 @@
def load_data(
file: str | Path,
location: str | None = None,
**kwargs: Incomplete,
**kwargs: Any,
) -> xr.Dataset:
"""Loads a piece of data using available plugins. This the user facing API for data loading.

Expand Down Expand Up @@ -260,71 +255,130 @@ def _df_or_list_to_files(
return list(df_or_list)


def file_for_pickle(name: str) -> Path | str:
here = Path()
from .config import CONFIG
class ToNetCDFParams(TypedDict, total=False):
"""Typed dictionary for the parameters of the to_netcdf function.

if CONFIG["WORKSPACE"] and "path" in CONFIG["WORKSPACE"]:
here = Path(CONFIG["WORKSPACE"]["path"])
path = here / "picklejar" / f"{name}.pickle"
path.parent.mkdir(exist_ok=True)
return str(path)
There are many parameters are available, but here only two are defined.
"""

mode: Literal["w", "a"]
engine: Literal["netcdf4", "h5netcdf"]

def load_pickle(name: str) -> object:
"""Loads a workspace local pickle. Inverse to `save_pickle`."""
with Path(file_for_pickle(name)).open("rb") as file:
return pickle.load(file) # noqa: S301

@singledispatch
def save_custom_netcdf(
obj: xr.DataArray | xr.Dataset | xr.DataTree,
path: str | Path,
**kwargs: Unpack[ToNetCDFParams],
) -> None:
"""Save an xarray object (DataArray, Dataset, or DataTree) to a NetCDF file.

def save_pickle(data: object, name: str) -> None:
"""Saves a workspace local pickle. Inverse to `load_pickle`."""
with Path(file_for_pickle(name)).open("wb") as pickle_file:
pickle.dump(data, pickle_file)
Args:
obj (xr.DataArray | xr.Dataset | xr.DataTree): The xarray object to save.
path (str | Path): The file path to write the NetCDF file to.
**kwargs: Additional keyword arguments passed to `to_netcdf()`.

Returns:
None
"""
del path, kwargs
msg = f"Unsupported type: {type(obj)}"
raise NotImplementedError(msg)


def easy_pickle(data_or_str: str | object, name: str = "") -> object:
"""A convenience function around pickling.
def _jsonify_attrs(obj: DataType) -> DataType:
if hasattr(obj, "attrs"):
obj.attrs = {"__json__": json.dumps(obj.attrs)}
return obj

Provides a workspace scoped associative set of named pickles which
can be used for

Examples:
Retaining analysis results between sessions.
@save_custom_netcdf.register
def _(
obj: xr.DataArray,
path: str | Path,
**kwargs: Unpack[ToNetCDFParams],
) -> None:
"""Save an xarray object (DataArray, Dataset, or DataTree) to a NetCDF file.

Sharing results between workspaces.
encoding all attrs as JSON strings.

Caching expensive or interim work.
Args:
obj (xr.DataArray | xr.Dataset | xr.DataTree): The xarray object to save.
path (str | Path): The file path to write the NetCDF file to.
**kwargs: Additional keyword arguments passed to `to_netcdf()`.

For reproducibility reasons, you should generally prefer to
duplicate anaysis results using common code to prevent stale data
dependencies, but there are good reasons to use pickling as well.
Returns:
None
"""
_jsonify_attrs(obj).to_netcdf(path, **kwargs)

This function knows whether we are pickling or unpickling depending on
whether one or two arguments are provided.

@save_custom_netcdf.register
def _(
obj: xr.Dataset,
path: str | Path,
**kwargs: Unpack[ToNetCDFParams],
) -> None:
"""Save an xarray object (DataArray, Dataset, or DataTree) to a NetCDF file.

encoding all attrs as JSON strings.

Args:
data_or_str: If saving, the data to be pickled. If loading, the name of the pickle to load.
name: If saving (non-None value), the name to associate. Defaults to None.
obj (xr.DataArray | xr.Dataset | xr.DataTree): The xarray object to save.
path (str | Path): The file path to write the NetCDF file to.
**kwargs: Additional keyword arguments passed to `to_netcdf()`.

Returns:
None if name is not None, which indicates that we are saving data.
Otherwise, returns the unpickled value associated to `name`.
None
"""
# we are loading data
if isinstance(data_or_str, str) or not name:
assert isinstance(data_or_str, str)
return load_pickle(data_or_str)
# we are saving data
assert isinstance(name, str)
save_pickle(data_or_str, name)
return None
ds_jsonified = xr.Dataset({name: _jsonify_attrs(var) for name, var in obj.data_vars.items()})
_jsonify_attrs(ds_jsonified).to_netcdf(path, **kwargs)


@save_custom_netcdf.register
def _(
obj: xr.DataTree,
path: str | Path,
**kwargs: Unpack[ToNetCDFParams],
) -> None:
def jsonify_attrs_dataset(obj: xr.Dataset) -> xr.Dataset:
ds_jsonified = xr.Dataset(
{name: _jsonify_attrs(var) for name, var in obj.data_vars.items()},
)
assert isinstance(ds_jsonified, xr.Dataset)
return _jsonify_attrs(ds_jsonified)

new_tree = obj.map_over_datasets(jsonify_attrs_dataset)
assert isinstance(new_tree, xr.DataTree)
new_tree.to_netcdf(path, **kwargs)


def load_custom_netcdf(
path: str | Path,
**kwargs: Unpack[ToNetCDFParams],
) -> xr.DataArray | xr.Dataset | xr.DataTree:
"""Load an xarray object from a NetCDF file and decode all JSON-encoded attrs.

The object type (DataArray, Dataset, or DataTree) is determined automatically

def list_pickles() -> list[str]:
"""Generates a summary list of (workspace-local) pickled results and data.

from the saved metadata.

Args:
path (str | Path): The file path to read the NetCDF file from.
**kwargs: Additional keyword arguments passed to the appropriate `open_*` function.

Returns:
A list of the named pickles, suitable for passing to `easy_pickle`.
xr.DataArray | xr.Dataset | xr.DataTree: The loaded and decoded xarray object.
"""
return [str(s.stem) for s in Path(file_for_pickle("just-a-pickle")).parent.glob("*.pickle")]
obj = xr.open_dataset(path, **kwargs)

# Check and decode attrs if necessary
if hasattr(obj, "attrs"):
obj.attrs = {k: (json.loads(v) if isinstance(v, str) else v) for k, v in obj.attrs.items()}

# If it's a DataTree, reconstruct it
if isinstance(obj, xr.DataTree):
obj = xr.open_datatree(path, **kwargs)

return obj
113 changes: 112 additions & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,122 @@
"""Unit test of io module in aryspes."""

import json
from pathlib import Path

import pytest
import xarray as xr

from arpes.io import load_example_data
from arpes.io import load_custom_netcdf, load_example_data, save_custom_netcdf


def test_load_example_raises_kye_error() -> None:
msg = "Could not find requested example_name: cut0.*"
with pytest.raises(KeyError, match=msg):
load_example_data("cut0")


@pytest.fixture
def sample_dataarray() -> xr.DataArray:
"""Fixture to provide a sample xarray.DataArray for testing."""
data = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [0, 1, 2]})
data.attrs = {"description": "Test DataArray", "info": {"nested": "value"}}
return data


@pytest.fixture
def sample_dataset() -> xr.Dataset:
"""Fixture to provide a sample xarray.Dataset for testing."""
data = xr.Dataset({"var": ("x", [1, 2, 3])}, coords={"x": [0, 1, 2]})
data.attrs = {"description": "Test Dataset", "info": {"nested": "value"}}
return data


@pytest.fixture
def sample_datatree() -> xr.DataTree:
"""Fixture to provide a sample xarray.DataTree for testing."""
node1 = xr.DataArray([1, 2], name="node1")
node2 = xr.DataArray([3, 4], name="node2")
dataset = xr.Dataset({"node1": node1, "node2": node2})

data = xr.DataTree(dataset)

data.attrs = {"description": "Test DataTree", "info": {"nested": "value"}}
return data


@pytest.fixture
def sample_hierarchy_datatree() -> xr.DataTree:
"""Sample hierarchy datatree taken from xarray document.ation.

https://docs.xarray.dev/en/latest/user-guide/hierarchical-data.html

"""
bart = xr.DataTree(name="Bart")
lisa = xr.DataTree(name="Lisa")
homer = xr.DataTree(name="Homer", children={"Bart": bart, "Lisa": lisa})
maggie = xr.DataTree(name="Maggie")
homer.children = {"Bart": bart, "Lisa": lisa, "Maggie": maggie}
abe = xr.DataTree(name="Abe")
abe.children = {"Homer": homer}
herbert = xr.DataTree(name="Herb")
return abe.assign({"Herbert": herbert})


def test_save_load_dataarray(sample_dataarray: xr.DataArray, tmp_path: Path):
"""Test saving and loading a DataArray with JSON-encoded attrs."""
file_path = tmp_path / "dataarray.nc"

# Save the DataArray
save_custom_netcdf(sample_dataarray, file_path)

# Load the DataArray
loaded = load_custom_netcdf(file_path)

assert isinstance(loaded, xr.DataArray)
assert loaded.attrs["description"] == "Test DataArray"
assert json.loads(loaded.attrs["info"]) == {"nested": "value"}


def test_save_load_dataset(sample_dataset: xr.Dataset, tmp_path: Path):
"""Test saving and loading a Dataset with JSON-encoded attrs."""
file_path = tmp_path / "dataset.nc"

# Save the Dataset
save_custom_netcdf(sample_dataset, file_path)

# Load the Dataset
loaded = load_custom_netcdf(file_path)

assert isinstance(loaded, xr.Dataset)
assert loaded.attrs["description"] == "Test Dataset"
assert json.loads(loaded.attrs["info"]) == {"nested": "value"}


def test_save_load_datatree(sample_datatree: xr.DataTree, tmp_path: Path):
"""Test saving and loading a DataTree with JSON-encoded attrs."""
file_path = tmp_path / "datatree.nc"

# Save the DataTree
save_custom_netcdf(sample_datatree, file_path)

# Load the DataTree
loaded = load_custom_netcdf(file_path)

assert isinstance(loaded, xr.DataTree)
assert loaded.attrs["description"] == "Test DataTree"
assert json.loads(loaded.attrs["info"]) == {"nested": "value"}


def test_save_load_with_kwargs(sample_dataarray: xr.DataArray, tmp_path: Path):
"""Test saving and loading with additional kwargs passed to `to_netcdf`."""
file_path = tmp_path / "dataarray_with_kwargs.nc"

# Save with engine kwargs
save_custom_netcdf(sample_dataarray, file_path, engine="h5netcdf")

# Load the DataArray
loaded = load_custom_netcdf(file_path)

assert isinstance(loaded, xr.DataArray)
assert loaded.attrs["description"] == "Test DataArray"
assert json.loads(loaded.attrs["info"]) == {"nested": "value"}
Loading