diff --git a/bridgescaler/__init__.py b/bridgescaler/__init__.py index 0e60896..c4ab6ec 100644 --- a/bridgescaler/__init__.py +++ b/bridgescaler/__init__.py @@ -1,4 +1,51 @@ +from importlib.metadata import version, PackageNotFoundError + +from packaging.version import Version + + +# 1. PyTorch Checks +REQUIRED_TORCH_VERSION = Version("2.0.0") + +def get_torch_status() -> tuple[bool, Version | None]: + try: + return True, Version(version("torch")) + except PackageNotFoundError: + return False, None + +TORCH_AVAILABLE, TORCH_VERSION = get_torch_status() + +def require_torch() -> None: + if not TORCH_AVAILABLE: + raise ImportError("PyTorch is required but not installed") + if TORCH_VERSION < REQUIRED_TORCH_VERSION: + raise RuntimeError( + f"PyTorch >= {REQUIRED_TORCH_VERSION} required; found {TORCH_VERSION}" + ) + +# 2. Base Imports from .backend import save_scaler, load_scaler, print_scaler, read_scaler from .group import GroupStandardScaler, GroupRobustScaler, GroupMinMaxScaler from .deep import DeepStandardScaler, DeepMinMaxScaler, DeepQuantileTransformer -from .distributed import DStandardScaler, DMinMaxScaler, DQuantileScaler +from .distributed import (DStandardScaler, DMinMaxScaler, DQuantileScaler) + +# 3. Conditional Torch Imports +if TORCH_AVAILABLE: + try: # Ensure that no errors are raised if PyTorch is installed but does not meet the required version. + from .distributed_tensor import ( + DStandardScalerTensor, + DMinMaxScalerTensor, + ) + from .backend_tensor import print_scaler_tensor, read_scaler_tensor + except: + pass + +# 4. Define Public API +__all__ = [ + # Utilities + "save_scaler", "load_scaler", "print_scaler", "read_scaler", + "TORCH_AVAILABLE", + # Scalers + "GroupStandardScaler", "GroupRobustScaler", "GroupMinMaxScaler", + "DeepStandardScaler", "DeepMinMaxScaler", "DeepQuantileTransformer", + "DStandardScaler", "DMinMaxScaler", "DQuantileScaler", +] diff --git a/bridgescaler/backend.py b/bridgescaler/backend.py index adcfaae..4ef775b 100644 --- a/bridgescaler/backend.py +++ b/bridgescaler/backend.py @@ -156,4 +156,4 @@ def create_synthetic_data(): for l in range(locs.shape[0]): x_data_dict[names[l]] = np.random.normal(loc=locs[l], scale=scales[l], size=num_examples) x_data = pd.DataFrame(x_data_dict) - return x_data + return x_data \ No newline at end of file diff --git a/bridgescaler/backend_tensor.py b/bridgescaler/backend_tensor.py new file mode 100644 index 0000000..b2a200d --- /dev/null +++ b/bridgescaler/backend_tensor.py @@ -0,0 +1,43 @@ +from . import require_torch +require_torch() # enforce torch availability/version at import time +import torch + +import json +import numpy as np + +from bridgescaler.distributed_tensor import DStandardScalerTensor, DMinMaxScalerTensor +from .backend import NumpyEncoder, object_hook + + +scaler_objs = {"DStandardScalerTensor": DStandardScalerTensor, + "DMinMaxScalerTensor": DMinMaxScalerTensor, + } + + +def print_scaler_tensor(scaler): + """ + Modify the print_scaler() in backend.py for tensors. + """ + scaler_params = scaler.__dict__ + scaler_params["type"] = str(type(scaler))[1:-2].split(".")[-1] + + for keys in scaler_params: + if type(scaler_params[keys]) == torch.Tensor: + scaler_params[keys] = scaler_params[keys].cpu().numpy().copy() + + return json.dumps(scaler_params, indent=4, sort_keys=True, cls=NumpyEncoder) + + +def read_scaler_tensor(scaler_str): + """ + Modify the read_scaler() in backend.py for tensors. + """ + scaler_params = json.loads(scaler_str, object_hook=object_hook) + scaler = scaler_objs[scaler_params["type"]]() + del scaler_params["type"] + for k, v in scaler_params.items(): + if k == "x_columns_": + setattr(scaler, k, v) + else: + setattr(scaler, k, torch.tensor(v)) + return scaler \ No newline at end of file diff --git a/bridgescaler/distributed_tensor.py b/bridgescaler/distributed_tensor.py index a81f179..834c0f8 100644 --- a/bridgescaler/distributed_tensor.py +++ b/bridgescaler/distributed_tensor.py @@ -1,23 +1,8 @@ -from copy import deepcopy -import importlib.util - -from packaging import version +from . import require_torch +require_torch() # enforce torch availability/version at import time import torch -REQUIRED_VERSION = "2.0.0" # required torch version - -# Check if PyTorch is installed -if importlib.util.find_spec("torch") is None: - raise ImportError("PyTorch is not installed") - -installed_version = torch.__version__ - -# Validate version -if version.parse(installed_version) < version.parse(REQUIRED_VERSION): - raise RuntimeError( - f"PyTorch version mismatch: required {REQUIRED_VERSION}, " - f"found {installed_version}" - ) +from copy import deepcopy class DBaseScalerTensor: @@ -37,23 +22,71 @@ def is_fit(self): @staticmethod def extract_x_columns(x, channels_last=True): """ - Extract column indices to be transformed from x. All of these assume that the columns are in the last dimension. + Extract the variable (column) names from input x + + The variable/channel names are stored as a list of strings in the + `variable_names` attribute. If this attribute does not exist, an error will be raised. + This extraction assumes that the channels are located in the last dimension. Args: - x (torch.tensor): tensor of values to be transformed. + x (torch.tensor): tensor of values. channels_last (bool): If True, then assume the variable or channel dimension is the last dimension of the array. If False, then assume the variable or channel dimension is second. Returns: - x_columns (torch.tensor): tensor of column indices. + x_columns (torch.tensor): list of strings with variable names """ var_dim_num = -1 if not channels_last: var_dim_num = 1 assert isinstance(x, torch.Tensor), "Input must be a PyTorch tensor" - x_columns = torch.arange(x.shape[var_dim_num]) + if hasattr(x, 'variable_names'): + x_columns = x.variable_names + else: + x_columns = list(range(x.shape[var_dim_num])) + #assert getattr(x, 'variable_names', None) is not None, "variable_names attribute is missing or empty" + assert len(x_columns) == len( + set(x_columns)), f"Duplicates found! Unique count: {len(set(x_columns))}, Total count: {len(x_columns)}" return x_columns + @staticmethod + def extract_array(x): + pass + + def get_column_order(self, x_in_columns): + """ + Get the indices of the scaler columns that have the same name as the variables (columns) in the input x tensor. This + enables users to pass a torch.Tensor to transform or inverse_transform with fewer variables than + the original scaler or variables in a different order and still have the input dataset be transformed properly. + + Args: + x_in_columns (list): list of input variable names + + Returns: + x_in_col_indices (torch.Tensor): indices of the input variables from x in the scaler in order. + """ + assert all(var in self.x_columns_ for var in x_in_columns), "Some input variables not in scaler x_columns." + x_in_col_indices = [self.x_columns_.index(item) for item in x_in_columns if item in self.x_columns_] + return x_in_col_indices + + @staticmethod + def package_transformed_x(x_transformed, x): + """ + Repackaged a transformed torch.Tensor into the same datatype as the original x, including + all metadata. + + Args: + x_transformed (torch.Tensor): array after being transformed or inverse transformed + x (torch.Tensor) + + Returns: + + """ + x_packaged = x_transformed + if getattr(x, 'variable_names', None) is not None: + x_packaged.variable_names = x.variable_names + return x_packaged + def set_channel_dim(self, channels_last=None): if channels_last is None: channels_last = self.channels_last @@ -68,12 +101,13 @@ def process_x_for_transform(self, x, channels_last=None): channels_last = self.channels_last channel_dim = self.set_channel_dim(channels_last) assert self._fit, "Scaler has not been fit." - assert ( - x.shape[channel_dim] == self.x_columns_.shape[0] - ), "Number of input columns does not match scaler." - x_col_order = torch.arange(x.shape[channel_dim]) + x_in_cols = self.extract_x_columns(x, channels_last=channels_last) + #assert ( + # x.shape[channel_dim] == len(self.x_columns_) + #), "Number of input columns does not match scaler." + x_col_order = self.get_column_order(x_in_cols) xv = x - x_transformed = torch.zeros(xv.shape, dtype=xv.dtype) + x_transformed = torch.zeros(xv.shape, dtype=xv.dtype, device=xv.device) return xv, x_transformed, channels_last, channel_dim, x_col_order def fit(self, x, weight=None): @@ -98,6 +132,16 @@ def subset_columns(self, sel_columns): def add_variables(self, other): pass + @staticmethod + def reshape_to_channels_first(stat, target): + """Reshapes 'stat' to align with the channel dimension (index 1).""" + return stat.view(*(stat.size(0) if i == 1 else 1 for i in range(target.dim()))) + + @staticmethod + def reshape_to_channels_last(stat, target): + """Reshapes 'stat' to align with the last dimension.""" + return stat.view(*(stat.size(0) if i == target.dim() - 1 else 1 for i in range(target.dim()))) + class DStandardScalerTensor(DBaseScalerTensor): """ @@ -120,59 +164,53 @@ def fit(self, x, weight=None): self.x_columns_ = x_columns if len(xv.shape) > 2: if self.channels_last: - self.n_ += torch.prod(torch.tensor(xv.shape[:-1])) + self.n_ += torch.prod(torch.tensor(xv.shape[:-1], dtype=xv.dtype, device=xv.device)) else: self.n_ += xv.shape[0] * \ - torch.prod(torch.tensor(xv.shape[2:])) + torch.prod(torch.tensor(xv.shape[2:], dtype=xv.dtype, device=xv.device)) else: self.n_ += xv.shape[0] - self.mean_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype) - self.var_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype) + self.mean_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype, device=xv.device) + self.var_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype, device=xv.device) if self.channels_last: - for i in range(xv.shape[channel_dim]): - self.mean_x_[i] = torch.mean(xv[..., i]) - self.var_x_[i] = torch.var(xv[..., i], correction=0) + self.mean_x_ = torch.mean(xv, dim=tuple(range(xv.ndim - 1))) + self.var_x_ = torch.var(xv, dim=tuple(range(xv.ndim - 1)), correction=0) else: - for i in range(xv.shape[channel_dim]): - self.mean_x_[i] = torch.mean(xv[:, i]) - self.var_x_[i] = torch.var(xv[:, i], correction=0) + self.mean_x_ = torch.mean(xv, dim=tuple(d for d in range(xv.ndim) if d != 1)) + self.var_x_ = torch.var(xv, dim=tuple(d for d in range(xv.ndim) if d != 1), correction=0) else: # Update existing scaler with new data assert ( - x.shape[channel_dim] == self.x_columns_.shape[0] + x.shape[channel_dim] == len(self.x_columns_) ), "New data has a different number of columns" - if self.channels_last: - x_col_order = torch.arange(x.shape[-1]) - else: - x_col_order = torch.arange(x.shape[1]) + x_col_order = self.get_column_order(x_columns) if len(xv.shape) > 2: if self.channels_last: - new_n = torch.prod(torch.tensor(xv.shape[:-1])) + new_n = torch.prod(torch.tensor(xv.shape[:-1], dtype=xv.dtype, device=xv.device)) else: new_n = xv.shape[0] * \ - torch.prod(torch.tensor(xv.shape[2:])) + torch.prod(torch.tensor(xv.shape[2:], dtype=xv.dtype, device=xv.device)) else: new_n = xv.shape[0] - for i, o in enumerate(x_col_order): - if self.channels_last: - new_mean = torch.mean(xv[..., i]) - new_var = torch.var(xv[..., i], correction=0) - else: - new_mean = torch.mean(xv[:, i]) - new_var = torch.var(xv[:, i], correction=0) - combined_mean = (self.n_ * self.mean_x_[o] + new_n * new_mean) / ( - self.n_ + new_n - ) - weighted_var = (self.n_ * self.var_x_[o] + new_n * new_var) / ( - self.n_ + new_n - ) - var_correction = ( - self.n_ * new_n * (self.mean_x_[o] - new_mean) ** 2 - ) / ((self.n_ + new_n) ** 2) - self.mean_x_[o] = combined_mean - self.var_x_[o] = weighted_var + var_correction + if self.channels_last: + new_mean = torch.mean(xv[...,x_col_order], dim=tuple(range(xv.ndim - 1))) + new_var = torch.var(xv[...,x_col_order], dim=tuple(range(xv.ndim - 1)), correction=0) + else: + new_mean = torch.mean(xv[:, x_col_order], dim=tuple(d for d in range(xv.ndim) if d != 1)) + new_var = torch.var(xv[:, x_col_order], dim=tuple(d for d in range(xv.ndim) if d != 1), correction=0) + combined_mean = (self.n_ * self.mean_x_ + new_n * new_mean) / ( + self.n_ + new_n + ) + weighted_var = (self.n_ * self.var_x_ + new_n * new_var) / ( + self.n_ + new_n + ) + var_correction = ( + self.n_ * new_n * (self.mean_x_ - new_mean) ** 2 + ) / ((self.n_ + new_n) ** 2) + self.mean_x_ = combined_mean + self.var_x_ = weighted_var + var_correction self.n_ += new_n self._fit = True @@ -196,16 +234,15 @@ def transform(self, x, channels_last=None): channel_dim, x_col_order, ) = self.process_x_for_transform(x, channels_last) - x_mean, x_var = self.get_scales() + x_mean, x_var = self.get_scales(x_col_order) if channels_last: - for i, o in enumerate(x_col_order): - x_transformed[..., i] = ( - xv[..., i] - x_mean[o]) / torch.sqrt(x_var[o]) + x_transformed = ( + xv - self.reshape_to_channels_last(x_mean.to(device=xv.device), xv)) / torch.sqrt(self.reshape_to_channels_last(x_var.to(device=xv.device), xv)) else: - for i, o in enumerate(x_col_order): - x_transformed[:, i] = ( - xv[:, i] - x_mean[o]) / torch.sqrt(x_var[o]) - return x_transformed + x_transformed = ( + xv - self.reshape_to_channels_first(x_mean.to(device=xv.device), xv)) / torch.sqrt(self.reshape_to_channels_first(x_var.to(device=xv.device), xv)) + x_transformed_final = self.package_transformed_x(x_transformed, x) + return x_transformed_final def inverse_transform(self, x, channels_last=None): ( @@ -215,26 +252,24 @@ def inverse_transform(self, x, channels_last=None): channel_dim, x_col_order, ) = self.process_x_for_transform(x, channels_last) - x_mean, x_var = self.get_scales() + x_mean, x_var = self.get_scales(x_col_order) if channels_last: - for i, o in enumerate(x_col_order): - x_transformed[..., i] = xv[..., i] * \ - torch.sqrt(x_var[o]) + x_mean[o] + x_transformed = xv * \ + torch.sqrt(self.reshape_to_channels_last(x_var.to(device=xv.device), xv)) + self.reshape_to_channels_last(x_mean.to(device=xv.device), xv) else: - for i, o in enumerate(x_col_order): - x_transformed[:, i] = xv[:, i] * \ - torch.sqrt(x_var[o]) + x_mean[o] - return x_transformed + x_transformed = xv * \ + torch.sqrt(self.reshape_to_channels_first(x_var.to(device=xv.device), xv)) + self.reshape_to_channels_first(x_mean.to(device=xv.device), xv) + x_transformed_final = self.package_transformed_x(x_transformed, x) + return x_transformed_final - def get_scales(self): - return self.mean_x_, self.var_x_ + def get_scales(self, x_col_order=slice(None)): + return self.mean_x_[x_col_order], self.var_x_[x_col_order] def __add__(self, other): assert ( type(other) is DStandardScalerTensor ), "Input is not DStandardScalerTensor" - assert torch.all( - other.x_columns_ == self.x_columns_ + assert (other.x_columns_ == self.x_columns_ ), "Scaler columns do not match." current = deepcopy(self) current.mean_x_ = (self.n_ * self.mean_x_ + other.n_ * other.mean_x_) / ( @@ -254,9 +289,8 @@ def __add__(self, other): class DMinMaxScalerTensor(DBaseScalerTensor): """ Distributed MinMaxScaler enables calculation of min and max of variables in datasets in parallel, then combining - the mins and maxes as a reduction step. Scaler - supports torch.tensor and will return a transformed array in the - same form as the original with column or coordinate names preserved. + the mins and maxes as a reduction step. Scaler supports torch.Tensor and will return a transformed tensor in the + same form as the original with variable/column names preserved. """ def __init__(self, channels_last=True): @@ -270,40 +304,33 @@ def fit(self, x, weight=None): channel_dim = self.set_channel_dim() if not self._fit: self.x_columns_ = x_columns - self.max_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype) - self.min_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype) + self.max_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype, device=xv.device) + self.min_x_ = torch.zeros(xv.shape[channel_dim], dtype=xv.dtype, device=xv.device) if self.channels_last: - for i in range(xv.shape[channel_dim]): - self.max_x_[i] = torch.max(xv[..., i]) - self.min_x_[i] = torch.min(xv[..., i]) + self.max_x_ = torch.amax(xv, dim=tuple(range(xv.ndim - 1))) + self.min_x_ = torch.amin(xv, dim=tuple(range(xv.ndim - 1))) else: - for i in range(xv.shape[channel_dim]): - self.max_x_[i] = torch.max(xv[:, i]) - self.min_x_[i] = torch.min(xv[:, i]) + self.max_x_ = torch.amax(xv, dim=tuple(d for d in range(xv.ndim) if d != 1)) + self.min_x_ = torch.amin(xv, dim=tuple(d for d in range(xv.ndim) if d != 1)) else: # Update existing scaler with new data assert ( - x.shape[channel_dim] == self.x_columns_.shape[0] + x.shape[channel_dim] == len(self.x_columns_) ), "New data has a different number of columns" + x_col_order = self.get_column_order(x_columns) if self.channels_last: - x_col_order = torch.arange(x.shape[-1]) - else: - x_col_order = torch.arange(x.shape[1]) - if self.channels_last: - for i, o in enumerate(x_col_order): - self.max_x_[o] = torch.maximum( - self.max_x_[o], torch.max(xv[..., i]) - ) - self.min_x_[o] = torch.minimum( - self.min_x_[o], torch.min(xv[..., i]) - ) + self.max_x_ = torch.maximum( + self.max_x_, torch.amax(xv, dim=tuple(range(xv.ndim - 1))) + ) + self.min_x_ = torch.minimum( + self.min_x_, torch.amin(xv, dim=tuple(range(xv.ndim - 1))) + ) else: - for i, o in enumerate(xv.shape[channel_dim]): - self.max_x_[o] = torch.maximum( - self.max_x_[o], torch.max(xv[:, i])) - self.min_x_[o] = torch.minimum( - self.min_x_[o], torch.min(xv[:, i])) + self.max_x_ = torch.maximum( + self.max_x_, torch.amax(xv, dim=tuple(d for d in range(xv.ndim) if d != 1))) + self.min_x_ = torch.minimum( + self.min_x_, torch.amin(xv, dim=tuple(d for d in range(xv.ndim) if d != 1))) self._fit = True def transform(self, x, channels_last=None): @@ -314,17 +341,17 @@ def transform(self, x, channels_last=None): channel_dim, x_col_order, ) = self.process_x_for_transform(x, channels_last) + x_min, x_max = self.get_scales(x_col_order) if channels_last: - for i, o in enumerate(x_col_order): - x_transformed[..., i] = (xv[..., i] - self.min_x_[o]) / ( - self.max_x_[o] - self.min_x_[o] - ) + x_transformed = (xv - self.reshape_to_channels_last(x_min.to(device=xv.device), xv)) / ( + self.reshape_to_channels_last(x_max.to(device=xv.device), xv) - self.reshape_to_channels_last(x_min.to(device=xv.device), xv) + ) else: - for i, o in enumerate(x_col_order): - x_transformed[:, i] = (xv[:, i] - self.min_x_[o]) / ( - self.max_x_[o] - self.min_x_[o] - ) - return x_transformed + x_transformed = (xv - self.reshape_to_channels_first(x_min.to(device=xv.device), xv)) / ( + self.reshape_to_channels_first(x_max.to(device=xv.device), xv) - self.reshape_to_channels_first(x_min.to(device=xv.device), xv) + ) + x_transformed_final = self.package_transformed_x(x_transformed, x) + return x_transformed_final def inverse_transform(self, x, channels_last=None): ( @@ -334,28 +361,26 @@ def inverse_transform(self, x, channels_last=None): channel_dim, x_col_order, ) = self.process_x_for_transform(x, channels_last) + x_min, x_max = self.get_scales(x_col_order) if channels_last: - for i, o in enumerate(x_col_order): - x_transformed[..., i] = ( - xv[..., i] * (self.max_x_[o] - self.min_x_[o] - ) + self.min_x_[o] - ) + x_transformed = ( + xv * (self.reshape_to_channels_last(x_max.to(device=xv.device), xv) - self.reshape_to_channels_last(x_min.to(device=xv.device), xv) + ) + self.reshape_to_channels_last(x_min.to(device=xv.device), xv) + ) else: - for i, o in enumerate(x_col_order): - x_transformed[:, i] = ( - xv[:, i] * (self.max_x_[o] - self.min_x_[o]) + - self.min_x_[o] - ) - return x_transformed + x_transformed = ( + xv * (self.reshape_to_channels_first(x_max.to(device=xv.device), xv) - self.reshape_to_channels_first(x_min.to(device=xv.device), xv)) + + self.reshape_to_channels_first(x_min.to(device=xv.device), xv) + ) + x_transformed_final = self.package_transformed_x(x_transformed, x) + return x_transformed_final - def get_scales(self): - return self.min_x_, self.max_x_ + def get_scales(self, x_col_order=slice(None)): + return self.min_x_[x_col_order], self.max_x_[x_col_order] def __add__(self, other): assert type(other) is DMinMaxScalerTensor, "Input is not DMinMaxScaler" - assert torch.all( - other.x_columns_ == self.x_columns_ - ), "Scaler columns do not match." + assert other.x_columns_ == self.x_columns_, "Scaler columns do not match." current = deepcopy(self) current.max_x_ = torch.maximum(self.max_x_, other.max_x_) current.min_x_ = torch.minimum(self.min_x_, other.min_x_) diff --git a/setup.cfg b/setup.cfg index 1b0fcbb..605cf17 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,8 +26,8 @@ setup_requires = setuptools python_requires = >=3.7 install_requires = scikit-learn>=1.0 - numpy - pandas + numpy<2.4 + pandas<3 crick scipy xarray