Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ coverage.xml

# vim
*.sw?
/.vscode
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

## [Unreleased]

### Changed
- `Data.norm_for_each`: allow for normlization for each combination of a set of variables

## [3.6.1]

### Changed
Expand Down
21 changes: 12 additions & 9 deletions WrightTools/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def get_axis(self, hint: str | int | Axis) -> Axis:

def norm_for_each(
self,
var: str | Variable | int,
*vars: str | Variable | int,
channel: str | Channel | int = 0,
new_channel: dict = {},
):
Expand All @@ -758,31 +758,31 @@ def norm_for_each(

Parameters
----------
var : str, int, or WrightTools.data.Variable
variable to apply normalization at each unique point.
*vars : str, int, or WrightTools.data.Variable
variable to apply normalization at each unique point. if many variables are given, normalization occurs across the joint pairs of values
channel : str, int or WrightTools.data.Channel (default 0)
channel to apply normalization. Channel should have more non-trivial dimensions than variable
new_channel : dict
Default is empty, and channel is overwriten with norm values.
If not empty, a new channel will be created.
Fields (e.g. name) can be supplied by supplying a dictionary (consult `Data.create_channel`).


Examples
--------
import WrightTools.datasets as ds
import WrightTools as wt

d = wt.open(ds.wt5.v1p0p1_MoS2_TrEE_movie)
d.norm_for_each("d2", "ai0") # equivalent to d.ai0[:] /= d.ai0[:].max(axis=(0,1))[None, None, :]
d.norm_for_each("d2", channel="ai0") # equivalent to d.ai0[:] /= d.ai0[:].max(axis=(0,1))[None, None, :]

"""
variable = self.get_var(var)
variables = [self.get_var(var) for var in vars]
channel = self.get_channel(channel)
trivial = {i for i, si in enumerate(variable.shape) if si == 1}
joint_shape = [max([v.shape[i] for v in variables]) for i in range(self.ndim)]
trivial = {i for i, si in enumerate(joint_shape) if si == 1}
if not trivial:
raise wt_exceptions.WrightToolsWarning(
f"Variable {variable.natural_name} and Channel {channel.natural_name} have the same shape {variable.shape}. "
f"variable(s) {[var.natural_name for var in variables]} and Channel {channel.natural_name} have the same shape {channel.shape}. "
+ "Produces a ones array channel."
)
# nontrivial = tuple({i for i in range(self.ndim)} - trivial)
Expand All @@ -793,7 +793,10 @@ def norm_for_each(
if not isinstance(new_channel, dict):
new_channel = {}
self.create_channel(
new_channel.pop("name", f"{channel.natural_name}_{variable.natural_name}_norm"),
new_channel.pop(
"name",
f"{channel.natural_name}_{''.join([f'v{self.variable_names.index(v.natural_name)}' for v in variables])}_norm",
),
values=new,
**new_channel,
)
Expand Down
18 changes: 13 additions & 5 deletions tests/data/norm_for_each.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@
def test_3D():
data = wt.open(datasets.wt5.v1p0p1_MoS2_TrEE_movie)

data.norm_for_each("w1", 0)
data.norm_for_each("w1", channel=0)
assert np.all(data.channels[0][:].max(axis=(0, 2)) == 1)

data.norm_for_each("d2", 0, new_channel={"name": "ai0_d2_norm"})
data.norm_for_each("d2", new_channel={"name": "ai0_d2_norm"})
assert np.all(data.channels[-1][:].max(axis=(0, 1)) == 1)

data.norm_for_each("d1", 0, new_channel=True)
data.norm_for_each("d1", new_channel=True)
assert data.channels[-1].natural_name == "ai0_v6_norm"

data.norm_for_each("d1", 0, new_channel={"name": "ai0_d1_norm1"})
data.norm_for_each("d1", new_channel={"name": "ai0_d1_norm"})
data.channels[0].normalize()
assert np.all(np.isclose(data.ai0_d1_norm[:], data.channels[0][:]))
assert np.all(np.isclose(data.channels[-1][:], data.channels[0][:]))


def test_two_vars():
data = wt.open(datasets.wt5.v1p0p1_MoS2_TrEE_movie)
data.norm_for_each("w1", "d2")
assert np.all(data.channels[0][:].max(axis=0) == 1)


if __name__ == "__main__":
test_3D()
test_two_vars()