bridgescaler.backend
====================

.. py:module:: bridgescaler.backend


Attributes
----------

.. autoapisummary::

   bridgescaler.backend.scaler_objs


Classes
-------

.. autoapisummary::

   bridgescaler.backend.NumpyEncoder


Functions
---------

.. autoapisummary::

   bridgescaler.backend.ensure_torch
   bridgescaler.backend.save_scaler
   bridgescaler.backend.print_scaler
   bridgescaler.backend.object_hook
   bridgescaler.backend.read_scaler
   bridgescaler.backend.load_scaler
   bridgescaler.backend.apply_to_dict_leaves
   bridgescaler.backend.save_scaler_dict
   bridgescaler.backend.load_scaler_dict
   bridgescaler.backend.scale_var_dict
   bridgescaler.backend.create_synthetic_data


Module Contents
---------------

.. py:data:: scaler_objs

.. py:function:: ensure_torch()

   Validates torch installation and load the module.


.. py:function:: save_scaler(scaler, scaler_file)

   Save a scikit-learn or bridgescaler scaler object to json format.

   :param scaler: scikit-learn-style scaler object
   :param scaler_file: path to json file where scaler information is stored.


.. py:function:: print_scaler(scaler)

   Output scikit-learn or bridgescaler scaler object to json string.

   :param scaler: scikit-learn-style scaler object

   :returns: str representation of object in json format


.. py:function:: object_hook(dct: dict[Any, Any])

.. py:function:: read_scaler(scaler_str)

   Initialize scikit-learn or bridgescaler scaler from json str.

   :param scaler_str: json str

   :returns: scaler object.


.. py:function:: load_scaler(scaler_file)

   Initialize scikit-learn or bridgescaler scaler from saved json file.

   :param scaler_file: path to json file.

   :returns: scaler object.


.. py:function:: apply_to_dict_leaves(d, operation)

   Recursively applies an operation to each leaf value in a nested dictionary.

   :param d: A nested dictionary where the operation will be
             applied to each leaf value.
   :type d: dict
   :param operation: A function to apply to each leaf value.
   :type operation: callable

   :returns:

             A nested dictionary with the same structure as ``d``,
                 where each leaf is the result of ``operation(leaf)``.
   :rtype: dict


.. py:function:: save_scaler_dict(scaler_dict, scaler_dict_file)

   Serializes and saves a nested dictionary of Bridgescaler scalers to a JSON file.

   :param scaler_dict: A nested dictionary of fitted Bridgescaler scaler objects
                       to be saved.
   :type scaler_dict: dict
   :param scaler_dict_file: The file path where the scaler
                            dictionary will be saved as a JSON file.
   :type scaler_dict_file: str or Path


.. py:function:: load_scaler_dict(scaler_dict_file)

   Loads and deserializes a nested dictionary of Bridgescaler scalers from a JSON file.

   :param scaler_dict_file: The file path to the JSON file
                            containing the serialized scaler dictionary.
   :type scaler_dict_file: str or Path

   :returns:

             A nested dictionary of reconstructed scaler objects, with the
                 same structure as the original dictionary passed to
                 ``save_scaler_dict``.
   :rtype: dict


.. py:class:: NumpyEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)

   Bases: :py:obj:`json.JSONEncoder`


   Custom encoder for numpy data types


   .. py:method:: default(obj)

      Implement this method in a subclass such that it returns
      a serializable object for ``o``, or calls the base implementation
      (to raise a ``TypeError``).

      For example, to support arbitrary iterators, you could
      implement default like this::

          def default(self, o):
              try:
                  iterable = iter(o)
              except TypeError:
                  pass
              else:
                  return list(iterable)
              # Let the base class default method raise the TypeError
              return super().default(o)




.. py:function:: scale_var_dict(var_dict, scalers, method, var_list=None, _key_path=())

   Recursively traverses a nested dict of tensor variables and applies a scaler method to each variable.

   :param var_dict: A nested dictionary where leaves are variables in torch.Tensor to be scaled.
   :type var_dict: dict
   :param scalers: A single scaler instance (for ``fit`` and
                   ``fit_transform``) or a nested dict of scalers matching the structure
                   of ``var_dict`` (for ``transform`` and ``inverse_transform``).
   :type scalers: object or dict
   :param method: The scaler method to apply. Must be one of ``fit``,
                  ``transform``, ``inverse_transform``, or ``fit_transform``.
   :type method: str
   :param var_list: A list of leaf key names to apply the
                    scaler method to. Keys not in ``var_list`` are skipped during ``fit``,
                    and left unchanged during ``transform``, ``inverse_transform``, and
                    ``fit_transform``. If ``None``, all leaf keys are processed.
   :type var_list: list of str, optional

   :returns:

             A nested dictionary with the same structure as ``var_dict``,
                 where each leaf is either a fitted scaler (for ``fit``) or a
                 transformed variable (for ``transform``, ``inverse_transform``,
                 ``fit_transform``). Keys named ``metadata`` and keys excluded by
                 ``var_list`` are omitted for ``fit``, and passed through unchanged
                 for other methods.
   :rtype: dict

   :raises AssertionError: If ``var_dict`` is not a dict.
   :raises AssertionError: If ``method`` is not one of the valid methods.
   :raises AssertionError: If ``scalers`` is not a dict when using ``transform``
       or ``inverse_transform``.
   :raises AssertionError: If a key path in ``var_dict`` is missing in ``scalers``.
   :raises AssertionError: If a scaler at a given key path does not have the
       requested ``method``.

   .. rubric:: Example

   >>> import torch
   >>> from bridgescaler.distributed_tensor import DStandardScalerTensor
   >>> from bridgescaler.backend import scale_var_dict
   >>> T = torch.randn((20, 5, 4, 8))
   >>> var_dict = {
       "era5": {
           "input": {"era5/prognostic/3d/T": T},
           "target": {"era5/prognostic/3d/T": T},
           "metadata": {"input_datetime": int, "target_datetime": int}
           }
       }
   >>> scalers = DStandardScaler(channels_last=False)
   >>> scaler_dict = scale_var_dict(var_dict, scalers, method="fit")
   >>> transformed = scale_var_dict(var_dict, scaler_dict, method="transform")
   >>> inverse_transformed = scale_var_dict(transformed, scaler_dict, method="inverse_transform")
   >>> fitted_transformed = scale_var_dict(var_dict, scalers, method="fit_transform")
   >>> # Only scale specific variables
   >>> filtered = scale_var_dict(var_dict, scaler_dict, method="transform", var_list=["era5/prognostic/3d/T"])


.. py:function:: create_synthetic_data()

