bridgescaler.backend#
Attributes#
Classes#
Custom encoder for numpy data types |
Functions#
Validates torch installation and load the module. |
|
|
Save a scikit-learn or bridgescaler scaler object to json format. |
|
Output scikit-learn or bridgescaler scaler object to json string. |
|
|
|
Initialize scikit-learn or bridgescaler scaler from json str. |
|
Initialize scikit-learn or bridgescaler scaler from saved json file. |
|
Recursively applies an operation to each leaf value in a nested dictionary. |
|
Serializes and saves a nested dictionary of Bridgescaler scalers to a JSON file. |
|
Loads and deserializes a nested dictionary of Bridgescaler scalers from a JSON file. |
|
Recursively traverses a nested dict of tensor variables and applies a scaler method to each variable. |
Module Contents#
- bridgescaler.backend.scaler_objs#
- bridgescaler.backend.ensure_torch()#
Validates torch installation and load the module.
- bridgescaler.backend.save_scaler(scaler, scaler_file)#
Save a scikit-learn or bridgescaler scaler object to json format.
- Parameters:
scaler – scikit-learn-style scaler object
scaler_file – path to json file where scaler information is stored.
- bridgescaler.backend.print_scaler(scaler)#
Output scikit-learn or bridgescaler scaler object to json string.
- Parameters:
scaler – scikit-learn-style scaler object
- Returns:
str representation of object in json format
- bridgescaler.backend.object_hook(dct: dict[Any, Any])#
- bridgescaler.backend.read_scaler(scaler_str)#
Initialize scikit-learn or bridgescaler scaler from json str.
- Parameters:
scaler_str – json str
- Returns:
scaler object.
- bridgescaler.backend.load_scaler(scaler_file)#
Initialize scikit-learn or bridgescaler scaler from saved json file.
- Parameters:
scaler_file – path to json file.
- Returns:
scaler object.
- bridgescaler.backend.apply_to_dict_leaves(d, operation)#
Recursively applies an operation to each leaf value in a nested dictionary.
- Parameters:
d (dict) – A nested dictionary where the operation will be applied to each leaf value.
operation (callable) – A function to apply to each leaf value.
- Returns:
- A nested dictionary with the same structure as
d, where each leaf is the result of
operation(leaf).
- A nested dictionary with the same structure as
- Return type:
dict
- bridgescaler.backend.save_scaler_dict(scaler_dict, scaler_dict_file)#
Serializes and saves a nested dictionary of Bridgescaler scalers to a JSON file.
- Parameters:
scaler_dict (dict) – A nested dictionary of fitted Bridgescaler scaler objects to be saved.
scaler_dict_file (str or Path) – The file path where the scaler dictionary will be saved as a JSON file.
- bridgescaler.backend.load_scaler_dict(scaler_dict_file)#
Loads and deserializes a nested dictionary of Bridgescaler scalers from a JSON file.
- Parameters:
scaler_dict_file (str or Path) – The file path to the JSON file containing the serialized scaler dictionary.
- Returns:
- A nested dictionary of reconstructed scaler objects, with the
same structure as the original dictionary passed to
save_scaler_dict.
- Return type:
dict
- class bridgescaler.backend.NumpyEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)#
Bases:
json.JSONEncoderCustom encoder for numpy data types
- default(obj)#
Implement this method in a subclass such that it returns a serializable object for
o, or calls the base implementation (to raise aTypeError).For example, to support arbitrary iterators, you could implement default like this:
def default(self, o): try: iterable = iter(o) except TypeError: pass else: return list(iterable) # Let the base class default method raise the TypeError return super().default(o)
- bridgescaler.backend.scale_var_dict(var_dict, scalers, method, var_list=None, _key_path=())#
Recursively traverses a nested dict of tensor variables and applies a scaler method to each variable.
- Parameters:
var_dict (dict) – A nested dictionary where leaves are variables in torch.Tensor to be scaled.
scalers (object or dict) – A single scaler instance (for
fitandfit_transform) or a nested dict of scalers matching the structure ofvar_dict(fortransformandinverse_transform).method (str) – The scaler method to apply. Must be one of
fit,transform,inverse_transform, orfit_transform.var_list (list of str, optional) – A list of leaf key names to apply the scaler method to. Keys not in
var_listare skipped duringfit, and left unchanged duringtransform,inverse_transform, andfit_transform. IfNone, all leaf keys are processed.
- Returns:
- A nested dictionary with the same structure as
var_dict, where each leaf is either a fitted scaler (for
fit) or a transformed variable (fortransform,inverse_transform,fit_transform). Keys namedmetadataand keys excluded byvar_listare omitted forfit, and passed through unchanged for other methods.
- A nested dictionary with the same structure as
- Return type:
dict
- Raises:
AssertionError – If
var_dictis not a dict.AssertionError – If
methodis not one of the valid methods.AssertionError – If
scalersis not a dict when usingtransformorinverse_transform.AssertionError – If a key path in
var_dictis missing inscalers.AssertionError – If a scaler at a given key path does not have the requested
method.
Example
>>> import torch >>> from bridgescaler.distributed_tensor import DStandardScalerTensor >>> from bridgescaler.backend import scale_var_dict >>> T = torch.randn((20, 5, 4, 8)) >>> var_dict = { "era5": { "input": {"era5/prognostic/3d/T": T}, "target": {"era5/prognostic/3d/T": T}, "metadata": {"input_datetime": int, "target_datetime": int} } } >>> scalers = DStandardScaler(channels_last=False) >>> scaler_dict = scale_var_dict(var_dict, scalers, method="fit") >>> transformed = scale_var_dict(var_dict, scaler_dict, method="transform") >>> inverse_transformed = scale_var_dict(transformed, scaler_dict, method="inverse_transform") >>> fitted_transformed = scale_var_dict(var_dict, scalers, method="fit_transform") >>> # Only scale specific variables >>> filtered = scale_var_dict(var_dict, scaler_dict, method="transform", var_list=["era5/prognostic/3d/T"])
- bridgescaler.backend.create_synthetic_data()#