bridgescaler.backend#

Attributes#

Classes#

NumpyEncoder

Custom encoder for numpy data types

Functions#

ensure_torch()

Validates torch installation and load the module.

save_scaler(scaler, scaler_file)

Save a scikit-learn or bridgescaler scaler object to json format.

print_scaler(scaler)

Output scikit-learn or bridgescaler scaler object to json string.

object_hook(dct)

read_scaler(scaler_str)

Initialize scikit-learn or bridgescaler scaler from json str.

load_scaler(scaler_file)

Initialize scikit-learn or bridgescaler scaler from saved json file.

apply_to_dict_leaves(d, operation)

Recursively applies an operation to each leaf value in a nested dictionary.

save_scaler_dict(scaler_dict, scaler_dict_file)

Serializes and saves a nested dictionary of Bridgescaler scalers to a JSON file.

load_scaler_dict(scaler_dict_file)

Loads and deserializes a nested dictionary of Bridgescaler scalers from a JSON file.

scale_var_dict(var_dict, scalers, method[, var_list, ...])

Recursively traverses a nested dict of tensor variables and applies a scaler method to each variable.

create_synthetic_data()

Module Contents#

bridgescaler.backend.scaler_objs#
bridgescaler.backend.ensure_torch()#

Validates torch installation and load the module.

bridgescaler.backend.save_scaler(scaler, scaler_file)#

Save a scikit-learn or bridgescaler scaler object to json format.

Parameters:
  • scaler – scikit-learn-style scaler object

  • scaler_file – path to json file where scaler information is stored.

bridgescaler.backend.print_scaler(scaler)#

Output scikit-learn or bridgescaler scaler object to json string.

Parameters:

scaler – scikit-learn-style scaler object

Returns:

str representation of object in json format

bridgescaler.backend.object_hook(dct: dict[Any, Any])#
bridgescaler.backend.read_scaler(scaler_str)#

Initialize scikit-learn or bridgescaler scaler from json str.

Parameters:

scaler_str – json str

Returns:

scaler object.

bridgescaler.backend.load_scaler(scaler_file)#

Initialize scikit-learn or bridgescaler scaler from saved json file.

Parameters:

scaler_file – path to json file.

Returns:

scaler object.

bridgescaler.backend.apply_to_dict_leaves(d, operation)#

Recursively applies an operation to each leaf value in a nested dictionary.

Parameters:
  • d (dict) – A nested dictionary where the operation will be applied to each leaf value.

  • operation (callable) – A function to apply to each leaf value.

Returns:

A nested dictionary with the same structure as d,

where each leaf is the result of operation(leaf).

Return type:

dict

bridgescaler.backend.save_scaler_dict(scaler_dict, scaler_dict_file)#

Serializes and saves a nested dictionary of Bridgescaler scalers to a JSON file.

Parameters:
  • scaler_dict (dict) – A nested dictionary of fitted Bridgescaler scaler objects to be saved.

  • scaler_dict_file (str or Path) – The file path where the scaler dictionary will be saved as a JSON file.

bridgescaler.backend.load_scaler_dict(scaler_dict_file)#

Loads and deserializes a nested dictionary of Bridgescaler scalers from a JSON file.

Parameters:

scaler_dict_file (str or Path) – The file path to the JSON file containing the serialized scaler dictionary.

Returns:

A nested dictionary of reconstructed scaler objects, with the

same structure as the original dictionary passed to save_scaler_dict.

Return type:

dict

class bridgescaler.backend.NumpyEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)#

Bases: json.JSONEncoder

Custom encoder for numpy data types

default(obj)#

Implement this method in a subclass such that it returns a serializable object for o, or calls the base implementation (to raise a TypeError).

For example, to support arbitrary iterators, you could implement default like this:

def default(self, o):
    try:
        iterable = iter(o)
    except TypeError:
        pass
    else:
        return list(iterable)
    # Let the base class default method raise the TypeError
    return super().default(o)
bridgescaler.backend.scale_var_dict(var_dict, scalers, method, var_list=None, _key_path=())#

Recursively traverses a nested dict of tensor variables and applies a scaler method to each variable.

Parameters:
  • var_dict (dict) – A nested dictionary where leaves are variables in torch.Tensor to be scaled.

  • scalers (object or dict) – A single scaler instance (for fit and fit_transform) or a nested dict of scalers matching the structure of var_dict (for transform and inverse_transform).

  • method (str) – The scaler method to apply. Must be one of fit, transform, inverse_transform, or fit_transform.

  • var_list (list of str, optional) – A list of leaf key names to apply the scaler method to. Keys not in var_list are skipped during fit, and left unchanged during transform, inverse_transform, and fit_transform. If None, all leaf keys are processed.

Returns:

A nested dictionary with the same structure as var_dict,

where each leaf is either a fitted scaler (for fit) or a transformed variable (for transform, inverse_transform, fit_transform). Keys named metadata and keys excluded by var_list are omitted for fit, and passed through unchanged for other methods.

Return type:

dict

Raises:
  • AssertionError – If var_dict is not a dict.

  • AssertionError – If method is not one of the valid methods.

  • AssertionError – If scalers is not a dict when using transform or inverse_transform.

  • AssertionError – If a key path in var_dict is missing in scalers.

  • AssertionError – If a scaler at a given key path does not have the requested method.

Example

>>> import torch
>>> from bridgescaler.distributed_tensor import DStandardScalerTensor
>>> from bridgescaler.backend import scale_var_dict
>>> T = torch.randn((20, 5, 4, 8))
>>> var_dict = {
    "era5": {
        "input": {"era5/prognostic/3d/T": T},
        "target": {"era5/prognostic/3d/T": T},
        "metadata": {"input_datetime": int, "target_datetime": int}
        }
    }
>>> scalers = DStandardScaler(channels_last=False)
>>> scaler_dict = scale_var_dict(var_dict, scalers, method="fit")
>>> transformed = scale_var_dict(var_dict, scaler_dict, method="transform")
>>> inverse_transformed = scale_var_dict(transformed, scaler_dict, method="inverse_transform")
>>> fitted_transformed = scale_var_dict(var_dict, scalers, method="fit_transform")
>>> # Only scale specific variables
>>> filtered = scale_var_dict(var_dict, scaler_dict, method="transform", var_list=["era5/prognostic/3d/T"])
bridgescaler.backend.create_synthetic_data()#