Source code for torch_openreml.covariance.transform.transform_scale_shift
"""
Scale-shift transform module for affine parameter mappings.
Provides a differentiable bijective transform from
:math:`\\mathbb{R} \\rightarrow \\mathbb{R}` using a configurable
affine function.
Classes:
TransformScaleShift:
Affine transform (:math:`f(x) = ax + b`)
"""
import torch
from torch_openreml.covariance.transform.transform import Transform
[docs]
class TransformScaleShift(Transform):
r"""
Affine transform with configurable scale and shift.
.. math::
f(x) = ax + b
"""
domain = "\u211D"
codomain = "\u211D"
def __init__(self, a, b=0.0):
r"""
Initialize the affine transform.
Args:
a (float): Scale factor.
b (float): Shift offset. Defaults to ``0.0``.
"""
self.a = a
self.b = b
[docs]
def __call__(self, x):
r"""
Apply the affine transform.
Args:
x (torch.Tensor): Input tensor in :math:`\mathbb{R}`.
Returns:
torch.Tensor: Element-wise :math:`ax + b`.
Example:
.. jupyter-execute::
import torch
from torch_openreml.covariance.transform import TransformScaleShift
t = TransformScaleShift(a=2.0, b=1.0)
x = torch.tensor([0.0, 1.0, 2.0])
t(x)
"""
return self.a * x + self.b
[docs]
def inverse(self, x):
r"""
Apply the inverse transform.
Args:
x (torch.Tensor): Input tensor in :math:`\mathbb{R}`.
Returns:
torch.Tensor: Element-wise :math:`\frac{x - b}{a}`.
Example:
.. jupyter-execute::
import torch
from torch_openreml.covariance.transform import TransformScaleShift
t = TransformScaleShift(a=2.0, b=1.0)
x = torch.tensor([1.0, 3.0, 5.0])
t.inverse(x)
"""
return (x - self.b) / self.a
[docs]
def grad(self, x):
r"""
Compute derivative of :math:`ax + b` for chain rule propagation.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: [:math:`a`].
Example:
.. jupyter-execute::
import torch
from torch_openreml.covariance.transform import TransformScaleShift
t = TransformScaleShift(a=2.0, b=1.0)
x = torch.tensor([0.0])
t.grad(x)
"""
return torch.tensor([self.a], dtype=x.dtype, device=x.device)
def __repr__(self):
return f"{self.__class__.__name__}(a={self.a}, b={self.b})"