From 74d8ca281c231d5eda6a476282710c757df5beda Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Mon, 22 Dec 2025 14:21:41 -0700 Subject: [PATCH 01/30] add scalers for tensors --- bridgescaler/backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgescaler/backend.py b/bridgescaler/backend.py index adcfaae..1e335f6 100644 --- a/bridgescaler/backend.py +++ b/bridgescaler/backend.py @@ -3,6 +3,7 @@ from bridgescaler.group import GroupStandardScaler, GroupRobustScaler, GroupMinMaxScaler from bridgescaler.deep import DeepStandardScaler, DeepMinMaxScaler, DeepQuantileTransformer from bridgescaler.distributed import DStandardScaler, DMinMaxScaler, DQuantileScaler +from bridgescaler.distributed_tensor import DStandardScalerTensor, DMinMaxScalerTensor import numpy as np import json import pandas as pd @@ -27,6 +28,8 @@ "DStandardScaler": DStandardScaler, "DMinMaxScaler": DMinMaxScaler, "DQuantileScaler": DQuantileScaler, + "DStandardScalerTensor": DStandardScalerTensor, + "DMinMaxScalerTensor": DMinMaxScalerTensor, } From 858c5ece4f1af4932048b31f9581d5a138899461 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Tue, 23 Dec 2025 14:12:25 -0700 Subject: [PATCH 02/30] modified print_scaler() to convert tensors to NumPy arrays --- bridgescaler/backend.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bridgescaler/backend.py b/bridgescaler/backend.py index 1e335f6..67220a2 100644 --- a/bridgescaler/backend.py +++ b/bridgescaler/backend.py @@ -10,6 +10,7 @@ from numpy.lib.format import descr_to_dtype, dtype_to_descr from base64 import b64decode, b64encode from typing import Any +import torch scaler_objs = {"StandardScaler": StandardScaler, @@ -60,6 +61,12 @@ def print_scaler(scaler): """ scaler_params = scaler.__dict__ scaler_params["type"] = str(type(scaler))[1:-2].split(".")[-1] + + if "Tensor" in scaler_params["type"]: + for keys in scaler_params: + if type(scaler_params[keys]) == torch.Tensor: + scaler_params[keys] = scaler_params[keys].cpu().numpy().copy() + return json.dumps(scaler_params, indent=4, sort_keys=True, cls=NumpyEncoder) From 830d00cbbbd4f8411cad10758a8a7c992e3e85e9 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Tue, 23 Dec 2025 16:03:30 -0700 Subject: [PATCH 03/30] modified read_scaler() to convert NumPy arrays to tensors --- bridgescaler/backend.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bridgescaler/backend.py b/bridgescaler/backend.py index 67220a2..21c7bbb 100644 --- a/bridgescaler/backend.py +++ b/bridgescaler/backend.py @@ -90,12 +90,16 @@ def read_scaler(scaler_str): """ scaler_params = json.loads(scaler_str, object_hook=object_hook) scaler = scaler_objs[scaler_params["type"]]() + tensor_flag = "Tensor" in scaler_params["type"] del scaler_params["type"] for k, v in scaler_params.items(): if isinstance(v, dict) and v["object"] == "ndarray": setattr(scaler, k, np.array(v['data'], dtype=v['dtype']).reshape(v['shape'])) else: - setattr(scaler, k, v) + if tensor_flag: + setattr(scaler, k, torch.tensor(v)) + else: + setattr(scaler, k, v) return scaler From bede810f3acc2a88cefa82511bc914b58018cacf Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Fri, 2 Jan 2026 13:12:36 -0700 Subject: [PATCH 04/30] add PyTorch check and public API --- bridgescaler/__init__.py | 47 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/bridgescaler/__init__.py b/bridgescaler/__init__.py index 0e60896..53c82ad 100644 --- a/bridgescaler/__init__.py +++ b/bridgescaler/__init__.py @@ -1,4 +1,49 @@ +from importlib.metadata import version, PackageNotFoundError + +from packaging.version import parse + +# 1. Internal Constants & PyTorch Checks +REQUIRED_TORCH_VERSION = "2.0.0" + +def _get_torch_status(): + """Checks torch version via metadata without importing the module.""" + try: + installed_version = version("torch") + + if parse(installed_version) < parse(REQUIRED_TORCH_VERSION): + raise RuntimeError( + f"PyTorch >= {REQUIRED_TORCH_VERSION} required; found {installed_version}" + ) + + return True, installed_version + except PackageNotFoundError: + return False, None + +TORCH_AVAILABLE, TORCH_VERSION = _get_torch_status() + +# 2. Base Imports from .backend import save_scaler, load_scaler, print_scaler, read_scaler from .group import GroupStandardScaler, GroupRobustScaler, GroupMinMaxScaler from .deep import DeepStandardScaler, DeepMinMaxScaler, DeepQuantileTransformer -from .distributed import DStandardScaler, DMinMaxScaler, DQuantileScaler +from .distributed import (DStandardScaler, DMinMaxScaler, DQuantileScaler) + +# 3. Conditional Torch Imports +if TORCH_AVAILABLE: + from .distributed_tensor import ( + DStandardScalerTensor, + DMinMaxScalerTensor, + ) + +# 4. Define Public API +__all__ = [ + # Utilities + "save_scaler", "load_scaler", "print_scaler", "read_scaler", + "TORCH_AVAILABLE", + # Scalers + "GroupStandardScaler", "GroupRobustScaler", "GroupMinMaxScaler", + "DeepStandardScaler", "DeepMinMaxScaler", "DeepQuantileTransformer", + "DStandardScaler", "DMinMaxScaler", "DQuantileScaler", +] + +if TORCH_AVAILABLE: + __all__ += ["DStandardScalerTensor", "DMinMaxScalerTensor"] \ No newline at end of file From 6a0d1e229ccb394ac596ae29b29fc3097b3007fd Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Fri, 2 Jan 2026 14:12:07 -0700 Subject: [PATCH 05/30] remove PyTorch check (already in __init__.py) --- bridgescaler/distributed_tensor.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index a81f179..8b817a2 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -1,25 +1,7 @@ from copy import deepcopy -import importlib.util -from packaging import version import torch -REQUIRED_VERSION = "2.0.0" # required torch version - -# Check if PyTorch is installed -if importlib.util.find_spec("torch") is None: - raise ImportError("PyTorch is not installed") - -installed_version = torch.__version__ - -# Validate version -if version.parse(installed_version) < version.parse(REQUIRED_VERSION): - raise RuntimeError( - f"PyTorch version mismatch: required {REQUIRED_VERSION}, " - f"found {installed_version}" - ) - - class DBaseScalerTensor: """ Base distributed scaler class for tensor. Used only to store attributes and methods shared across all distributed From 0c7ccb4967e3fd040485a3a22c4a81ede421aa33 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Sun, 4 Jan 2026 20:10:37 -0700 Subject: [PATCH 06/30] revert changes to backend.py --- bridgescaler/backend.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/bridgescaler/backend.py b/bridgescaler/backend.py index 21c7bbb..4ef775b 100644 --- a/bridgescaler/backend.py +++ b/bridgescaler/backend.py @@ -3,14 +3,12 @@ from bridgescaler.group import GroupStandardScaler, GroupRobustScaler, GroupMinMaxScaler from bridgescaler.deep import DeepStandardScaler, DeepMinMaxScaler, DeepQuantileTransformer from bridgescaler.distributed import DStandardScaler, DMinMaxScaler, DQuantileScaler -from bridgescaler.distributed_tensor import DStandardScalerTensor, DMinMaxScalerTensor import numpy as np import json import pandas as pd from numpy.lib.format import descr_to_dtype, dtype_to_descr from base64 import b64decode, b64encode from typing import Any -import torch scaler_objs = {"StandardScaler": StandardScaler, @@ -29,8 +27,6 @@ "DStandardScaler": DStandardScaler, "DMinMaxScaler": DMinMaxScaler, "DQuantileScaler": DQuantileScaler, - "DStandardScalerTensor": DStandardScalerTensor, - "DMinMaxScalerTensor": DMinMaxScalerTensor, } @@ -61,12 +57,6 @@ def print_scaler(scaler): """ scaler_params = scaler.__dict__ scaler_params["type"] = str(type(scaler))[1:-2].split(".")[-1] - - if "Tensor" in scaler_params["type"]: - for keys in scaler_params: - if type(scaler_params[keys]) == torch.Tensor: - scaler_params[keys] = scaler_params[keys].cpu().numpy().copy() - return json.dumps(scaler_params, indent=4, sort_keys=True, cls=NumpyEncoder) @@ -90,16 +80,12 @@ def read_scaler(scaler_str): """ scaler_params = json.loads(scaler_str, object_hook=object_hook) scaler = scaler_objs[scaler_params["type"]]() - tensor_flag = "Tensor" in scaler_params["type"] del scaler_params["type"] for k, v in scaler_params.items(): if isinstance(v, dict) and v["object"] == "ndarray": setattr(scaler, k, np.array(v['data'], dtype=v['dtype']).reshape(v['shape'])) else: - if tensor_flag: - setattr(scaler, k, torch.tensor(v)) - else: - setattr(scaler, k, v) + setattr(scaler, k, v) return scaler @@ -170,4 +156,4 @@ def create_synthetic_data(): for l in range(locs.shape[0]): x_data_dict[names[l]] = np.random.normal(loc=locs[l], scale=scales[l], size=num_examples) x_data = pd.DataFrame(x_data_dict) - return x_data + return x_data \ No newline at end of file From 9b738dedd30a92e7e614dbc356fdb990767e8d6a Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Mon, 5 Jan 2026 10:28:23 -0700 Subject: [PATCH 07/30] modify PyTorch check and import backend for tensors --- bridgescaler/__init__.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/bridgescaler/__init__.py b/bridgescaler/__init__.py index 53c82ad..4f9c9f6 100644 --- a/bridgescaler/__init__.py +++ b/bridgescaler/__init__.py @@ -1,25 +1,26 @@ from importlib.metadata import version, PackageNotFoundError -from packaging.version import parse +from packaging.version import Version -# 1. Internal Constants & PyTorch Checks -REQUIRED_TORCH_VERSION = "2.0.0" -def _get_torch_status(): - """Checks torch version via metadata without importing the module.""" - try: - installed_version = version("torch") - - if parse(installed_version) < parse(REQUIRED_TORCH_VERSION): - raise RuntimeError( - f"PyTorch >= {REQUIRED_TORCH_VERSION} required; found {installed_version}" - ) +# 1. PyTorch Checks +REQUIRED_TORCH_VERSION = Version("2.0.0") - return True, installed_version +def get_torch_status() -> tuple[bool, Version | None]: + try: + return True, Version(version("torch")) except PackageNotFoundError: return False, None -TORCH_AVAILABLE, TORCH_VERSION = _get_torch_status() +TORCH_AVAILABLE, TORCH_VERSION = get_torch_status() + +def require_torch() -> None: + if not TORCH_AVAILABLE: + raise ImportError("PyTorch is required but not installed") + if TORCH_VERSION < REQUIRED_TORCH_VERSION: + raise RuntimeError( + f"PyTorch >= {REQUIRED_TORCH_VERSION} required; found {TORCH_VERSION}" + ) # 2. Base Imports from .backend import save_scaler, load_scaler, print_scaler, read_scaler @@ -33,6 +34,7 @@ def _get_torch_status(): DStandardScalerTensor, DMinMaxScalerTensor, ) + from .backend_tensor import print_scaler_tensor, read_scaler_tensor # 4. Define Public API __all__ = [ From 7dbea592cd679f73bdfb6991367422602cc3b5b3 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Mon, 5 Jan 2026 10:29:46 -0700 Subject: [PATCH 08/30] add PyTorch hard check --- bridgescaler/distributed_tensor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index 8b817a2..50d196b 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -1,6 +1,9 @@ +from . import require_torch +require_torch() # enforce torch availability/version at import time +import torch + from copy import deepcopy -import torch class DBaseScalerTensor: """ From b5d18f5b73f94e999732690dbe4c20cb02868293 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Mon, 5 Jan 2026 10:31:46 -0700 Subject: [PATCH 09/30] backend methods for tensors --- bridgescaler/backend_tensor.py | 40 ++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 bridgescaler/backend_tensor.py diff --git a/bridgescaler/backend_tensor.py b/bridgescaler/backend_tensor.py new file mode 100644 index 0000000..1eb97f9 --- /dev/null +++ b/bridgescaler/backend_tensor.py @@ -0,0 +1,40 @@ +from . import require_torch +require_torch() # enforce torch availability/version at import time +import torch + +import json +import numpy as np + +from bridgescaler.distributed_tensor import DStandardScalerTensor, DMinMaxScalerTensor +from .backend import NumpyEncoder, object_hook + + +scaler_objs = {"DStandardScalerTensor": DStandardScalerTensor, + "DMinMaxScalerTensor": DMinMaxScalerTensor, + } + + +def print_scaler_tensor(scaler): + """ + Modify the print_scaler() in backend.py for tensors. + """ + scaler_params = scaler.__dict__ + scaler_params["type"] = str(type(scaler))[1:-2].split(".")[-1] + + for keys in scaler_params: + if type(scaler_params[keys]) == torch.Tensor: + scaler_params[keys] = scaler_params[keys].cpu().numpy().copy() + + return json.dumps(scaler_params, indent=4, sort_keys=True, cls=NumpyEncoder) + + +def read_scaler_tensor(scaler_str): + """ + Modify the read_scaler() in backend.py for tensors. + """ + scaler_params = json.loads(scaler_str, object_hook=object_hook) + scaler = scaler_objs[scaler_params["type"]]() + del scaler_params["type"] + for k, v in scaler_params.items(): + setattr(scaler, k, torch.tensor(v)) + return scaler \ No newline at end of file From 9b9a4aa94dfa0c14b22fad553e5f0564c63b0a2b Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Mon, 5 Jan 2026 13:11:31 -0700 Subject: [PATCH 10/30] modify the require version and conditional torch imoports --- bridgescaler/__init__.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/bridgescaler/__init__.py b/bridgescaler/__init__.py index 4f9c9f6..4c9558d 100644 --- a/bridgescaler/__init__.py +++ b/bridgescaler/__init__.py @@ -4,7 +4,7 @@ # 1. PyTorch Checks -REQUIRED_TORCH_VERSION = Version("2.0.0") +REQUIRED_TORCH_VERSION = Version("2.6.0") def get_torch_status() -> tuple[bool, Version | None]: try: @@ -30,11 +30,14 @@ def require_torch() -> None: # 3. Conditional Torch Imports if TORCH_AVAILABLE: - from .distributed_tensor import ( - DStandardScalerTensor, - DMinMaxScalerTensor, - ) - from .backend_tensor import print_scaler_tensor, read_scaler_tensor + try: # Ensure that no errors are raised if PyTorch is installed but does not meet the required version. + from .distributed_tensor import ( + DStandardScalerTensor, + DMinMaxScalerTensor, + ) + from .backend_tensor import print_scaler_tensor, read_scaler_tensor + except: + pass # 4. Define Public API __all__ = [ @@ -45,7 +48,4 @@ def require_torch() -> None: "GroupStandardScaler", "GroupRobustScaler", "GroupMinMaxScaler", "DeepStandardScaler", "DeepMinMaxScaler", "DeepQuantileTransformer", "DStandardScaler", "DMinMaxScaler", "DQuantileScaler", -] - -if TORCH_AVAILABLE: - __all__ += ["DStandardScalerTensor", "DMinMaxScalerTensor"] \ No newline at end of file +] \ No newline at end of file From ca18ee2a9394645c04538681c976c895d109cacc Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Mon, 5 Jan 2026 13:13:26 -0700 Subject: [PATCH 11/30] modify the required version --- bridgescaler/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgescaler/__init__.py b/bridgescaler/__init__.py index 4c9558d..80971d7 100644 --- a/bridgescaler/__init__.py +++ b/bridgescaler/__init__.py @@ -4,7 +4,7 @@ # 1. PyTorch Checks -REQUIRED_TORCH_VERSION = Version("2.6.0") +REQUIRED_TORCH_VERSION = Version("2.0.0") def get_torch_status() -> tuple[bool, Version | None]: try: From 2ba8785b6b0f6103906f90216236344621f339e0 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Mon, 2 Feb 2026 13:58:06 -0700 Subject: [PATCH 12/30] tensors placement --- bridgescaler/distributed_tensor.py | 58 +++++++++++++++--------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index 50d196b..99b700f 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -36,7 +36,7 @@ def extract_x_columns(x, channels_last=True): if not channels_last: var_dim_num = 1 assert isinstance(x, torch.Tensor), "Input must be a PyTorch tensor" - x_columns = torch.arange(x.shape[var_dim_num]) + x_columns = torch.arange(x.shape[var_dim_num], device=x.device) return x_columns def set_channel_dim(self, channels_last=None): @@ -56,9 +56,9 @@ def process_x_for_transform(self, x, channels_last=None): assert ( x.shape[channel_dim] == self.x_columns_.shape[0] ), "Number of input columns does not match scaler." - x_col_order = torch.arange(x.shape[channel_dim]) + x_col_order = torch.arange(x.shape[channel_dim], device=x.device) xv = x - x_transformed = torch.zeros(xv.shape, dtype=xv.dtype) + x_transformed = torch.zeros(xv.shape, dtype=xv.dtype, device=xv.device) return xv, x_transformed, channels_last, channel_dim, x_col_order def fit(self, x, weight=None): @@ -105,14 +105,14 @@ def fit(self, x, weight=None): self.x_columns_ = x_columns if len(xv.shape) > 2: if self.channels_last: - self.n_ += torch.prod(torch.tensor(xv.shape[:-1])) + self.n_ += torch.prod(torch.tensor(xv.shape[:-1], dtype=xv.dtype, device=xv.device)) else: self.n_ += xv.shape[0] * \ - torch.prod(torch.tensor(xv.shape[2:])) + torch.prod(torch.tensor(xv.shape[2:], dtype=xv.dtype, device=xv.device)) else: self.n_ += xv.shape[0] - self.mean_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype) - self.var_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype) + self.mean_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype, device=xv.device) + self.var_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype, device=xv.device) if self.channels_last: for i in range(xv.shape[channel_dim]): @@ -129,15 +129,15 @@ def fit(self, x, weight=None): x.shape[channel_dim] == self.x_columns_.shape[0] ), "New data has a different number of columns" if self.channels_last: - x_col_order = torch.arange(x.shape[-1]) + x_col_order = torch.arange(x.shape[-1], device=x.device) else: - x_col_order = torch.arange(x.shape[1]) + x_col_order = torch.arange(x.shape[1], device=x.device) if len(xv.shape) > 2: if self.channels_last: - new_n = torch.prod(torch.tensor(xv.shape[:-1])) + new_n = torch.prod(torch.tensor(xv.shape[:-1], dtype=xv.dtype, device=xv.device)) else: new_n = xv.shape[0] * \ - torch.prod(torch.tensor(xv.shape[2:])) + torch.prod(torch.tensor(xv.shape[2:], dtype=xv.dtype, device=xv.device)) else: new_n = xv.shape[0] for i, o in enumerate(x_col_order): @@ -185,11 +185,11 @@ def transform(self, x, channels_last=None): if channels_last: for i, o in enumerate(x_col_order): x_transformed[..., i] = ( - xv[..., i] - x_mean[o]) / torch.sqrt(x_var[o]) + xv[..., i] - x_mean[o].to(device=xv[..., i].device)) / torch.sqrt(x_var[o].to(device=xv[..., i].device)) else: for i, o in enumerate(x_col_order): x_transformed[:, i] = ( - xv[:, i] - x_mean[o]) / torch.sqrt(x_var[o]) + xv[:, i] - x_mean[o].to(device=xv[:, i].device)) / torch.sqrt(x_var[o].to(device=xv[:, i].device)) return x_transformed def inverse_transform(self, x, channels_last=None): @@ -204,11 +204,11 @@ def inverse_transform(self, x, channels_last=None): if channels_last: for i, o in enumerate(x_col_order): x_transformed[..., i] = xv[..., i] * \ - torch.sqrt(x_var[o]) + x_mean[o] + torch.sqrt(x_var[o].to(device=xv[..., i].device)) + x_mean[o].to(device=xv[..., i].device) else: for i, o in enumerate(x_col_order): x_transformed[:, i] = xv[:, i] * \ - torch.sqrt(x_var[o]) + x_mean[o] + torch.sqrt(x_var[o].to(device=xv[:, i].device)) + x_mean[o].to(device=xv[:, i].device) return x_transformed def get_scales(self): @@ -255,8 +255,8 @@ def fit(self, x, weight=None): channel_dim = self.set_channel_dim() if not self._fit: self.x_columns_ = x_columns - self.max_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype) - self.min_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype) + self.max_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype, device=xv.device) + self.min_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype, device=xv.device) if self.channels_last: for i in range(xv.shape[channel_dim]): @@ -272,9 +272,9 @@ def fit(self, x, weight=None): x.shape[channel_dim] == self.x_columns_.shape[0] ), "New data has a different number of columns" if self.channels_last: - x_col_order = torch.arange(x.shape[-1]) + x_col_order = torch.arange(x.shape[-1], device=x.device) else: - x_col_order = torch.arange(x.shape[1]) + x_col_order = torch.arange(x.shape[1], device=x.device) if self.channels_last: for i, o in enumerate(x_col_order): self.max_x_[o] = torch.maximum( @@ -299,15 +299,16 @@ def transform(self, x, channels_last=None): channel_dim, x_col_order, ) = self.process_x_for_transform(x, channels_last) + x_min, x_max = self.get_scales() if channels_last: for i, o in enumerate(x_col_order): - x_transformed[..., i] = (xv[..., i] - self.min_x_[o]) / ( - self.max_x_[o] - self.min_x_[o] + x_transformed[..., i] = (xv[..., i] - x_min[o].to(device=xv[..., i].device)) / ( + x_max[o].to(device=xv[..., i].device) - x_min[o].to(device=xv[..., i].device) ) else: for i, o in enumerate(x_col_order): - x_transformed[:, i] = (xv[:, i] - self.min_x_[o]) / ( - self.max_x_[o] - self.min_x_[o] + x_transformed[:, i] = (xv[:, i] - x_min[o].to(device=xv[:, i].device)) / ( + x_max[o].to(device=xv[:, i].device) - x_min[o].to(device=xv[:, i].device) ) return x_transformed @@ -319,17 +320,18 @@ def inverse_transform(self, x, channels_last=None): channel_dim, x_col_order, ) = self.process_x_for_transform(x, channels_last) + x_min, x_max = self.get_scales() if channels_last: for i, o in enumerate(x_col_order): x_transformed[..., i] = ( - xv[..., i] * (self.max_x_[o] - self.min_x_[o] - ) + self.min_x_[o] + xv[..., i] * (x_max[o].to(device=xv[..., i].device) - x_min[o].to(device=xv[..., i].device) + ) + x_min[o].to(device=xv[..., i].device) ) else: for i, o in enumerate(x_col_order): x_transformed[:, i] = ( - xv[:, i] * (self.max_x_[o] - self.min_x_[o]) + - self.min_x_[o] + xv[:, i] * (x_max[o].to(device=xv[:, i].device) - x_min[o].to(device=xv[:, i].device)) + + x_min[o].to(device=xv[:, i].device) ) return x_transformed @@ -344,4 +346,4 @@ def __add__(self, other): current = deepcopy(self) current.max_x_ = torch.maximum(self.max_x_, other.max_x_) current.min_x_ = torch.minimum(self.min_x_, other.min_x_) - return current + return current \ No newline at end of file From a7d255f4c6fd26c8684d5462e3bb43dd99f95013 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Wed, 4 Feb 2026 08:55:00 -0700 Subject: [PATCH 13/30] code optimization: avoid for-looping --- bridgescaler/distributed_tensor.py | 139 ++++++++++++++--------------- 1 file changed, 67 insertions(+), 72 deletions(-) diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index 99b700f..8042634 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -83,6 +83,16 @@ def subset_columns(self, sel_columns): def add_variables(self, other): pass + @staticmethod + def reshape_to_channels_first(stat, target): + """Reshapes 'stat' to align with the channel dimension (index 1).""" + return stat.view(*(stat.size(0) if i == 1 else 1 for i in range(target.dim()))) + + @staticmethod + def reshape_to_channels_last(stat, target): + """Reshapes 'stat' to align with the last dimension.""" + return stat.view(*(stat.size(0) if i == target.dim() - 1 else 1 for i in range(target.dim()))) + class DStandardScalerTensor(DBaseScalerTensor): """ @@ -115,13 +125,11 @@ def fit(self, x, weight=None): self.var_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype, device=xv.device) if self.channels_last: - for i in range(xv.shape[channel_dim]): - self.mean_x_[i] = torch.mean(xv[..., i]) - self.var_x_[i] = torch.var(xv[..., i], correction=0) + self.mean_x_ = torch.mean(xv, dim=tuple(range(xv.ndim - 1))) + self.var_x_ = torch.var(xv, dim=tuple(range(xv.ndim - 1)), correction=0) else: - for i in range(xv.shape[channel_dim]): - self.mean_x_[i] = torch.mean(xv[:, i]) - self.var_x_[i] = torch.var(xv[:, i], correction=0) + self.mean_x_ = torch.mean(xv, dim=tuple(d for d in range(xv.ndim) if d != 1)) + self.var_x_ = torch.var(xv, dim=tuple(d for d in range(xv.ndim) if d != 1), correction=0) else: # Update existing scaler with new data @@ -140,24 +148,23 @@ def fit(self, x, weight=None): torch.prod(torch.tensor(xv.shape[2:], dtype=xv.dtype, device=xv.device)) else: new_n = xv.shape[0] - for i, o in enumerate(x_col_order): - if self.channels_last: - new_mean = torch.mean(xv[..., i]) - new_var = torch.var(xv[..., i], correction=0) - else: - new_mean = torch.mean(xv[:, i]) - new_var = torch.var(xv[:, i], correction=0) - combined_mean = (self.n_ * self.mean_x_[o] + new_n * new_mean) / ( - self.n_ + new_n - ) - weighted_var = (self.n_ * self.var_x_[o] + new_n * new_var) / ( - self.n_ + new_n - ) - var_correction = ( - self.n_ * new_n * (self.mean_x_[o] - new_mean) ** 2 - ) / ((self.n_ + new_n) ** 2) - self.mean_x_[o] = combined_mean - self.var_x_[o] = weighted_var + var_correction + if self.channels_last: + new_mean = torch.mean(xv, dim=tuple(range(xv.ndim - 1))) + new_var = torch.var(xv, dim=tuple(range(xv.ndim - 1)), correction=0) + else: + new_mean = torch.mean(xv, dim=tuple(d for d in range(xv.ndim) if d != 1)) + new_var = torch.var(xv, dim=tuple(d for d in range(xv.ndim) if d != 1), correction=0) + combined_mean = (self.n_ * self.mean_x_ + new_n * new_mean) / ( + self.n_ + new_n + ) + weighted_var = (self.n_ * self.var_x_ + new_n * new_var) / ( + self.n_ + new_n + ) + var_correction = ( + self.n_ * new_n * (self.mean_x_ - new_mean) ** 2 + ) / ((self.n_ + new_n) ** 2) + self.mean_x_ = combined_mean + self.var_x_ = weighted_var + var_correction self.n_ += new_n self._fit = True @@ -183,13 +190,11 @@ def transform(self, x, channels_last=None): ) = self.process_x_for_transform(x, channels_last) x_mean, x_var = self.get_scales() if channels_last: - for i, o in enumerate(x_col_order): - x_transformed[..., i] = ( - xv[..., i] - x_mean[o].to(device=xv[..., i].device)) / torch.sqrt(x_var[o].to(device=xv[..., i].device)) + x_transformed = ( + xv - self.reshape_to_channels_last(x_mean.to(device=xv.device), xv)) / torch.sqrt(self.reshape_to_channels_last(x_var.to(device=xv.device), xv)) else: - for i, o in enumerate(x_col_order): - x_transformed[:, i] = ( - xv[:, i] - x_mean[o].to(device=xv[:, i].device)) / torch.sqrt(x_var[o].to(device=xv[:, i].device)) + x_transformed = ( + xv - self.reshape_to_channels_first(x_mean.to(device=xv.device), xv)) / torch.sqrt(self.reshape_to_channels_first(x_var.to(device=xv.device), xv)) return x_transformed def inverse_transform(self, x, channels_last=None): @@ -202,13 +207,11 @@ def inverse_transform(self, x, channels_last=None): ) = self.process_x_for_transform(x, channels_last) x_mean, x_var = self.get_scales() if channels_last: - for i, o in enumerate(x_col_order): - x_transformed[..., i] = xv[..., i] * \ - torch.sqrt(x_var[o].to(device=xv[..., i].device)) + x_mean[o].to(device=xv[..., i].device) + x_transformed = xv * \ + torch.sqrt(self.reshape_to_channels_last(x_var.to(device=xv.device), xv)) + self.reshape_to_channels_last(x_mean.to(device=xv.device), xv) else: - for i, o in enumerate(x_col_order): - x_transformed[:, i] = xv[:, i] * \ - torch.sqrt(x_var[o].to(device=xv[:, i].device)) + x_mean[o].to(device=xv[:, i].device) + x_transformed = xv * \ + torch.sqrt(self.reshape_to_channels_first(x_var.to(device=xv.device), xv)) + self.reshape_to_channels_first(x_mean.to(device=xv.device), xv) return x_transformed def get_scales(self): @@ -259,13 +262,11 @@ def fit(self, x, weight=None): self.min_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype, device=xv.device) if self.channels_last: - for i in range(xv.shape[channel_dim]): - self.max_x_[i] = torch.max(xv[..., i]) - self.min_x_[i] = torch.min(xv[..., i]) + self.max_x_ = torch.amax(xv, dim=tuple(range(xv.ndim - 1))) + self.min_x_ = torch.amin(xv, dim=tuple(range(xv.ndim - 1))) else: - for i in range(xv.shape[channel_dim]): - self.max_x_[i] = torch.max(xv[:, i]) - self.min_x_[i] = torch.min(xv[:, i]) + self.max_x_ = torch.amax(xv, dim=tuple(d for d in range(xv.ndim) if d != 1)) + self.min_x_ = torch.amin(xv, dim=tuple(d for d in range(xv.ndim) if d != 1)) else: # Update existing scaler with new data assert ( @@ -276,19 +277,17 @@ def fit(self, x, weight=None): else: x_col_order = torch.arange(x.shape[1], device=x.device) if self.channels_last: - for i, o in enumerate(x_col_order): - self.max_x_[o] = torch.maximum( - self.max_x_[o], torch.max(xv[..., i]) - ) - self.min_x_[o] = torch.minimum( - self.min_x_[o], torch.min(xv[..., i]) - ) + self.max_x_ = torch.maximum( + self.max_x_, torch.amax(xv, dim=tuple(range(xv.ndim - 1))) + ) + self.min_x_ = torch.minimum( + self.min_x_, torch.amin(xv, dim=tuple(range(xv.ndim - 1))) + ) else: - for i, o in enumerate(xv.shape[channel_dim]): - self.max_x_[o] = torch.maximum( - self.max_x_[o], torch.max(xv[:, i])) - self.min_x_[o] = torch.minimum( - self.min_x_[o], torch.min(xv[:, i])) + self.max_x_ = torch.maximum( + self.max_x_, torch.amax(xv, dim=tuple(d for d in range(xv.ndim) if d != 1))) + self.min_x_ = torch.minimum( + self.min_x_, torch.amin(xv, dim=tuple(d for d in range(xv.ndim) if d != 1))) self._fit = True def transform(self, x, channels_last=None): @@ -301,15 +300,13 @@ def transform(self, x, channels_last=None): ) = self.process_x_for_transform(x, channels_last) x_min, x_max = self.get_scales() if channels_last: - for i, o in enumerate(x_col_order): - x_transformed[..., i] = (xv[..., i] - x_min[o].to(device=xv[..., i].device)) / ( - x_max[o].to(device=xv[..., i].device) - x_min[o].to(device=xv[..., i].device) - ) + x_transformed = (xv - self.reshape_to_channels_last(x_min.to(device=xv.device), xv)) / ( + self.reshape_to_channels_last(x_max.to(device=xv.device), xv) - self.reshape_to_channels_last(x_min.to(device=xv.device), xv) + ) else: - for i, o in enumerate(x_col_order): - x_transformed[:, i] = (xv[:, i] - x_min[o].to(device=xv[:, i].device)) / ( - x_max[o].to(device=xv[:, i].device) - x_min[o].to(device=xv[:, i].device) - ) + x_transformed = (xv - self.reshape_to_channels_first(x_min.to(device=xv.device), xv)) / ( + self.reshape_to_channels_first(x_max.to(device=xv.device), xv) - self.reshape_to_channels_first(x_min.to(device=xv.device), xv) + ) return x_transformed def inverse_transform(self, x, channels_last=None): @@ -322,17 +319,15 @@ def inverse_transform(self, x, channels_last=None): ) = self.process_x_for_transform(x, channels_last) x_min, x_max = self.get_scales() if channels_last: - for i, o in enumerate(x_col_order): - x_transformed[..., i] = ( - xv[..., i] * (x_max[o].to(device=xv[..., i].device) - x_min[o].to(device=xv[..., i].device) - ) + x_min[o].to(device=xv[..., i].device) - ) + x_transformed = ( + xv * (self.reshape_to_channels_last(x_max.to(device=xv.device), xv) - self.reshape_to_channels_last(x_min.to(device=xv.device), xv) + ) + self.reshape_to_channels_last(x_min.to(device=xv.device), xv) + ) else: - for i, o in enumerate(x_col_order): - x_transformed[:, i] = ( - xv[:, i] * (x_max[o].to(device=xv[:, i].device) - x_min[o].to(device=xv[:, i].device)) + - x_min[o].to(device=xv[:, i].device) - ) + x_transformed = ( + xv * (self.reshape_to_channels_first(x_max.to(device=xv.device), xv) - self.reshape_to_channels_first(x_min.to(device=xv.device), xv)) + + self.reshape_to_channels_first(x_min.to(device=xv.device), xv) + ) return x_transformed def get_scales(self): From 8bc88d6ad186c531ce3de0f5cbc34f7a762a8ffd Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Wed, 4 Feb 2026 13:00:53 -0700 Subject: [PATCH 14/30] add path to environment --- .github/workflows/python-package-conda.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index e32ec3f..59faf0e 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -46,3 +46,5 @@ jobs: - name: Test with pytest run: | pytest + env: + PYTHONPATH: ${{ github.workspace }} From 5a3d33f3099fcebe8e7b7a3abfecfba7b6ef14f2 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Wed, 4 Feb 2026 13:09:16 -0700 Subject: [PATCH 15/30] installing in editable mode --- .github/workflows/python-package-conda.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index 59faf0e..efbae18 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -33,7 +33,7 @@ jobs: run: | python -m pip install --upgrade uv uv pip install torch --system --index-url https://download.pytorch.org/whl/cpu - uv pip install . --system + uv pip install -e . --system uv pip install ruff pytest --system - name: Lint with ruff run: | @@ -46,5 +46,3 @@ jobs: - name: Test with pytest run: | pytest - env: - PYTHONPATH: ${{ github.workspace }} From 37c4101f4974f112c2453284dbd94324d4a8ee72 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 5 Feb 2026 10:40:47 -0700 Subject: [PATCH 16/30] specify uv version --- .github/workflows/python-package-conda.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index efbae18..b8919e7 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -31,9 +31,9 @@ jobs: cache: 'pip' - name: Install dependencies run: | - python -m pip install --upgrade uv + python -m pip install uv=0.9.15 uv pip install torch --system --index-url https://download.pytorch.org/whl/cpu - uv pip install -e . --system + uv pip install . --system uv pip install ruff pytest --system - name: Lint with ruff run: | From 7ed3595b6003ad586c7ae5f9d1df14d353dea269 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 5 Feb 2026 10:43:15 -0700 Subject: [PATCH 17/30] fix syntax error --- .github/workflows/python-package-conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index b8919e7..7169b9f 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -31,7 +31,7 @@ jobs: cache: 'pip' - name: Install dependencies run: | - python -m pip install uv=0.9.15 + python -m pip install uv==0.9.15 uv pip install torch --system --index-url https://download.pytorch.org/whl/cpu uv pip install . --system uv pip install ruff pytest --system From 262423b4e8c8a8bcbb94bbc984520d175c40efac Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 5 Feb 2026 10:54:22 -0700 Subject: [PATCH 18/30] downgrade Python version --- .github/workflows/python-package-conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index 7169b9f..f99bf57 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -27,7 +27,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v6 with: - python-version: '3.11' + python-version: '3.9' cache: 'pip' - name: Install dependencies run: | From bb87cb07774bf9ea5644aec2291696080f71e71d Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 5 Feb 2026 11:19:42 -0700 Subject: [PATCH 19/30] try this version combination --- .github/workflows/python-package-conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index f99bf57..287f696 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -27,7 +27,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v6 with: - python-version: '3.9' + python-version: '3.11.14' cache: 'pip' - name: Install dependencies run: | From 9d8466b35604d05cca5c3d3433feb566df7f905f Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 5 Feb 2026 11:29:01 -0700 Subject: [PATCH 20/30] use Python 3.9 in GitHub workflow and latest uv --- .github/workflows/python-package-conda.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index 287f696..f083cf9 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -27,11 +27,11 @@ jobs: - name: Setup Python uses: actions/setup-python@v6 with: - python-version: '3.11.14' + python-version: '3.9' cache: 'pip' - name: Install dependencies run: | - python -m pip install uv==0.9.15 + python -m pip install --upgrade uv uv pip install torch --system --index-url https://download.pytorch.org/whl/cpu uv pip install . --system uv pip install ruff pytest --system From 9eea9444c71ffb053d0de88e5a811d6d6a299bc5 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 5 Feb 2026 11:29:44 -0700 Subject: [PATCH 21/30] fix syntax for backward compatibility --- bridgescaler/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bridgescaler/__init__.py b/bridgescaler/__init__.py index 80971d7..0191a31 100644 --- a/bridgescaler/__init__.py +++ b/bridgescaler/__init__.py @@ -1,4 +1,5 @@ from importlib.metadata import version, PackageNotFoundError +from typing import Union from packaging.version import Version @@ -6,7 +7,7 @@ # 1. PyTorch Checks REQUIRED_TORCH_VERSION = Version("2.0.0") -def get_torch_status() -> tuple[bool, Version | None]: +def get_torch_status() -> tuple[bool, Union[Version, None]]: try: return True, Version(version("torch")) except PackageNotFoundError: @@ -48,4 +49,4 @@ def require_torch() -> None: "GroupStandardScaler", "GroupRobustScaler", "GroupMinMaxScaler", "DeepStandardScaler", "DeepMinMaxScaler", "DeepQuantileTransformer", "DStandardScaler", "DMinMaxScaler", "DQuantileScaler", -] \ No newline at end of file +] From 18d1876a301c024f38f3632307cf18fe43347e52 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 5 Feb 2026 13:28:17 -0700 Subject: [PATCH 22/30] revert changes to pass workflow runs --- .github/workflows/python-package-conda.yml | 2 +- bridgescaler/__init__.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index f083cf9..e32ec3f 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -27,7 +27,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v6 with: - python-version: '3.9' + python-version: '3.11' cache: 'pip' - name: Install dependencies run: | diff --git a/bridgescaler/__init__.py b/bridgescaler/__init__.py index 0191a31..c4ab6ec 100644 --- a/bridgescaler/__init__.py +++ b/bridgescaler/__init__.py @@ -1,5 +1,4 @@ from importlib.metadata import version, PackageNotFoundError -from typing import Union from packaging.version import Version @@ -7,7 +6,7 @@ # 1. PyTorch Checks REQUIRED_TORCH_VERSION = Version("2.0.0") -def get_torch_status() -> tuple[bool, Union[Version, None]]: +def get_torch_status() -> tuple[bool, Version | None]: try: return True, Version(version("torch")) except PackageNotFoundError: From d19262833bc59603e4aa6b2ed86d0f1f5510b5d3 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 5 Feb 2026 13:30:57 -0700 Subject: [PATCH 23/30] Pin Pandas and NumPy versions per Katelyn's suggestion --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 1b0fcbb..605cf17 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,8 +26,8 @@ setup_requires = setuptools python_requires = >=3.7 install_requires = scikit-learn>=1.0 - numpy - pandas + numpy<2.4 + pandas<3 crick scipy xarray From 5775e2cd9c9917ab82e4a63adff152fd96b81ae5 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 12 Feb 2026 16:12:18 -0700 Subject: [PATCH 24/30] include attribute variable_names into operations for DStandardScalerTensor --- bridgescaler/distributed_tensor.py | 88 ++++++++++++++++++++++-------- 1 file changed, 66 insertions(+), 22 deletions(-) diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index 8042634..eee2219 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -22,23 +22,70 @@ def is_fit(self): @staticmethod def extract_x_columns(x, channels_last=True): """ - Extract column indices to be transformed from x. All of these assume that the columns are in the last dimension. + Extract the variable (column) names from input x + + The variable/channel names are stored as a list of strings in the + `variable_names` attribute. If this attribute does not exist, an error will be raised. + This extraction assumes that the channels are located in the last dimension. Args: - x (torch.tensor): tensor of values to be transformed. + x (torch.tensor): tensor of values. channels_last (bool): If True, then assume the variable or channel dimension is the last dimension of the array. If False, then assume the variable or channel dimension is second. Returns: - x_columns (torch.tensor): tensor of column indices. + x_columns (torch.tensor): list of strings with variable names """ var_dim_num = -1 if not channels_last: var_dim_num = 1 assert isinstance(x, torch.Tensor), "Input must be a PyTorch tensor" - x_columns = torch.arange(x.shape[var_dim_num], device=x.device) + if hasattr(x, 'variable_names'): + x_columns = x.variable_names + else: + x_columns = list(range(x.shape[var_dim_num])) + #assert getattr(x, 'variable_names', None) is not None, "variable_names attribute is missing or empty" + assert len(x_columns) == len( + set(x_columns)), f"Duplicates found! Unique count: {len(set(x_columns))}, Total count: {len(x_columns)}" return x_columns + @staticmethod + def extract_array(x): + pass + + def get_column_order(self, x_in_columns): + """ + Get the indices of the scaler columns that have the same name as the variables (columns) in the input x tensor. This + enables users to pass a torch.Tensor to transform or inverse_transform with fewer variables than + the original scaler or variables in a different order and still have the input dataset be transformed properly. + + Args: + x_in_columns (list): list of input variable names + + Returns: + x_in_col_indices (torch.Tensor): indices of the input variables from x in the scaler in order. + """ + assert all(var in self.x_columns_ for var in x_in_columns), "Some input variables not in scaler x_columns." + x_in_col_indices = [self.x_columns_.index(item) for item in x_in_columns if item in self.x_columns_] + return x_in_col_indices + + @staticmethod + def package_transformed_x(x_transformed, x): + """ + Repackaged a transformed torch.Tensor into the same datatype as the original x, including + all metadata. + + Args: + x_transformed (torch.Tensor): array after being transformed or inverse transformed + x (torch.Tensor) + + Returns: + + """ + x_packaged = x_transformed + x_packaged.variable_names = x.variable_names + return x_packaged + def set_channel_dim(self, channels_last=None): if channels_last is None: channels_last = self.channels_last @@ -53,10 +100,11 @@ def process_x_for_transform(self, x, channels_last=None): channels_last = self.channels_last channel_dim = self.set_channel_dim(channels_last) assert self._fit, "Scaler has not been fit." + x_in_cols = self.extract_x_columns(x, channels_last=channels_last) assert ( - x.shape[channel_dim] == self.x_columns_.shape[0] + x.shape[channel_dim] == len(self.x_columns_) ), "Number of input columns does not match scaler." - x_col_order = torch.arange(x.shape[channel_dim], device=x.device) + x_col_order = self.get_column_order(x_in_cols) xv = x x_transformed = torch.zeros(xv.shape, dtype=xv.dtype, device=xv.device) return xv, x_transformed, channels_last, channel_dim, x_col_order @@ -134,12 +182,9 @@ def fit(self, x, weight=None): else: # Update existing scaler with new data assert ( - x.shape[channel_dim] == self.x_columns_.shape[0] + x.shape[channel_dim] == len(self.x_columns_) ), "New data has a different number of columns" - if self.channels_last: - x_col_order = torch.arange(x.shape[-1], device=x.device) - else: - x_col_order = torch.arange(x.shape[1], device=x.device) + x_col_order = self.get_column_order(x_columns) if len(xv.shape) > 2: if self.channels_last: new_n = torch.prod(torch.tensor(xv.shape[:-1], dtype=xv.dtype, device=xv.device)) @@ -149,11 +194,11 @@ def fit(self, x, weight=None): else: new_n = xv.shape[0] if self.channels_last: - new_mean = torch.mean(xv, dim=tuple(range(xv.ndim - 1))) - new_var = torch.var(xv, dim=tuple(range(xv.ndim - 1)), correction=0) + new_mean = torch.mean(xv[...,x_col_order], dim=tuple(range(xv.ndim - 1))) + new_var = torch.var(xv[...,x_col_order], dim=tuple(range(xv.ndim - 1)), correction=0) else: - new_mean = torch.mean(xv, dim=tuple(d for d in range(xv.ndim) if d != 1)) - new_var = torch.var(xv, dim=tuple(d for d in range(xv.ndim) if d != 1), correction=0) + new_mean = torch.mean(xv[:, x_col_order], dim=tuple(d for d in range(xv.ndim) if d != 1)) + new_var = torch.var(xv[:, x_col_order], dim=tuple(d for d in range(xv.ndim) if d != 1), correction=0) combined_mean = (self.n_ * self.mean_x_ + new_n * new_mean) / ( self.n_ + new_n ) @@ -188,7 +233,7 @@ def transform(self, x, channels_last=None): channel_dim, x_col_order, ) = self.process_x_for_transform(x, channels_last) - x_mean, x_var = self.get_scales() + x_mean, x_var = self.get_scales(x_col_order) if channels_last: x_transformed = ( xv - self.reshape_to_channels_last(x_mean.to(device=xv.device), xv)) / torch.sqrt(self.reshape_to_channels_last(x_var.to(device=xv.device), xv)) @@ -205,7 +250,7 @@ def inverse_transform(self, x, channels_last=None): channel_dim, x_col_order, ) = self.process_x_for_transform(x, channels_last) - x_mean, x_var = self.get_scales() + x_mean, x_var = self.get_scales(x_col_order) if channels_last: x_transformed = xv * \ torch.sqrt(self.reshape_to_channels_last(x_var.to(device=xv.device), xv)) + self.reshape_to_channels_last(x_mean.to(device=xv.device), xv) @@ -214,15 +259,14 @@ def inverse_transform(self, x, channels_last=None): torch.sqrt(self.reshape_to_channels_first(x_var.to(device=xv.device), xv)) + self.reshape_to_channels_first(x_mean.to(device=xv.device), xv) return x_transformed - def get_scales(self): - return self.mean_x_, self.var_x_ + def get_scales(self, x_col_order): + return self.mean_x_[x_col_order], self.var_x_[x_col_order] def __add__(self, other): assert ( type(other) is DStandardScalerTensor ), "Input is not DStandardScalerTensor" - assert torch.all( - other.x_columns_ == self.x_columns_ + assert (other.x_columns_ == self.x_columns_ ), "Scaler columns do not match." current = deepcopy(self) current.mean_x_ = (self.n_ * self.mean_x_ + other.n_ * other.mean_x_) / ( @@ -341,4 +385,4 @@ def __add__(self, other): current = deepcopy(self) current.max_x_ = torch.maximum(self.max_x_, other.max_x_) current.min_x_ = torch.minimum(self.min_x_, other.min_x_) - return current \ No newline at end of file + return current From 03c950ea38ce832cccce79c6eca61404c0b16248 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 12 Feb 2026 21:47:42 -0700 Subject: [PATCH 25/30] include attribute variable_names into operations for DMinMaxScalerTensor and errors in DStandScalerTensor --- bridgescaler/distributed_tensor.py | 36 ++++++++++++++---------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index eee2219..9161cc2 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -240,7 +240,8 @@ def transform(self, x, channels_last=None): else: x_transformed = ( xv - self.reshape_to_channels_first(x_mean.to(device=xv.device), xv)) / torch.sqrt(self.reshape_to_channels_first(x_var.to(device=xv.device), xv)) - return x_transformed + x_transformed_final = self.package_transformed_x(x_transformed, x) + return x_transformed_final def inverse_transform(self, x, channels_last=None): ( @@ -257,9 +258,10 @@ def inverse_transform(self, x, channels_last=None): else: x_transformed = xv * \ torch.sqrt(self.reshape_to_channels_first(x_var.to(device=xv.device), xv)) + self.reshape_to_channels_first(x_mean.to(device=xv.device), xv) - return x_transformed + x_transformed_final = self.package_transformed_x(x_transformed, x) + return x_transformed_final - def get_scales(self, x_col_order): + def get_scales(self, x_col_order=slice(None)): return self.mean_x_[x_col_order], self.var_x_[x_col_order] def __add__(self, other): @@ -286,9 +288,8 @@ def __add__(self, other): class DMinMaxScalerTensor(DBaseScalerTensor): """ Distributed MinMaxScaler enables calculation of min and max of variables in datasets in parallel, then combining - the mins and maxes as a reduction step. Scaler - supports torch.tensor and will return a transformed array in the - same form as the original with column or coordinate names preserved. + the mins and maxes as a reduction step. Scaler supports torch.Tensor and will return a transformed tensor in the + same form as the original with variable/column names preserved. """ def __init__(self, channels_last=True): @@ -316,10 +317,7 @@ def fit(self, x, weight=None): assert ( x.shape[channel_dim] == self.x_columns_.shape[0] ), "New data has a different number of columns" - if self.channels_last: - x_col_order = torch.arange(x.shape[-1], device=x.device) - else: - x_col_order = torch.arange(x.shape[1], device=x.device) + x_col_order = self.get_column_order(x_columns) if self.channels_last: self.max_x_ = torch.maximum( self.max_x_, torch.amax(xv, dim=tuple(range(xv.ndim - 1))) @@ -342,7 +340,7 @@ def transform(self, x, channels_last=None): channel_dim, x_col_order, ) = self.process_x_for_transform(x, channels_last) - x_min, x_max = self.get_scales() + x_min, x_max = self.get_scales(x_col_order) if channels_last: x_transformed = (xv - self.reshape_to_channels_last(x_min.to(device=xv.device), xv)) / ( self.reshape_to_channels_last(x_max.to(device=xv.device), xv) - self.reshape_to_channels_last(x_min.to(device=xv.device), xv) @@ -351,7 +349,8 @@ def transform(self, x, channels_last=None): x_transformed = (xv - self.reshape_to_channels_first(x_min.to(device=xv.device), xv)) / ( self.reshape_to_channels_first(x_max.to(device=xv.device), xv) - self.reshape_to_channels_first(x_min.to(device=xv.device), xv) ) - return x_transformed + x_transformed_final = self.package_transformed_x(x_transformed, x) + return x_transformed_final def inverse_transform(self, x, channels_last=None): ( @@ -361,7 +360,7 @@ def inverse_transform(self, x, channels_last=None): channel_dim, x_col_order, ) = self.process_x_for_transform(x, channels_last) - x_min, x_max = self.get_scales() + x_min, x_max = self.get_scales(x_col_order) if channels_last: x_transformed = ( xv * (self.reshape_to_channels_last(x_max.to(device=xv.device), xv) - self.reshape_to_channels_last(x_min.to(device=xv.device), xv) @@ -372,16 +371,15 @@ def inverse_transform(self, x, channels_last=None): xv * (self.reshape_to_channels_first(x_max.to(device=xv.device), xv) - self.reshape_to_channels_first(x_min.to(device=xv.device), xv)) + self.reshape_to_channels_first(x_min.to(device=xv.device), xv) ) - return x_transformed + x_transformed_final = self.package_transformed_x(x_transformed, x) + return x_transformed_final - def get_scales(self): - return self.min_x_, self.max_x_ + def get_scales(self, x_col_order=slice(None)): + return self.min_x_[x_col_order], self.max_x_[x_col_order] def __add__(self, other): assert type(other) is DMinMaxScalerTensor, "Input is not DMinMaxScaler" - assert torch.all( - other.x_columns_ == self.x_columns_ - ), "Scaler columns do not match." + assert other.x_columns_ == self.x_columns_, "Scaler columns do not match." current = deepcopy(self) current.max_x_ = torch.maximum(self.max_x_, other.max_x_) current.min_x_ = torch.minimum(self.min_x_, other.min_x_) From 37e0877efc010d8cb75ae21eb2a1d80ddfe35d89 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 12 Feb 2026 22:06:35 -0700 Subject: [PATCH 26/30] fix columns check in DMinMaxScalerTensor() --- bridgescaler/distributed_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index 9161cc2..4d25122 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -315,7 +315,7 @@ def fit(self, x, weight=None): else: # Update existing scaler with new data assert ( - x.shape[channel_dim] == self.x_columns_.shape[0] + x.shape[channel_dim] == len(self.x_columns_) ), "New data has a different number of columns" x_col_order = self.get_column_order(x_columns) if self.channels_last: From 3c615e824fee3e55b8d8a5addde50e93b0d8305b Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 12 Feb 2026 22:17:46 -0700 Subject: [PATCH 27/30] modified variable_names attribute decoding --- bridgescaler/backend_tensor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bridgescaler/backend_tensor.py b/bridgescaler/backend_tensor.py index 1eb97f9..b2a200d 100644 --- a/bridgescaler/backend_tensor.py +++ b/bridgescaler/backend_tensor.py @@ -36,5 +36,8 @@ def read_scaler_tensor(scaler_str): scaler = scaler_objs[scaler_params["type"]]() del scaler_params["type"] for k, v in scaler_params.items(): - setattr(scaler, k, torch.tensor(v)) + if k == "x_columns_": + setattr(scaler, k, v) + else: + setattr(scaler, k, torch.tensor(v)) return scaler \ No newline at end of file From bb095b060e3427f9dc7d6167881047fbe59b3163 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Thu, 12 Feb 2026 22:50:25 -0700 Subject: [PATCH 28/30] modify package_transformed_x() to accomodate input data without attribute --- bridgescaler/distributed_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index 9161cc2..6e5ce9d 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -83,7 +83,8 @@ def package_transformed_x(x_transformed, x): """ x_packaged = x_transformed - x_packaged.variable_names = x.variable_names + if getattr(x, 'variable_names', None) is not None: + x_packaged.variable_names = x.variable_names return x_packaged def set_channel_dim(self, channels_last=None): From a4ec6c61372e93545cf5fb8e599a03dc462d3122 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Fri, 13 Feb 2026 08:35:08 -0700 Subject: [PATCH 29/30] allow the data to be transform to have less variables than the scaler --- bridgescaler/distributed_tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index 6e5ce9d..f995bc3 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -101,10 +101,10 @@ def process_x_for_transform(self, x, channels_last=None): channels_last = self.channels_last channel_dim = self.set_channel_dim(channels_last) assert self._fit, "Scaler has not been fit." - x_in_cols = self.extract_x_columns(x, channels_last=channels_last) - assert ( - x.shape[channel_dim] == len(self.x_columns_) - ), "Number of input columns does not match scaler." + #x_in_cols = self.extract_x_columns(x, channels_last=channels_last) + #assert ( + # x.shape[channel_dim] == len(self.x_columns_) + #), "Number of input columns does not match scaler." x_col_order = self.get_column_order(x_in_cols) xv = x x_transformed = torch.zeros(xv.shape, dtype=xv.dtype, device=xv.device) From 614acde9950c87523a2e7423c5c600da7916e7a6 Mon Sep 17 00:00:00 2001 From: kevinyang-cky Date: Fri, 13 Feb 2026 08:37:25 -0700 Subject: [PATCH 30/30] uncommont code --- bridgescaler/distributed_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index f995bc3..6fe11d0 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -101,7 +101,7 @@ def process_x_for_transform(self, x, channels_last=None): channels_last = self.channels_last channel_dim = self.set_channel_dim(channels_last) assert self._fit, "Scaler has not been fit." - #x_in_cols = self.extract_x_columns(x, channels_last=channels_last) + x_in_cols = self.extract_x_columns(x, channels_last=channels_last) #assert ( # x.shape[channel_dim] == len(self.x_columns_) #), "Number of input columns does not match scaler."