diff --git a/.gitignore b/.gitignore index bc584112..4551f2d9 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,4 @@ coverage.xml # vim *.sw? +/.vscode diff --git a/CHANGELOG.md b/CHANGELOG.md index 9597d827..d845a3f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/). ## [Unreleased] +### Changed +- `Data.norm_for_each`: allow for normlization for each combination of a set of variables + ## [3.6.1] ### Changed diff --git a/WrightTools/data/_data.py b/WrightTools/data/_data.py index 6b46a1ce..1fd767fc 100644 --- a/WrightTools/data/_data.py +++ b/WrightTools/data/_data.py @@ -749,7 +749,7 @@ def get_axis(self, hint: str | int | Axis) -> Axis: def norm_for_each( self, - var: str | Variable | int, + *vars: str | Variable | int, channel: str | Channel | int = 0, new_channel: dict = {}, ): @@ -758,8 +758,8 @@ def norm_for_each( Parameters ---------- - var : str, int, or WrightTools.data.Variable - variable to apply normalization at each unique point. + *vars : str, int, or WrightTools.data.Variable + variable to apply normalization at each unique point. if many variables are given, normalization occurs across the joint pairs of values channel : str, int or WrightTools.data.Channel (default 0) channel to apply normalization. Channel should have more non-trivial dimensions than variable new_channel : dict @@ -767,22 +767,22 @@ def norm_for_each( If not empty, a new channel will be created. Fields (e.g. name) can be supplied by supplying a dictionary (consult `Data.create_channel`). - Examples -------- import WrightTools.datasets as ds import WrightTools as wt d = wt.open(ds.wt5.v1p0p1_MoS2_TrEE_movie) - d.norm_for_each("d2", "ai0") # equivalent to d.ai0[:] /= d.ai0[:].max(axis=(0,1))[None, None, :] + d.norm_for_each("d2", channel="ai0") # equivalent to d.ai0[:] /= d.ai0[:].max(axis=(0,1))[None, None, :] """ - variable = self.get_var(var) + variables = [self.get_var(var) for var in vars] channel = self.get_channel(channel) - trivial = {i for i, si in enumerate(variable.shape) if si == 1} + joint_shape = [max([v.shape[i] for v in variables]) for i in range(self.ndim)] + trivial = {i for i, si in enumerate(joint_shape) if si == 1} if not trivial: raise wt_exceptions.WrightToolsWarning( - f"Variable {variable.natural_name} and Channel {channel.natural_name} have the same shape {variable.shape}. " + f"variable(s) {[var.natural_name for var in variables]} and Channel {channel.natural_name} have the same shape {channel.shape}. " + "Produces a ones array channel." ) # nontrivial = tuple({i for i in range(self.ndim)} - trivial) @@ -793,7 +793,10 @@ def norm_for_each( if not isinstance(new_channel, dict): new_channel = {} self.create_channel( - new_channel.pop("name", f"{channel.natural_name}_{variable.natural_name}_norm"), + new_channel.pop( + "name", + f"{channel.natural_name}_{''.join([f'v{self.variable_names.index(v.natural_name)}' for v in variables])}_norm", + ), values=new, **new_channel, ) diff --git a/tests/data/norm_for_each.py b/tests/data/norm_for_each.py index d3f86673..846f621d 100644 --- a/tests/data/norm_for_each.py +++ b/tests/data/norm_for_each.py @@ -6,18 +6,26 @@ def test_3D(): data = wt.open(datasets.wt5.v1p0p1_MoS2_TrEE_movie) - data.norm_for_each("w1", 0) + data.norm_for_each("w1", channel=0) assert np.all(data.channels[0][:].max(axis=(0, 2)) == 1) - data.norm_for_each("d2", 0, new_channel={"name": "ai0_d2_norm"}) + data.norm_for_each("d2", new_channel={"name": "ai0_d2_norm"}) assert np.all(data.channels[-1][:].max(axis=(0, 1)) == 1) - data.norm_for_each("d1", 0, new_channel=True) + data.norm_for_each("d1", new_channel=True) + assert data.channels[-1].natural_name == "ai0_v6_norm" - data.norm_for_each("d1", 0, new_channel={"name": "ai0_d1_norm1"}) + data.norm_for_each("d1", new_channel={"name": "ai0_d1_norm"}) data.channels[0].normalize() - assert np.all(np.isclose(data.ai0_d1_norm[:], data.channels[0][:])) + assert np.all(np.isclose(data.channels[-1][:], data.channels[0][:])) + + +def test_two_vars(): + data = wt.open(datasets.wt5.v1p0p1_MoS2_TrEE_movie) + data.norm_for_each("w1", "d2") + assert np.all(data.channels[0][:].max(axis=0) == 1) if __name__ == "__main__": test_3D() + test_two_vars()