Skip to content

Enhance the capability of Bridgescaler for supporting tensors#23

Open
kevinyang-cky wants to merge 34 commits intoNCAR:mainfrom
kevinyang-cky:main
Open

Enhance the capability of Bridgescaler for supporting tensors#23
kevinyang-cky wants to merge 34 commits intoNCAR:mainfrom
kevinyang-cky:main

Conversation

@kevinyang-cky
Copy link
Collaborator

This PR addresses the following things: (probably not dive into each commit as I reverted a couple of things during the development, see the latest version of the files)

  • Support saving out and reading in a distributed scaler for tensors: print_scaler_tensor() and read_scaler_tensor() in backend_tensor.py do that.
  • Add PyTorch library check: The basic idea is that if a user does not have PyTorch installed in the environment, Bridgescaler can still function properly. Errors will be raised if a user wants to use distributed scalers for tensors but does not have PyTorch installed or if the required version is not met.
  • Tensors placement in distributed_tensor.py: ensure input tensors and the following fitting or transforming calculation stay on the same device.
  • Code optimization for distributed_tensor.py: remove for-loop over channels and use vectorization instead.

Unit tests passed, and I suggest CREDIT to use this version of Bridgescaler moving on.

Here is an example for using distributed scalers for tensors, happy to put it into the docs if @djgagne can point to a place to include it.

import numpy as np
import pandas as pd
import torch

from bridgescaler.distributed_tensor import DStandardScalerTensor
from bridgescaler import print_scaler_tensor, read_scaler_tensor

# create synthetic data
x_1 = np.random.normal(0, 2.2, (20, 5, 4, 8))
x_2 = np.random.normal(1, 3.5, (25, 4, 8, 5))

# fitting and transform
dss_1_tensor = DStandardScalerTensor(channels_last=False)
dss_2_tensor = DStandardScalerTensor(channels_last=True)
dss_1_tensor.fit(torch.from_numpy(x_1))
dss_2_tensor.fit(torch.from_numpy(x_2))
dss_combined_tensor = dss_1_tensor + dss_2_tensor

dss_combined_tensor.transform(torch.from_numpy(x_1), channels_last=False)

# save out scalers and read back in
scaler_list = [dss_1_tensor, dss_2_tensor]
df = pd.DataFrame({"scalers": [print_scaler_tensor(s) for s in scaler_list]})
df.to_parquet("scalers.parquet")
df_new = pd.read_parquet("scalers.parquet")
scaler_objs = df_new["scalers"].apply(read_scaler_tensor)
total_scaler = scaler_objs.sum()

@kevinyang-cky kevinyang-cky requested a review from djgagne February 4, 2026 17:13
@djgagne
Copy link
Collaborator

djgagne commented Feb 9, 2026

@kevinyang-cky Your changes all look good code-wise. For the documentation, can you add a separate file on the tensor scalers to the directory https://github.com/NCAR/bridgescaler/tree/main/doc/source in a rst file? See other doc files for examples.

@charlie-becker
Copy link
Collaborator

This is coming together nicely!

However, I do not believe the ability to transform data with different channel order is working correctly. Please see the following tests that fail on my end (this is code added directly to the end of your test example above).

x3 = np.random.normal(0, 2.2, (20, 5, 44, 11))
x3_tensor = torch.from_numpy(x3)
x3_tensor.variable_names = ['a', 'b', 'c', 'd', 'e']

x4_tensor = torch.from_numpy(x3)
x4_tensor.variable_names = ['b', 'a', 'c', 'd', 'e'] # reverse the first and second channel dim

x3_transformed = total_scaler.transform(x3_tensor)
x4_transformed = total_scaler.transform(x4_tensor)

assert (x3_transformed[:, 2:, :, :] == x4_transformed[:, 2:, :, :]).all() ## passes
assert (x3_transformed[:, 0, :, :] == x4_transformed[:, 1, :, :]).all()   ## fails

@kevinyang-cky
Copy link
Collaborator Author

kevinyang-cky commented Feb 13, 2026

Thanks @charlie-becker for providing testing feedback. I think the issue is that x4_tensor's variable_names are reorder but the data is still in the same order as x3_tensor. Try changing this line of code x4_tensor = torch.from_numpy(x3) to x4_tensor = torch.from_numpy(x3[:,[1,0,2,3,4],:,:]), and it should give you a pass test. Let me know if you still run into issues.

I will continue working on incorporating more unit tests into the current test script today.

@charlie-becker
Copy link
Collaborator

@kevinyang-cky

Yup, you're exactly right! Passes no problem. Thank you for catching my testing bug!

@djgagne
Copy link
Collaborator

djgagne commented Feb 14, 2026

@kevinyang-cky I think the code looks good but will wait to approve and merge until you have added your remaining tests. I'm going to fix some of the docs and do some other library cleanup issues in my PR.

@kevinyang-cky
Copy link
Collaborator Author

@djgagne sounds good to me! I will tag you again when I have all the tests and the example adding into the docs. Have a great long weekend!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants