bridgescaler
============

.. py:module:: bridgescaler


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/bridgescaler/_version/index
   /autoapi/bridgescaler/backend/index
   /autoapi/bridgescaler/deep/index
   /autoapi/bridgescaler/distributed/index
   /autoapi/bridgescaler/distributed_tensor/index
   /autoapi/bridgescaler/group/index


Classes
-------

.. autoapisummary::

   bridgescaler.GroupStandardScaler
   bridgescaler.GroupRobustScaler
   bridgescaler.GroupMinMaxScaler
   bridgescaler.DeepStandardScaler
   bridgescaler.DeepMinMaxScaler
   bridgescaler.DeepQuantileTransformer
   bridgescaler.DStandardScaler
   bridgescaler.DMinMaxScaler
   bridgescaler.DQuantileScaler


Functions
---------

.. autoapisummary::

   bridgescaler.save_scaler
   bridgescaler.load_scaler
   bridgescaler.print_scaler
   bridgescaler.read_scaler
   bridgescaler.save_scaler_dict
   bridgescaler.load_scaler_dict
   bridgescaler.scale_var_dict


Package Contents
----------------

.. py:function:: save_scaler(scaler, scaler_file)

   Save a scikit-learn or bridgescaler scaler object to json format.

   :param scaler: scikit-learn-style scaler object
   :param scaler_file: path to json file where scaler information is stored.


.. py:function:: load_scaler(scaler_file)

   Initialize scikit-learn or bridgescaler scaler from saved json file.

   :param scaler_file: path to json file.

   :returns: scaler object.


.. py:function:: print_scaler(scaler)

   Output scikit-learn or bridgescaler scaler object to json string.

   :param scaler: scikit-learn-style scaler object

   :returns: str representation of object in json format


.. py:function:: read_scaler(scaler_str)

   Initialize scikit-learn or bridgescaler scaler from json str.

   :param scaler_str: json str

   :returns: scaler object.


.. py:function:: save_scaler_dict(scaler_dict, scaler_dict_file)

   Serializes and saves a nested dictionary of Bridgescaler scalers to a JSON file.

   :param scaler_dict: A nested dictionary of fitted Bridgescaler scaler objects
                       to be saved.
   :type scaler_dict: dict
   :param scaler_dict_file: The file path where the scaler
                            dictionary will be saved as a JSON file.
   :type scaler_dict_file: str or Path


.. py:function:: load_scaler_dict(scaler_dict_file)

   Loads and deserializes a nested dictionary of Bridgescaler scalers from a JSON file.

   :param scaler_dict_file: The file path to the JSON file
                            containing the serialized scaler dictionary.
   :type scaler_dict_file: str or Path

   :returns:

             A nested dictionary of reconstructed scaler objects, with the
                 same structure as the original dictionary passed to
                 ``save_scaler_dict``.
   :rtype: dict


.. py:function:: scale_var_dict(var_dict, scalers, method, var_list=None, _key_path=())

   Recursively traverses a nested dict of tensor variables and applies a scaler method to each variable.

   :param var_dict: A nested dictionary where leaves are variables in torch.Tensor to be scaled.
   :type var_dict: dict
   :param scalers: A single scaler instance (for ``fit`` and
                   ``fit_transform``) or a nested dict of scalers matching the structure
                   of ``var_dict`` (for ``transform`` and ``inverse_transform``).
   :type scalers: object or dict
   :param method: The scaler method to apply. Must be one of ``fit``,
                  ``transform``, ``inverse_transform``, or ``fit_transform``.
   :type method: str
   :param var_list: A list of leaf key names to apply the
                    scaler method to. Keys not in ``var_list`` are skipped during ``fit``,
                    and left unchanged during ``transform``, ``inverse_transform``, and
                    ``fit_transform``. If ``None``, all leaf keys are processed.
   :type var_list: list of str, optional

   :returns:

             A nested dictionary with the same structure as ``var_dict``,
                 where each leaf is either a fitted scaler (for ``fit``) or a
                 transformed variable (for ``transform``, ``inverse_transform``,
                 ``fit_transform``). Keys named ``metadata`` and keys excluded by
                 ``var_list`` are omitted for ``fit``, and passed through unchanged
                 for other methods.
   :rtype: dict

   :raises AssertionError: If ``var_dict`` is not a dict.
   :raises AssertionError: If ``method`` is not one of the valid methods.
   :raises AssertionError: If ``scalers`` is not a dict when using ``transform``
       or ``inverse_transform``.
   :raises AssertionError: If a key path in ``var_dict`` is missing in ``scalers``.
   :raises AssertionError: If a scaler at a given key path does not have the
       requested ``method``.

   .. rubric:: Example

   >>> import torch
   >>> from bridgescaler.distributed_tensor import DStandardScalerTensor
   >>> from bridgescaler.backend import scale_var_dict
   >>> T = torch.randn((20, 5, 4, 8))
   >>> var_dict = {
       "era5": {
           "input": {"era5/prognostic/3d/T": T},
           "target": {"era5/prognostic/3d/T": T},
           "metadata": {"input_datetime": int, "target_datetime": int}
           }
       }
   >>> scalers = DStandardScaler(channels_last=False)
   >>> scaler_dict = scale_var_dict(var_dict, scalers, method="fit")
   >>> transformed = scale_var_dict(var_dict, scaler_dict, method="transform")
   >>> inverse_transformed = scale_var_dict(transformed, scaler_dict, method="inverse_transform")
   >>> fitted_transformed = scale_var_dict(var_dict, scalers, method="fit_transform")
   >>> # Only scale specific variables
   >>> filtered = scale_var_dict(var_dict, scaler_dict, method="transform", var_list=["era5/prognostic/3d/T"])


.. py:class:: GroupStandardScaler

   Bases: :py:obj:`GroupBaseScaler`


   Scaler that enables calculation and sharing of scaling parameters among multiple variables via variable groupings.
   This is useful for situations where variables are related, such as temperatures at different height levels.

   Groups are specified as a list of column ids, which can be column names for pandas dataframes or column indices
   for numpy arrays.

   For example:
   ```
   groups = [["a", "b"], ["c", "d"], "e"]
   ```
   "a" and "b" are a single group and all values of both will be included when calculating the mean and standard
   deviation for that group.


   .. py:attribute:: center_
      :value: None



   .. py:attribute:: scale_
      :value: None



   .. py:method:: _fit(x, groups=None)


   .. py:method:: _transform_column(x_column, group_index)


   .. py:method:: _inverse_transform_column(x_column, group_index)


.. py:class:: GroupRobustScaler(quartile_range=(25.0, 75.0))

   Bases: :py:obj:`GroupBaseScaler`


   Group version of RobustScaler



   .. py:attribute:: quartile_range
      :value: (25.0, 75.0)



   .. py:attribute:: center_
      :value: None



   .. py:attribute:: scale_
      :value: None



   .. py:method:: _fit(x, groups)


   .. py:method:: _transform_column(x_column, group_index)


   .. py:method:: _inverse_transform_column(x_column, group_index)


.. py:class:: GroupMinMaxScaler(feature_range=(0, 1))

   Bases: :py:obj:`GroupBaseScaler`


   Group version of MinMaxScaler


   .. py:attribute:: feature_range
      :value: (0, 1)



   .. py:attribute:: mins_
      :value: None



   .. py:attribute:: maxes_
      :value: None



   .. py:method:: _fit(x, groups)


   .. py:method:: _transform_column(x_column, group_index)


   .. py:method:: _inverse_transform_column(x_column, group_index)


.. py:class:: DeepStandardScaler

   Bases: :py:obj:`object`


   Calculate standard scaler scores on an arbitrarily dimensional dataset as long as the last dimension is
   the variable dimension.



   .. py:attribute:: mean_
      :value: None



   .. py:attribute:: sd_
      :value: None



   .. py:method:: fit(x)


   .. py:method:: transform(x)


   .. py:method:: fit_transform(x)


   .. py:method:: inverse_transform(x)


.. py:class:: DeepMinMaxScaler

   Bases: :py:obj:`object`


   .. py:attribute:: max_
      :value: None



   .. py:attribute:: min_
      :value: None



   .. py:method:: fit(x)


   .. py:method:: transform(x)


   .. py:method:: fit_transform(x)


   .. py:method:: inverse_transform(x)


