diff --git a/docs/key_functionality.rst b/docs/key_functionality.rst index f5e29c4..b87b6fd 100644 --- a/docs/key_functionality.rst +++ b/docs/key_functionality.rst @@ -155,8 +155,8 @@ and continue loading the SDF file as normal. This file contains the initial simulation setup information which is not present in SDF outputs. By loading this file, you can access these parameters as part of your dataset's metadata. To do this, use the ``deck_path`` parameter when loading an SDF file with -`xarray.open_dataset`, `sdf_xarray.open_datatree`, `sdf_xarray.open_mfdataset` -or `sdf_xarray.open_mfdatatree`. +`sdf_xarray.open_dataset`, `xarray.open_dataset`, `sdf_xarray.open_datatree`, +`xarray.open_datatree`, `sdf_xarray.open_mfdataset` or `sdf_xarray.open_mfdatatree`. There are a few ways you can load an input deck: diff --git a/src/sdf_xarray/__init__.py b/src/sdf_xarray/__init__.py index 3a72789..55068ec 100644 --- a/src/sdf_xarray/__init__.py +++ b/src/sdf_xarray/__init__.py @@ -268,7 +268,7 @@ def open_dataset( def open_mfdataset( - path_glob: Iterable | str | Path | Callable[..., Iterable[Path]], + paths: Iterable | str | Path | Callable[..., Iterable[Path]], *, separate_times: bool = False, keep_particles: bool = False, @@ -301,7 +301,7 @@ def open_mfdataset( Parameters ---------- - path_glob + paths List of filenames or string glob pattern separate_times If ``True``, create separate time dimensions for variables defined at @@ -326,11 +326,11 @@ def open_mfdataset( from a relative or absolute file path. See :ref:`loading-input-deck` for details. """ - path_glob = _resolve_glob(path_glob) + paths = _resolve_glob(paths) if not separate_times: return combine_datasets( - path_glob, + paths, data_vars=data_vars, keep_particles=keep_particles, probe_names=probe_names, @@ -338,10 +338,10 @@ def open_mfdataset( deck_path=deck_path, ) - _, var_times_map = make_time_dims(path_glob) + _, var_times_map = make_time_dims(paths) all_dfs = [] - for f in path_glob: + for f in paths: ds = xr.open_dataset( f, keep_particles=keep_particles, @@ -451,7 +451,7 @@ def open_datatree( def open_mfdatatree( - path_glob: Iterable | str | Path | Callable[..., Iterable[Path]], + paths: Iterable | str | Path | Callable[..., Iterable[Path]], *, separate_times: bool = False, keep_particles: bool = False, @@ -512,7 +512,7 @@ def open_mfdatatree( Parameters ---------- - path_glob + paths List of filenames or string glob pattern separate_times If ``True``, create separate time dimensions for variables defined at @@ -536,7 +536,7 @@ def open_mfdatatree( """ # First, combine the datasets as usual combined_ds = open_mfdataset( - path_glob, + paths, separate_times=separate_times, keep_particles=keep_particles, probe_names=probe_names, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a4da3e9..537ce9c 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,15 +1,16 @@ +from typing import Protocol + import numpy as np import numpy.testing as npt import pytest import xarray as xr +import sdf_xarray as sdfxr from sdf_xarray import ( SDFPreprocess, _process_latex_name, _resolve_glob, download, - open_dataset, - open_mfdataset, ) TEST_FILES_DIR = download.fetch_dataset("test_files_1D") @@ -19,22 +20,16 @@ TEST_2D_PARTICLE_DATA = download.fetch_dataset("test_two_probes_2D") -def test_basic(): - with open_dataset(TEST_FILES_DIR / "0000.sdf") as df: - ex_field = "Electric_Field_Ex" - assert ex_field in df - x_coord = "X_Grid_mid" - assert x_coord in df[ex_field].coords - assert df[x_coord].attrs["long_name"] == "X" +# Type hinting support +class XRLibrary(Protocol): + def open_dataset(self, *args, **kwargs) -> xr.Dataset: ... - px_protons = "Particles_Px_proton" - assert px_protons not in df - x_coord = "X_Particles_proton" - assert x_coord not in df.coords + def open_mfdataset(self, *args, **kwargs) -> xr.Dataset: ... -def test_xr_basic(): - with xr.open_dataset(TEST_FILES_DIR / "0000.sdf") as df: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_basic(xrlib: XRLibrary): + with xrlib.open_dataset(TEST_FILES_DIR / "0000.sdf") as df: ex_field = "Electric_Field_Ex" assert ex_field in df x_coord = "X_Grid_mid" @@ -47,17 +42,9 @@ def test_xr_basic(): assert x_coord not in df.coords -def test_constant_name_and_units(): - with open_dataset(TEST_FILES_DIR / "0000.sdf") as df: - name = "Absorption_Total_Laser_Energy_Injected" - full_name = "Absorption/Total Laser Energy Injected" - assert name in df - assert df[name].units == "J" - assert df[name].attrs["full_name"] == full_name - - -def test_xr_constant_name_and_units(): - with xr.open_dataset(TEST_FILES_DIR / "0000.sdf") as df: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_constant_name_and_units(xrlib: XRLibrary): + with xrlib.open_dataset(TEST_FILES_DIR / "0000.sdf") as df: name = "Absorption_Total_Laser_Energy_Injected" full_name = "Absorption/Total Laser Energy Injected" assert name in df @@ -65,49 +52,26 @@ def test_xr_constant_name_and_units(): assert df[name].attrs["full_name"] == full_name -def test_preferred_chunks_metadata(): - with open_dataset(TEST_FILES_DIR / "0000.sdf") as df: - for var in df.data_vars: - assert "preferred_chunks" in df[var].encoding - - -def test_xr_preferred_chunks_metadata(): - with xr.open_dataset(TEST_FILES_DIR / "0000.sdf") as df: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_preferred_chunks_metadata(xrlib: XRLibrary): + with xrlib.open_dataset(TEST_FILES_DIR / "0000.sdf") as df: for var in df.data_vars: assert "preferred_chunks" in df[var].encoding -def test_coords(): - with open_dataset(TEST_FILES_DIR / "0010.sdf") as df: - px_electron = "dist_fn_x_px_electron" - assert px_electron in df - print(df[px_electron].coords) - x_coord = "Px_x_px_electron" - assert x_coord in df[px_electron].coords - assert df[x_coord].attrs["full_name"] == "Grid/x_px/electron" - - -def test_xr_coords(): - with xr.open_dataset(TEST_FILES_DIR / "0010.sdf") as df: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_coords(xrlib: XRLibrary): + with xrlib.open_dataset(TEST_FILES_DIR / "0010.sdf") as df: px_electron = "dist_fn_x_px_electron" assert px_electron in df - print(df[px_electron].coords) x_coord = "Px_x_px_electron" assert x_coord in df[px_electron].coords assert df[x_coord].attrs["full_name"] == "Grid/x_px/electron" -def test_particles(): - with open_dataset(TEST_FILES_DIR / "0010.sdf", keep_particles=True) as df: - px_protons = "Particles_Px_proton" - assert px_protons in df - x_coord = "X_Particles_proton" - assert x_coord in df[px_protons].coords - assert df[x_coord].attrs["long_name"] == "X" - - -def test_xr_particles(): - with xr.open_dataset(TEST_FILES_DIR / "0010.sdf", keep_particles=True) as df: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_particles(xrlib: XRLibrary): + with xrlib.open_dataset(TEST_FILES_DIR / "0010.sdf", keep_particles=True) as df: px_protons = "Particles_Px_proton" assert px_protons in df x_coord = "X_Particles_proton" @@ -115,20 +79,33 @@ def test_xr_particles(): assert df[x_coord].attrs["long_name"] == "X" -def test_no_particles(): - with open_dataset(TEST_FILES_DIR / "0010.sdf", keep_particles=False) as df: - px_protons = "Particles_Px_proton" - assert px_protons not in df - - -def test_xr_no_particles(): - with xr.open_dataset(TEST_FILES_DIR / "0010.sdf", keep_particles=False) as df: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_no_particles(xrlib: XRLibrary): + with xrlib.open_dataset(TEST_FILES_DIR / "0010.sdf", keep_particles=False) as df: px_protons = "Particles_Px_proton" assert px_protons not in df -def test_multiple_files_one_time_dim(): - with open_mfdataset(TEST_FILES_DIR.glob("*.sdf"), keep_particles=True) as df: +@pytest.mark.parametrize( + ("xrlib", "params"), + [ + ( + xr, + { + "compat": "no_conflicts", + "join": "outer", + "preprocess": SDFPreprocess(), + }, + ), + (sdfxr, {}), + ], +) +def test_multiple_files_one_time_dim(xrlib: XRLibrary, params): + with xrlib.open_mfdataset( + paths=TEST_FILES_DIR.glob("*.sdf"), + keep_particles=True, + **params, + ) as df: ex_field = df["Electric_Field_Ex"] assert sorted(ex_field.coords) == sorted(("X_Grid_mid", "time")) assert ex_field.shape == (11, 16) @@ -215,7 +192,7 @@ def test_multiple_files_one_time_dim(): def test_multiple_files_multiple_time_dims(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), separate_times=True, keep_particles=True ) as df: assert list(df["Electric_Field_Ex"].coords) != list( @@ -280,111 +257,28 @@ def test_resolve_glob_from_path_list_multiple_duplicates(): assert result == expected -def test_xr_erroring_on_mismatched_jobid_files(): +@pytest.mark.parametrize( + ("xrlib", "params"), + [ + ( + xr, + { + "compat": "no_conflicts", + "join": "outer", + "preprocess": SDFPreprocess(), + }, + ), + (sdfxr, {}), + ], +) +def test_erroring_on_mismatched_jobid_files(xrlib, params): with pytest.raises(ValueError): # noqa: PT011 - xr.open_mfdataset( - TEST_MISMATCHED_FILES_DIR.glob("*.sdf"), - combine="nested", - data_vars="minimal", - coords="minimal", - compat="override", - join="outer", - preprocess=SDFPreprocess(), - ) - - -def test_xr_multiple_files_data(): - with xr.open_mfdataset( - TEST_FILES_DIR.glob("*.sdf"), - compat="no_conflicts", - join="outer", - preprocess=SDFPreprocess(), - ) as df: - ex = df.isel(time=10)["Electric_Field_Ex"] - ex_values = ex.values - ex_x_coords = ex.coords["X_Grid_mid"].values - - expected_ex = np.array( - [ - -3126528.47057157754898071289062500000000, - -3249643.37612255383282899856567382812500, - -6827013.11566223856061697006225585937500, - -9350267.99022011645138263702392578125000, - -1643592.58487333403900265693664550781250, - -2044751.41207189299166202545166015625000, - -4342811.34666103497147560119628906250000, - -10420841.38402196019887924194335937500000, - -7038801.83154528774321079254150390625000, - 781649.31791684380732476711273193359375, - 4476555.84853181242942810058593750000000, - 5873312.79385650344192981719970703125000, - -95930.60501570138148963451385498046875, - -8977898.96547995693981647491455078125000, - -7951712.64987809769809246063232421875000, - -5655667.11171338520944118499755859375000, - ] - ) - expected_ex_coords = np.array( - [ - 1.72522447e-05, - 5.17567340e-05, - 8.62612233e-05, - 1.20765713e-04, - 1.55270202e-04, - 1.89774691e-04, - 2.24279181e-04, - 2.58783670e-04, - 2.93288159e-04, - 3.27792649e-04, - 3.62297138e-04, - 3.96801627e-04, - 4.31306117e-04, - 4.65810606e-04, - 5.00315095e-04, - 5.34819585e-04, - ] - ) - npt.assert_allclose(ex_values, expected_ex) - npt.assert_allclose(ex_x_coords, expected_ex_coords) - - -def test_xr_time_dim(): - with xr.open_mfdataset( - TEST_FILES_DIR.glob("*.sdf"), - join="outer", - preprocess=SDFPreprocess(), - ) as df: - time = df["time"] - assert time.units == "s" - assert time.long_name == "Time" - assert time.full_name == "time" - - time_values = np.array( - [ - 5.466993e-14, - 2.417504e-10, - 4.833915e-10, - 7.251419e-10, - 9.667830e-10, - 1.208533e-09, - 1.450175e-09, - 1.691925e-09, - 1.933566e-09, - 2.175316e-09, - 2.416958e-09, - ] - ) - - npt.assert_allclose(time_values, time.values, rtol=1e-6) + xrlib.open_mfdataset(paths=TEST_MISMATCHED_FILES_DIR.glob("*.sdf"), **params) -def test_xr_latex_rename_variables(): - with xr.open_mfdataset( - TEST_ARRAYS_DIR.glob("*.sdf"), - join="outer", - preprocess=SDFPreprocess(), - keep_particles=True, - ) as df: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_latex_rename_variables(xrlib: XRLibrary): + with xrlib.open_dataset(TEST_ARRAYS_DIR / "0001.sdf", keep_particles=True) as df: assert df["Electric_Field_Ex"].attrs["long_name"] == "Electric Field $E_x$" assert df["Electric_Field_Ey"].attrs["long_name"] == "Electric Field $E_y$" assert df["Electric_Field_Ez"].attrs["long_name"] == "Electric Field $E_z$" @@ -417,7 +311,7 @@ def test_xr_latex_rename_variables(): ) -def test_xr_arrays_with_no_grids(): +def test_arrays_with_no_grids(): with xr.open_dataset(TEST_ARRAYS_DIR / "0001.sdf") as df: laser_phase = "laser_x_min_phase" assert laser_phase in df @@ -428,7 +322,7 @@ def test_xr_arrays_with_no_grids(): assert df[random_states].shape == (8,) -def test_xr_arrays_with_no_grids_multifile(): +def test_arrays_with_no_grids_multifile(): with xr.open_mfdataset( TEST_ARRAYS_DIR.glob("*.sdf"), join="outer", @@ -443,37 +337,23 @@ def test_xr_arrays_with_no_grids_multifile(): assert df[random_states].shape == (1, 8) -def test_xr_3d_distribution_function(): +def test_3d_distribution_function(): with xr.open_dataset(TEST_3D_DIST_FN / "0000.sdf") as df: distribution_function = "dist_fn_x_px_py_Electron" assert df[distribution_function].shape == (16, 20, 20) -def test_drop_variables(): - with open_dataset( - TEST_FILES_DIR / "0000.sdf", drop_variables=["Electric_Field_Ex"] - ) as df: - assert "Electric_Field_Ex" not in df - - -def test_xr_drop_variables(): - with xr.open_dataset( +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_drop_variables(xrlib: XRLibrary): + with xrlib.open_dataset( TEST_FILES_DIR / "0000.sdf", drop_variables=["Electric_Field_Ex"] ) as df: assert "Electric_Field_Ex" not in df -def test_drop_variables_multiple(): - with open_dataset( - TEST_FILES_DIR / "0000.sdf", - drop_variables=["Electric_Field_Ex", "Electric_Field_Ey"], - ) as df: - assert "Electric_Field_Ex" not in df - assert "Electric_Field_Ey" not in df - - -def test_xr_drop_variables_multiple(): - with xr.open_dataset( +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_drop_variables_multiple(xrlib: XRLibrary): + with xrlib.open_dataset( TEST_FILES_DIR / "0000.sdf", drop_variables=["Electric_Field_Ex", "Electric_Field_Ey"], ) as df: @@ -481,17 +361,9 @@ def test_xr_drop_variables_multiple(): assert "Electric_Field_Ey" not in df -def test_drop_variables_original(): - with open_dataset( - TEST_FILES_DIR / "0000.sdf", - drop_variables=["Electric_Field/Ex", "Electric_Field/Ey"], - ) as df: - assert "Electric_Field_Ex" not in df - assert "Electric_Field_Ey" not in df - - -def test_xr_drop_variables_original(): - with xr.open_dataset( +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_drop_variables_original(xrlib: XRLibrary): + with xrlib.open_dataset( TEST_FILES_DIR / "0000.sdf", drop_variables=["Electric_Field/Ex", "Electric_Field/Ey"], ) as df: @@ -499,17 +371,9 @@ def test_xr_drop_variables_original(): assert "Electric_Field_Ey" not in df -def test_drop_variables_mixed(): - with open_dataset( - TEST_FILES_DIR / "0000.sdf", - drop_variables=["Electric_Field/Ex", "Electric_Field_Ey"], - ) as df: - assert "Electric_Field_Ex" not in df - assert "Electric_Field_Ey" not in df - - -def test_xr_drop_variables_mixed(): - with xr.open_dataset( +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_drop_variables_mixed(xrlib: XRLibrary): + with xrlib.open_dataset( TEST_FILES_DIR / "0000.sdf", drop_variables=["Electric_Field/Ex", "Electric_Field_Ey"], ) as df: @@ -517,32 +381,17 @@ def test_xr_drop_variables_mixed(): assert "Electric_Field_Ey" not in df -def test_erroring_drop_variables(): - with pytest.raises(KeyError): - open_dataset(TEST_FILES_DIR / "0000.sdf", drop_variables=["Electric_Field/E"]) - - -def test_xr_erroring_drop_variables(): +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_erroring_drop_variables(xrlib: XRLibrary): with pytest.raises(KeyError): - xr.open_dataset( + xrlib.open_dataset( TEST_FILES_DIR / "0000.sdf", drop_variables=["Electric_Field/E"] ) -def test_loading_multiple_probes(): - with open_dataset( - TEST_2D_PARTICLE_DATA / "0002.sdf", - keep_particles=True, - probe_names=["Electron_Front_Probe", "Electron_Back_Probe"], - ) as df: - assert "X_Probe_Electron_Front_Probe" in df.coords - assert "X_Probe_Electron_Back_Probe" in df.coords - assert "ID_Electron_Front_Probe_Px" in df.dims - assert "ID_Electron_Back_Probe_Px" in df.dims - - -def test_xr_loading_multiple_probes(): - with xr.open_dataset( +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_loading_multiple_probes(xrlib: XRLibrary): + with xrlib.open_dataset( TEST_2D_PARTICLE_DATA / "0002.sdf", keep_particles=True, probe_names=["Electron_Front_Probe", "Electron_Back_Probe"], @@ -553,25 +402,9 @@ def test_xr_loading_multiple_probes(): assert "ID_Electron_Back_Probe_Px" in df.dims -def test_loading_one_probe_drop_second_probe(): - with open_dataset( - TEST_2D_PARTICLE_DATA / "0002.sdf", - keep_particles=True, - drop_variables=[ - "Electron_Back_Probe_Px", - "Electron_Back_Probe_Py", - "Electron_Back_Probe_Pz", - "Electron_Back_Probe_weight", - ], - probe_names=["Electron_Front_Probe"], - ) as df: - assert "X_Probe_Electron_Front_Probe" in df.coords - assert "ID_Electron_Front_Probe_Px" in df.dims - assert "ID_Electron_Back_Probe_Px" not in df.dims - - -def test_xr_loading_one_probe_drop_second_probe(): - with xr.open_dataset( +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_loading_one_probe_drop_second_probe(xrlib: XRLibrary): + with xrlib.open_dataset( TEST_2D_PARTICLE_DATA / "0002.sdf", keep_particles=True, drop_variables=[ @@ -588,7 +421,7 @@ def test_xr_loading_one_probe_drop_second_probe(): def test_open_mfdataset_data_vars_single(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex"], ) as df: @@ -603,7 +436,7 @@ def test_open_mfdataset_data_vars_single(): def test_open_mfdataset_data_vars_multiple(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex", "Electric_Field_Ey"], ) as df: @@ -623,7 +456,7 @@ def test_open_mfdataset_data_vars_multiple(): def test_open_mfdataset_data_vars_sparse_multiple(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), keep_particles=True, data_vars=[ @@ -657,7 +490,7 @@ def test_open_mfdataset_data_vars_sparse_multiple(): def test_open_mfdataset_data_vars_invalid_var(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field"], ) as df: @@ -666,7 +499,7 @@ def test_open_mfdataset_data_vars_invalid_var(): def test_open_mfdataset_data_vars_time(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex"], ) as df: @@ -695,7 +528,7 @@ def test_open_mfdataset_data_vars_time(): def test_open_mfdataset_data_vars_sparse_time(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Particles_Particles_Per_Cell_proton"], ) as df: @@ -724,7 +557,7 @@ def test_open_mfdataset_data_vars_sparse_time(): def test_open_mfdataset_data_vars_separate_times_single(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex"], separate_times=True, @@ -746,7 +579,7 @@ def test_open_mfdataset_data_vars_separate_times_single(): def test_open_mfdataset_data_vars_separate_times_multiple(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex", "Electric_Field_Ey"], separate_times=True, @@ -778,7 +611,7 @@ def test_open_mfdataset_data_vars_separate_times_multiple(): def test_open_mfdataset_data_vars_separate_times_multiple_times_keep_particles(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex", "Particles_Px_electron_beam"], separate_times=True, @@ -813,35 +646,37 @@ def test_open_mfdataset_data_vars_separate_times_multiple_times_keep_particles() assert particle_px_coords["ID_electron_beam"] == 1440 -def test_open_dataset_deck_path_default(): - with open_dataset(TEST_FILES_DIR / "0000.sdf") as df: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_open_dataset_deck_path_default(xrlib: XRLibrary): + with xrlib.open_dataset(TEST_FILES_DIR / "0000.sdf") as df: assert "deck" in df.attrs -def test_open_dataset_deck_path_failed(): - with ( - pytest.raises(FileNotFoundError), - open_dataset(TEST_FILES_DIR / "0000.sdf", deck_path="non_existent.deck"), - ): - pass +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_open_dataset_deck_path_failed(xrlib: XRLibrary): + with pytest.raises(FileNotFoundError): + xrlib.open_dataset(TEST_FILES_DIR / "0000.sdf", deck_path="non_existent.deck") -def test_open_dataset_deck_path_relative(): - with open_dataset(TEST_FILES_DIR / "0000.sdf", deck_path="input.deck") as df: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_open_dataset_deck_path_relative(xrlib: XRLibrary): + with xrlib.open_dataset(TEST_FILES_DIR / "0000.sdf", deck_path="input.deck") as df: assert "deck" in df.attrs assert "constant" in df.attrs["deck"] -def test_open_dataset_deck_path_absolute(): - with open_dataset( +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_open_dataset_deck_path_absolute(xrlib: XRLibrary): + with xrlib.open_dataset( TEST_FILES_DIR / "0000.sdf", deck_path=TEST_FILES_DIR / "input.deck" ) as df: assert "deck" in df.attrs assert "constant" in df.attrs["deck"] -def test_open_dataset_deck_path_absolute_other_path(): - with open_dataset( +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_open_dataset_deck_path_absolute_other_path(xrlib: XRLibrary): + with xrlib.open_dataset( TEST_FILES_DIR / "0000.sdf", deck_path=TEST_3D_DIST_FN / "input.deck" ) as df: assert "deck" in df.attrs @@ -849,25 +684,19 @@ def test_open_dataset_deck_path_absolute_other_path(): def test_open_mfdataset_deck_path_default(): - with open_mfdataset( - TEST_FILES_DIR.glob("*.sdf"), - ) as df: + with sdfxr.open_mfdataset(TEST_FILES_DIR.glob("*.sdf")) as df: assert "deck" in df.attrs def test_open_mfdataset_deck_path_failed(): - with ( - pytest.raises(FileNotFoundError), - open_mfdataset( - TEST_FILES_DIR.glob("*.sdf"), - deck_path="non_existent.deck", - ), - ): - pass + with pytest.raises(FileNotFoundError): + sdfxr.open_mfdataset( + TEST_FILES_DIR.glob("*.sdf"), deck_path="non_existent.deck" + ) def test_open_mfdataset_deck_path_relative(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), deck_path="input.deck", ) as df: @@ -876,7 +705,7 @@ def test_open_mfdataset_deck_path_relative(): def test_open_mfdataset_deck_path_absolute(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), deck_path=TEST_FILES_DIR / "input.deck" ) as df: assert "deck" in df.attrs @@ -884,7 +713,7 @@ def test_open_mfdataset_deck_path_absolute(): def test_open_mfdataset_deck_path_absolute_other_path(): - with open_mfdataset( + with sdfxr.open_mfdataset( TEST_FILES_DIR.glob("*.sdf"), deck_path=TEST_3D_DIST_FN / "input.deck" ) as df: assert "deck" in df.attrs diff --git a/tests/test_datatree.py b/tests/test_datatree.py index 3d7b272..3b7b989 100644 --- a/tests/test_datatree.py +++ b/tests/test_datatree.py @@ -1,12 +1,12 @@ +from typing import Protocol + import numpy as np import numpy.testing as npt import pytest +import xarray as xr -from sdf_xarray import ( - download, - open_datatree, - open_mfdatatree, -) +import sdf_xarray as sdfxr +from sdf_xarray import download TEST_FILES_DIR = download.fetch_dataset("test_files_1D") TEST_MISMATCHED_FILES_DIR = download.fetch_dataset("test_mismatched_files") @@ -15,231 +15,220 @@ TEST_2D_PARTICLE_DATA = download.fetch_dataset("test_two_probes_2D") -def test_datatree_basic(): - dt = open_datatree(TEST_FILES_DIR / "0000.sdf") - # Electric field group and variable (multi-level from full_name hierarchy) - assert "/Electric_Field" in dt.groups - ex = dt["Electric_Field"]["Ex"] - assert "X_Grid_mid" in ex.coords - assert ex.coords["X_Grid_mid"].attrs["long_name"] == "X" +# Type hinting support +class XRLibrary(Protocol): + def open_datatree(self, *args, **kwargs) -> xr.DataTree: ... - # Particles should not be present by default - assert not any(g for g in dt.groups if g.endswith("/Particles")) +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_datatree_basic(xrlib: XRLibrary): + with xrlib.open_datatree(TEST_FILES_DIR / "0000.sdf") as dt: + # Electric field group and variable (multi-level from full_name hierarchy) + assert "/Electric_Field" in dt.groups + ex = dt["Electric_Field"]["Ex"] + assert "X_Grid_mid" in ex.coords + assert ex.coords["X_Grid_mid"].attrs["long_name"] == "X" -def test_datatree_attrs(): - dt = open_datatree(TEST_FILES_DIR / "0000.sdf") - assert dt.attrs != {} - assert dt.attrs["code_name"] == "Epoch1d" + # Particles should not be present by default + assert not any(g for g in dt.groups if g.endswith("/Particles")) -def test_datatree_constant_name_and_units(): - dt = open_datatree(TEST_FILES_DIR / "0000.sdf") - # Absorption group with constants - assert "/Absorption" in dt.groups - total = dt["Absorption"]["Total_Laser_Energy_Injected"] - assert total.attrs.get("units") == "J" - assert total.attrs.get("full_name") == "Absorption/Total Laser Energy Injected" +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_datatree_attrs(xrlib: XRLibrary): + with xrlib.open_datatree(TEST_FILES_DIR / "0000.sdf") as dt: + assert dt.attrs != {} + assert dt.attrs["code_name"] == "Epoch1d" -def test_datatree_coords(): - dt = open_datatree(TEST_FILES_DIR / "0010.sdf") - # dist_fn/px_py/Electron structure from full_name - var = dt["dist_fn"]["x_px"]["electron"] - assert "Px_x_px_electron" in var.coords - assert var.coords["Px_x_px_electron"].attrs["full_name"] == "Grid/x_px/electron" +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_datatree_constant_name_and_units(xrlib: XRLibrary): + with xrlib.open_datatree(TEST_FILES_DIR / "0000.sdf") as dt: + # Absorption group with constants + assert "/Absorption" in dt.groups + total = dt["Absorption"]["Total_Laser_Energy_Injected"] + assert total.attrs.get("units") == "J" + assert total.attrs.get("full_name") == "Absorption/Total Laser Energy Injected" -def test_datatree_particles(): - dt = open_datatree(TEST_FILES_DIR / "0010.sdf", keep_particles=True) - # Particles group appears when particles are kept - assert "/Particles" in dt.groups - # Particles_Px_proton -> Particles/Px/proton structure - px = dt["Particles"]["Px"]["proton"] - assert "X_Particles_proton" in px.coords - assert px.coords["X_Particles_proton"].attrs["long_name"] == "X" +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_datatree_coords(xrlib: XRLibrary): + with xrlib.open_datatree(TEST_FILES_DIR / "0010.sdf") as dt: + # dist_fn/px_py/Electron structure from full_name + var = dt["dist_fn"]["x_px"]["electron"] + assert "Px_x_px_electron" in var.coords + assert var.coords["Px_x_px_electron"].attrs["full_name"] == "Grid/x_px/electron" -def test_datatree_no_particles(): - dt = open_datatree(TEST_FILES_DIR / "0010.sdf", keep_particles=False) - # Particles group may still exist for non-point grid vars, - # but point-data variables like Px_proton must be absent. - particles_group_exists = any(g for g in dt.groups if g.endswith("/Particles")) - if particles_group_exists: - # Check that leaf groups don't have direct point-data - assert "proton" not in dt["Particles"].data_vars - else: - assert True +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_datatree_particles(xrlib: XRLibrary): + with xrlib.open_datatree(TEST_FILES_DIR / "0010.sdf", keep_particles=True) as dt: + # Particles group appears when particles are kept + assert "/Particles" in dt.groups + # Particles_Px_proton -> Particles/Px/proton structure + px = dt["Particles"]["Px"]["proton"] + assert "X_Particles_proton" in px.coords + assert px.coords["X_Particles_proton"].attrs["long_name"] == "X" -def test_datatree_multiple_files_one_time_dim(): - dt = open_mfdatatree(TEST_FILES_DIR.glob("*.sdf"), keep_particles=True) - # Electric_Field/Ex structure from full_name - ex = dt["Electric_Field"]["Ex"] - assert sorted(ex.coords) == sorted(("X_Grid_mid", "time")) - assert ex.shape == (11, 16) - - # Electric_Field/Ey structure - ey = dt["Electric_Field"]["Ey"] - assert sorted(ey.coords) == sorted(("X_Grid_mid", "time")) - assert ey.shape == (11, 16) - - # Particles/Px/proton structure - px_protons = dt["Particles"]["Px"]["proton"] - assert sorted(px_protons.coords) == sorted(("X_Particles_proton", "time")) - assert px_protons.shape == (11, 1920) - - # Particles/Weight/proton structure - weight_protons = dt["Particles"]["Weight"]["proton"] - assert sorted(weight_protons.coords) == sorted(("X_Particles_proton", "time")) - assert weight_protons.shape == (11, 1920) - - absorption = dt["Absorption"]["Total_Laser_Energy_Injected"] - # Single coordinate 'time' - assert tuple(absorption.coords) == ("time",) - assert absorption.shape == (11,) - - # Check values match baseline - ex_da = ex.isel(time=10) - ex_values = ex_da.values - ex_x_coords = ex_da.coords["X_Grid_mid"].values - time_values = np.array( - [ - 5.466993e-14, - 2.417504e-10, - 4.833915e-10, - 7.251419e-10, - 9.667830e-10, - 1.208533e-09, - 1.450175e-09, - 1.691925e-09, - 1.933566e-09, - 2.175316e-09, - 2.416958e-09, - ] - ) - - expected_ex = np.array( - [ - -3126528.4705715775, - -3249643.376122554, - -6827013.115662239, - -9350267.990220116, - -1643592.584873334, - -2044751.412071893, - -4342811.346661035, - -10420841.38402196, - -7038801.831545288, - 781649.3179168438, - 4476555.848531812, - 5873312.793856503, - -95930.60501570138, - -8977898.965479957, - -7951712.649878098, - -5655667.111713385, - ] - ) - expected_ex_coords = np.array( - [ - 1.72522447e-05, - 5.17567340e-05, - 8.62612233e-05, - 1.20765713e-04, - 1.55270202e-04, - 1.89774691e-04, - 2.24279181e-04, - 2.58783670e-04, - 2.93288159e-04, - 3.27792649e-04, - 3.62297138e-04, - 3.96801627e-04, - 4.31306117e-04, - 4.65810606e-04, - 5.00315095e-04, - 5.34819585e-04, - ] - ) +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_datatree_no_particles(xrlib: XRLibrary): + with xrlib.open_datatree(TEST_FILES_DIR / "0010.sdf", keep_particles=False) as dt: + # Particles group may still exist for non-point grid vars, + # but point-data variables like Px_proton must be absent. + particles_group_exists = any(g for g in dt.groups if g.endswith("/Particles")) + if particles_group_exists: + # Check that leaf groups don't have direct point-data + assert "proton" not in dt["Particles"].data_vars + else: + assert True - # time coordinate available on variables in DataTree - npt.assert_allclose(time_values, ex.coords["time"].values, rtol=1e-6) - npt.assert_allclose(ex_values, expected_ex) - npt.assert_allclose(ex_x_coords, expected_ex_coords) + +def test_datatree_multiple_files_one_time_dim(): + with sdfxr.open_mfdatatree(TEST_FILES_DIR.glob("*.sdf"), keep_particles=True) as dt: + # Electric_Field/Ex structure from full_name + ex = dt["Electric_Field"]["Ex"] + assert sorted(ex.coords) == sorted(("X_Grid_mid", "time")) + assert ex.shape == (11, 16) + + # Electric_Field/Ey structure + ey = dt["Electric_Field"]["Ey"] + assert sorted(ey.coords) == sorted(("X_Grid_mid", "time")) + assert ey.shape == (11, 16) + + # Particles/Px/proton structure + px_protons = dt["Particles"]["Px"]["proton"] + assert sorted(px_protons.coords) == sorted(("X_Particles_proton", "time")) + assert px_protons.shape == (11, 1920) + + # Particles/Weight/proton structure + weight_protons = dt["Particles"]["Weight"]["proton"] + assert sorted(weight_protons.coords) == sorted(("X_Particles_proton", "time")) + assert weight_protons.shape == (11, 1920) + + absorption = dt["Absorption"]["Total_Laser_Energy_Injected"] + # Single coordinate 'time' + assert tuple(absorption.coords) == ("time",) + assert absorption.shape == (11,) + + # Check values match baseline + ex_da = ex.isel(time=10) + ex_values = ex_da.values + ex_x_coords = ex_da.coords["X_Grid_mid"].values + time_values = np.array( + [ + 5.466993e-14, + 2.417504e-10, + 4.833915e-10, + 7.251419e-10, + 9.667830e-10, + 1.208533e-09, + 1.450175e-09, + 1.691925e-09, + 1.933566e-09, + 2.175316e-09, + 2.416958e-09, + ] + ) + + expected_ex = np.array( + [ + -3126528.4705715775, + -3249643.376122554, + -6827013.115662239, + -9350267.990220116, + -1643592.584873334, + -2044751.412071893, + -4342811.346661035, + -10420841.38402196, + -7038801.831545288, + 781649.3179168438, + 4476555.848531812, + 5873312.793856503, + -95930.60501570138, + -8977898.965479957, + -7951712.649878098, + -5655667.111713385, + ] + ) + expected_ex_coords = np.array( + [ + 1.72522447e-05, + 5.17567340e-05, + 8.62612233e-05, + 1.20765713e-04, + 1.55270202e-04, + 1.89774691e-04, + 2.24279181e-04, + 2.58783670e-04, + 2.93288159e-04, + 3.27792649e-04, + 3.62297138e-04, + 3.96801627e-04, + 4.31306117e-04, + 4.65810606e-04, + 5.00315095e-04, + 5.34819585e-04, + ] + ) + + # time coordinate available on variables in DataTree + npt.assert_allclose(time_values, ex.coords["time"].values, rtol=1e-6) + npt.assert_allclose(ex_values, expected_ex) + npt.assert_allclose(ex_x_coords, expected_ex_coords) def test_datatree_multiple_files_multiple_time_dims(): - dt = open_mfdatatree( + with sdfxr.open_mfdatatree( TEST_FILES_DIR.glob("*.sdf"), separate_times=True, keep_particles=True - ) - # With this dataset, Ex and Ey share the same time dimension - assert list(dt["Electric_Field"]["Ex"].coords) == list( - dt["Electric_Field"]["Ey"].coords - ) - assert dt["Electric_Field"]["Ex"].shape == (11, 16) - assert dt["Electric_Field"]["Ey"].shape == (11, 16) - assert dt["Particles"]["Px"]["proton"].shape == (1, 1920) - assert dt["Particles"]["Weight"]["proton"].shape == (2, 1920) - assert dt["Absorption"]["Total_Laser_Energy_Injected"].shape == (11,) - - -def test_datatree_time_dim(): - dt = open_mfdatatree(TEST_FILES_DIR.glob("*.sdf")) - # Access time from a representative variable's coords - time = dt["Electric_Field"]["Ex"].coords["time"] - assert time.units == "s" - assert time.long_name == "Time" - assert time.full_name == "time" - - time_values = np.array( - [ - 5.466993e-14, - 2.417504e-10, - 4.833915e-10, - 7.251419e-10, - 9.667830e-10, - 1.208533e-09, - 1.450175e-09, - 1.691925e-09, - 1.933566e-09, - 2.175316e-09, - 2.416958e-09, - ] - ) - npt.assert_allclose(time_values, time.values, rtol=1e-6) + ) as dt: + # With this dataset, Ex and Ey share the same time dimension + assert list(dt["Electric_Field"]["Ex"].coords) == list( + dt["Electric_Field"]["Ey"].coords + ) + assert dt["Electric_Field"]["Ex"].shape == (11, 16) + assert dt["Electric_Field"]["Ey"].shape == (11, 16) + assert dt["Particles"]["Px"]["proton"].shape == (1, 1920) + assert dt["Particles"]["Weight"]["proton"].shape == (2, 1920) + assert dt["Absorption"]["Total_Laser_Energy_Injected"].shape == (11,) def test_datatree_latex_rename_variables(): - dt = open_mfdatatree(TEST_ARRAYS_DIR.glob("*.sdf"), keep_particles=True) - assert dt["Electric_Field"]["Ex"].attrs["long_name"] == "Electric Field $E_x$" - assert dt["Electric_Field"]["Ey"].attrs["long_name"] == "Electric Field $E_y$" - assert dt["Electric_Field"]["Ez"].attrs["long_name"] == "Electric Field $E_z$" - assert dt["Magnetic_Field"]["Bx"].attrs["long_name"] == "Magnetic Field $B_x$" - assert dt["Magnetic_Field"]["By"].attrs["long_name"] == "Magnetic Field $B_y$" - assert dt["Magnetic_Field"]["Bz"].attrs["long_name"] == "Magnetic Field $B_z$" - assert ( - dt["Particles"]["Px"]["Electron"].attrs["long_name"] - == "Particles $P_x$ Electron" - ) + with sdfxr.open_mfdatatree( + TEST_ARRAYS_DIR.glob("*.sdf"), keep_particles=True + ) as dt: + assert dt["Electric_Field"]["Ex"].attrs["long_name"] == "Electric Field $E_x$" + assert dt["Electric_Field"]["Ey"].attrs["long_name"] == "Electric Field $E_y$" + assert dt["Electric_Field"]["Ez"].attrs["long_name"] == "Electric Field $E_z$" + assert dt["Magnetic_Field"]["Bx"].attrs["long_name"] == "Magnetic Field $B_x$" + assert dt["Magnetic_Field"]["By"].attrs["long_name"] == "Magnetic Field $B_y$" + assert dt["Magnetic_Field"]["Bz"].attrs["long_name"] == "Magnetic Field $B_z$" + assert ( + dt["Particles"]["Px"]["Electron"].attrs["long_name"] + == "Particles $P_x$ Electron" + ) def test_datatree_open_mfdatatree_data_vars_single(): - dt = open_mfdatatree(TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex"]) - # Variable should be present under Electric_Field/Ex structure - assert "Ex" in dt["Electric_Field"].data_vars - # A different variable should not be anywhere - assert "Ey" not in dt["Electric_Field"].data_vars + with sdfxr.open_mfdatatree( + TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex"] + ) as dt: + # Variable should be present under Electric_Field/Ex structure + assert "Ex" in dt["Electric_Field"].data_vars + # A different variable should not be anywhere + assert "Ey" not in dt["Electric_Field"].data_vars def test_datatree_open_mfdatatree_data_vars_multiple(): - dt = open_mfdatatree( + with sdfxr.open_mfdatatree( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex", "Electric_Field_Ey"], - ) - assert "Ex" in dt["Electric_Field"].data_vars - assert "Ey" in dt["Electric_Field"].data_vars + ) as dt: + assert "Ex" in dt["Electric_Field"].data_vars + assert "Ey" in dt["Electric_Field"].data_vars def test_datatree_open_mfdatatree_data_vars_sparse_multiple(): - dt = open_mfdatatree( + with sdfxr.open_mfdatatree( TEST_FILES_DIR.glob("*.sdf"), keep_particles=True, data_vars=[ @@ -247,101 +236,103 @@ def test_datatree_open_mfdatatree_data_vars_sparse_multiple(): "Electric_Field_Ez", "dist_fn_x_px_proton", ], - ) - # Check presence under corresponding multi-level groups - # Particles_Particles_Per_Cell_proton -> Particles/Particles_Per_Cell/proton - assert "Particles_Per_Cell" in [g.split("/")[-1] for g in dt["Particles"].groups] - # Electric_Field_Ez -> Electric_Field/Ez - assert "Ez" in dt["Electric_Field"].data_vars - # dist_fn_x_px_proton -> dist_fn/x_px/proton - assert "x_px" in [g.split("/")[-1] for g in dt["dist_fn"].groups] + ) as dt: + # Check presence under corresponding multi-level groups + # Particles_Particles_Per_Cell_proton -> Particles/Particles_Per_Cell/proton + assert "Particles_Per_Cell" in [ + g.split("/")[-1] for g in dt["Particles"].groups + ] + # Electric_Field_Ez -> Electric_Field/Ez + assert "Ez" in dt["Electric_Field"].data_vars + # dist_fn_x_px_proton -> dist_fn/x_px/proton + assert "x_px" in [g.split("/")[-1] for g in dt["dist_fn"].groups] def test_datatree_open_mfdatatree_data_vars_time(): - dt = open_mfdatatree(TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex"]) - # Time coordinate exists on the variable (Electric_Field/Ex structure) - assert "time" in dt["Electric_Field"]["Ex"].coords + with sdfxr.open_mfdatatree( + TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex"] + ) as dt: + # Time coordinate exists on the variable (Electric_Field/Ex structure) + assert "time" in dt["Electric_Field"]["Ex"].coords def test_datatree_open_mfdatatree_data_vars_sparse_time(): - dt = open_mfdatatree( + with sdfxr.open_mfdatatree( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Particles_Particles_Per_Cell_proton"], - ) - # Particles_Particles_Per_Cell_proton -> Particles/Particles_Per_Cell/proton - assert "time" in dt["Particles"]["Particles_Per_Cell"]["proton"].coords + ) as dt: + # Particles_Particles_Per_Cell_proton -> Particles/Particles_Per_Cell/proton + assert "time" in dt["Particles"]["Particles_Per_Cell"]["proton"].coords def test_datatree_open_mfdatatree_data_vars_separate_times_single(): - dt = open_mfdatatree( + with sdfxr.open_mfdatatree( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex"], separate_times=True, - ) - assert dt["Electric_Field"]["Ex"].shape[0] == 11 + ) as dt: + assert dt["Electric_Field"]["Ex"].shape[0] == 11 def test_datatree_open_mfdatatree_data_vars_separate_times_multiple(): - dt = open_mfdatatree( + with sdfxr.open_mfdatatree( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex", "Electric_Field_Ey"], separate_times=True, - ) - # Shapes may differ by time dims when separate_times=True - assert dt["Electric_Field"]["Ex"].shape[0] >= 1 - assert dt["Electric_Field"]["Ey"].shape[0] >= 1 + ) as dt: + # Shapes may differ by time dims when separate_times=True + assert dt["Electric_Field"]["Ex"].shape[0] >= 1 + assert dt["Electric_Field"]["Ey"].shape[0] >= 1 def test_datatree_open_mfdatatree_data_vars_separate_times_multiple_times_keep_particles(): - dt = open_mfdatatree( + with sdfxr.open_mfdatatree( TEST_FILES_DIR.glob("*.sdf"), data_vars=["Electric_Field_Ex", "Particles_Px_electron_beam"], separate_times=True, keep_particles=True, - ) - assert dt["Electric_Field"]["Ex"].shape[0] >= 1 - # Particles_Px_electron_beam -> Particles/Px/electron_beam - assert dt["Particles"]["Px"]["electron_beam"].shape[0] >= 1 - - -# Parity for mismatched jobid behaviour + ) as dt: + assert dt["Electric_Field"]["Ex"].shape[0] >= 1 + # Particles_Px_electron_beam -> Particles/Px/electron_beam + assert dt["Particles"]["Px"]["electron_beam"].shape[0] >= 1 def test_datatree_erroring_on_mismatched_jobid_files(): with pytest.raises(ValueError): # noqa: PT011 - # open_mfdatatree uses open_mfdataset under the hood with SDFPreprocess - open_mfdatatree(TEST_MISMATCHED_FILES_DIR.glob("*.sdf")) + sdfxr.open_mfdatatree(TEST_MISMATCHED_FILES_DIR.glob("*.sdf")) -def test_open_datatree_deck_path_default(): - with open_datatree(TEST_FILES_DIR / "0000.sdf") as dt: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_open_datatree_deck_path_default(xrlib: XRLibrary): + with xrlib.open_datatree(TEST_FILES_DIR / "0000.sdf") as dt: assert "deck" in dt.attrs -def test_open_datatree_deck_path_failed(): - with ( - pytest.raises(FileNotFoundError), - open_datatree(TEST_FILES_DIR / "0000.sdf", deck_path="non_existent.deck"), - ): - pass +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_open_datatree_deck_path_failed(xrlib: XRLibrary): + with pytest.raises(FileNotFoundError): + xrlib.open_datatree(TEST_FILES_DIR / "0000.sdf", deck_path="non_existent.deck") -def test_open_datatree_deck_path_relative(): - with open_datatree(TEST_FILES_DIR / "0000.sdf", deck_path="input.deck") as dt: +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_open_datatree_deck_path_relative(xrlib: XRLibrary): + with xrlib.open_datatree(TEST_FILES_DIR / "0000.sdf", deck_path="input.deck") as dt: assert "deck" in dt.attrs assert "constant" in dt.attrs["deck"] -def test_open_datatree_deck_path_absolute(): - with open_datatree( +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_open_datatree_deck_path_absolute(xrlib: XRLibrary): + with xrlib.open_datatree( TEST_FILES_DIR / "0000.sdf", deck_path=TEST_FILES_DIR / "input.deck" ) as dt: assert "deck" in dt.attrs assert "constant" in dt.attrs["deck"] -def test_open_datatree_deck_path_absolute_other_path(): - with open_datatree( +@pytest.mark.parametrize("xrlib", [xr, sdfxr]) +def test_open_datatree_deck_path_absolute_other_path(xrlib: XRLibrary): + with xrlib.open_datatree( TEST_FILES_DIR / "0000.sdf", deck_path=TEST_3D_DIST_FN / "input.deck" ) as dt: assert "deck" in dt.attrs @@ -349,23 +340,19 @@ def test_open_datatree_deck_path_absolute_other_path(): def test_open_mfdatatree_deck_path_default(): - with open_mfdatatree(TEST_FILES_DIR.glob("*.sdf")) as dt: + with sdfxr.open_mfdatatree(TEST_FILES_DIR.glob("*.sdf")) as dt: assert "deck" in dt.attrs def test_open_mfdatatree_deck_path_failed(): - with ( - pytest.raises(FileNotFoundError), - open_mfdatatree( - TEST_FILES_DIR.glob("*.sdf"), - deck_path="non_existent.deck", - ), - ): - pass + with pytest.raises(FileNotFoundError): + sdfxr.open_mfdatatree( + TEST_FILES_DIR.glob("*.sdf"), deck_path="non_existent.deck" + ) def test_open_mfdatatree_deck_path_relative(): - with open_mfdatatree( + with sdfxr.open_mfdatatree( TEST_FILES_DIR.glob("*.sdf"), deck_path="input.deck", ) as dt: @@ -374,7 +361,7 @@ def test_open_mfdatatree_deck_path_relative(): def test_open_mfdatatree_deck_path_absolute(): - with open_mfdatatree( + with sdfxr.open_mfdatatree( TEST_FILES_DIR.glob("*.sdf"), deck_path=TEST_FILES_DIR / "input.deck" ) as dt: assert "deck" in dt.attrs @@ -382,7 +369,7 @@ def test_open_mfdatatree_deck_path_absolute(): def test_open_mfdatatree_deck_path_absolute_other_path(): - with open_mfdatatree( + with sdfxr.open_mfdatatree( TEST_FILES_DIR.glob("*.sdf"), deck_path=TEST_3D_DIST_FN / "input.deck" ) as dt: assert "deck" in dt.attrs diff --git a/tests/test_epoch_dataarray_accessor.py b/tests/test_epoch_dataarray_accessor.py index 0cc9c50..9069a74 100644 --- a/tests/test_epoch_dataarray_accessor.py +++ b/tests/test_epoch_dataarray_accessor.py @@ -8,8 +8,9 @@ from matplotlib.animation import PillowWriter from packaging.version import Version +import sdf_xarray as sdfxr import sdf_xarray.plotting as sxp -from sdf_xarray import SDFPreprocess, download, open_mfdataset +from sdf_xarray import SDFPreprocess, download mpl.use("Agg") @@ -33,8 +34,22 @@ def test_animation_accessor(): assert hasattr(array.epoch, "animate") -def test_animate_headless(): - with open_mfdataset(TEST_FILES_DIR_1D.glob("*.sdf")) as ds: +@pytest.mark.parametrize( + ("xrlib", "params"), + [ + ( + xr, + { + "compat": "no_conflicts", + "join": "outer", + "preprocess": SDFPreprocess(), + }, + ), + (sdfxr, {}), + ], +) +def test_animate_headless(xrlib, params): + with xrlib.open_mfdataset(TEST_FILES_DIR_1D.glob("*.sdf"), **params) as ds: anim = ds["Derived_Number_Density_electron"].epoch.animate() # Specify a custom writable temporary directory @@ -46,25 +61,7 @@ def test_animate_headless(): pytest.fail(f"animate().save() failed in headless mode: {e}") -def test_xr_animate_headless(): - with xr.open_mfdataset( - TEST_FILES_DIR_1D.glob("*.sdf"), - compat="no_conflicts", - join="outer", - preprocess=SDFPreprocess(), - ) as ds: - anim = ds["Derived_Number_Density_electron"].epoch.animate() - - # Specify a custom writable temporary directory - with tempfile.TemporaryDirectory() as temp_dir: - temp_file_path = f"{temp_dir}/output.gif" - try: - anim.save(temp_file_path, writer=PillowWriter(fps=2)) - except Exception as e: - pytest.fail(f"animate().save() failed in headless mode: {e}") - - -def test_xr_get_frame_title_no_optional_params(): +def test_get_frame_title_no_optional_params(): with xr.open_mfdataset( TEST_FILES_DIR_1D.glob("*.sdf"), compat="no_conflicts", @@ -77,7 +74,7 @@ def test_xr_get_frame_title_no_optional_params(): assert expected_result == result -def test_xr_get_frame_title_sdf_name(): +def test_get_frame_title_sdf_name(): with xr.open_mfdataset( TEST_FILES_DIR_1D.glob("*.sdf"), compat="no_conflicts", @@ -90,7 +87,7 @@ def test_xr_get_frame_title_sdf_name(): assert expected_result == result -def test_xr_get_frame_title_custom_title(): +def test_get_frame_title_custom_title(): with xr.open_mfdataset( TEST_FILES_DIR_1D.glob("*.sdf"), compat="no_conflicts", @@ -103,7 +100,7 @@ def test_xr_get_frame_title_custom_title(): assert expected_result == result -def test_xr_get_frame_title_custom_title_and_sdf_name(): +def test_get_frame_title_custom_title_and_sdf_name(): with xr.open_mfdataset( TEST_FILES_DIR_1D.glob("*.sdf"), compat="no_conflicts", @@ -126,7 +123,7 @@ def test_get_frame_title_Z_Grid_mid(): assert expected_result == result -def test_xr_calculate_window_boundaries_1D(): +def test_calculate_window_boundaries_1D(): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -142,7 +139,7 @@ def test_xr_calculate_window_boundaries_1D(): assert result == pytest.approx(expected_result, abs=0.1) -def test_xr_calculate_window_boundaries_2D(): +def test_calculate_window_boundaries_2D(): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -158,7 +155,7 @@ def test_xr_calculate_window_boundaries_2D(): assert result == pytest.approx(expected_result, abs=0.1) -def test_xr_calculate_window_boundaries_1D_xlim(): +def test_calculate_window_boundaries_1D_xlim(): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -174,7 +171,7 @@ def test_xr_calculate_window_boundaries_1D_xlim(): assert result == pytest.approx(expected_result, abs=0.1) -def test_xr_calculate_window_boundaries_2D_xlim(): +def test_calculate_window_boundaries_2D_xlim(): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -190,7 +187,7 @@ def test_xr_calculate_window_boundaries_2D_xlim(): assert result == pytest.approx(expected_result, abs=0.1) -def test_xr_compute_global_limits(): +def test_compute_global_limits(): with xr.open_mfdataset( TEST_FILES_DIR_1D.glob("*.sdf"), compat="no_conflicts", @@ -206,7 +203,7 @@ def test_xr_compute_global_limits(): assert result_max == pytest.approx(expected_result_max, abs=1e19) -def test_xr_compute_global_limits_percentile(): +def test_compute_global_limits_percentile(): with xr.open_mfdataset( TEST_FILES_DIR_1D.glob("*.sdf"), compat="no_conflicts", @@ -222,7 +219,7 @@ def test_xr_compute_global_limits_percentile(): assert result_max == pytest.approx(expected_result_max, abs=1e18) -def test_xr_compute_global_limits_NaNs(): +def test_compute_global_limits_NaNs(): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(),