.. py:class:: DeepQuantileTransformer(n_quantiles=1000, stochastic=False)

   Bases: :py:obj:`object`


   Performs a quantile transform on N-dimensional arrays where the variable dimension is the last one.

   .. attribute:: n_quantiles

      number of quantiles to calculate and store

   .. attribute:: stochastic

      When transforming to quantile space, whether to take the mean of the left and right interpolation values (False)
      or to pick a random point in between (True).


   .. py:attribute:: n_quantiles
      :value: 1000



   .. py:attribute:: stochastic
      :value: False



   .. py:attribute:: quantiles_
      :value: None



   .. py:attribute:: references_
      :value: None



   .. py:attribute:: fitted_
      :value: False



   .. py:attribute:: x_column_names_
      :value: None



   .. py:method:: fit(x)


   .. py:method:: transform(x)


   .. py:method:: fit_transform(x)


   .. py:method:: inverse_transform(x)


   .. py:method:: _transform_col(x_col, col_index)


   .. py:method:: _inverse_transform_col(x_col, col_index)


.. py:class:: DStandardScaler(channels_last=True)

   Bases: :py:obj:`DBaseScaler`


   Distributed version of StandardScaler. You can calculate this map-reduce style by running it on individual
   data files, return the fitted objects, and then sum them together to represent the full dataset. Scaler
   supports numpy arrays, pandas dataframes, and xarray DataArrays and will return a transformed array in the
   same form as the original with column or coordinate names preserved.



   .. py:attribute:: mean_x_
      :value: None



   .. py:attribute:: n_
      :value: 0



   .. py:attribute:: var_x_
      :value: None



   .. py:method:: fit(x, weight=None)


   .. py:method:: transform(x, channels_last=None)

      Transform the input data from its original form to standard scaled form. If your input data has a
      different dimension order than the data used to fit the scaler, use the channels_last keyword argument
      to specify whether the new data are `channels_last` (True) or `channels_first` (False).

      :param x: Input data.
      :param channels_last: Override the default channels_last parameter of the scaler.

      :returns: Transformed data in the same shape and type as x.
      :rtype: x_transformed



   .. py:method:: inverse_transform(x, channels_last=None)


   .. py:method:: get_scales()


   .. py:method:: __add__(other)


.. py:class:: DMinMaxScaler(channels_last=True)

   Bases: :py:obj:`DBaseScaler`


   Distributed MinMaxScaler enables calculation of min and max of variables in datasets in parallel then combining
   the mins and maxes as a reduction step. Scaler
   supports numpy arrays, pandas dataframes, and xarray DataArrays and will return a transformed array in the
   same form as the original with column or coordinate names preserved.



   .. py:attribute:: max_x_
      :value: None



   .. py:attribute:: min_x_
      :value: None



   .. py:method:: fit(x, weight=None)


   .. py:method:: transform(x, channels_last=None)


   .. py:method:: inverse_transform(x, channels_last=None)


   .. py:method:: get_scales()


   .. py:method:: __add__(other)


.. py:class:: DQuantileScaler(compression=250, distribution='uniform', min_val=1e-07, max_val=0.9999999, channels_last=True)

   Bases: :py:obj:`DBaseScaler`


   Distributed Quantile Scaler that uses the crick TDigest Cython library to compute quantiles across multiple
   datasets in parallel. The library can perform fitting, transforms, and inverse transforms across variables
   in parallel using the multiprocessing library. Multidimensional arrays are stored in shared memory across
   processes to minimize inter-process communication.

   DQuantileScaler supports

   .. attribute:: compression

      Recommended number of centroids to use.

   .. attribute:: distribution

      "uniform", "normal", or "logistic".

   .. attribute:: min_val

      Minimum value for quantile to prevent -inf results when distribution is normal or logistic.

   .. attribute:: max_val

      Maximum value for quantile to prevent inf results when distribution is normal or logistic.

   .. attribute:: channels_last

      Whether to assume the last dim or second dim are the channel/variable dimension.


   .. py:attribute:: compression
      :value: 250



   .. py:attribute:: distribution
      :value: 'uniform'



   .. py:attribute:: min_val
      :value: 1e-07



   .. py:attribute:: max_val
      :value: 0.9999999



   .. py:attribute:: centroids_
      :value: None



   .. py:attribute:: size_
      :value: None



   .. py:attribute:: min_
      :value: None



   .. py:attribute:: max_
      :value: None



   .. py:method:: td_objs_to_attributes(td_objs)


   .. py:method:: attributes_to_td_objs()


   .. py:method:: fit(x, weight=None)


   .. py:method:: transform(x, channels_last=None, pool=None)


   .. py:method:: fit_transform(x, channels_last=None, weight=None, pool=None)


   .. py:method:: inverse_transform(x, channels_last=None, pool=None)


   .. py:method:: __add__(other)


