torch_openreml.covariance.transform.TransformScaleShift¶
- class torch_openreml.covariance.transform.TransformScaleShift(a, b=0.0)[source]¶
Bases:
TransformAffine transform with configurable scale and shift.
\[f(x) = ax + b\]Initialize the affine transform.
- Parameters:
a (float) – Scale factor.
b (float) – Shift offset. Defaults to
0.0.
Methods
__call__(x)Apply the affine transform.
grad(x)Compute derivative of \(ax + b\) for chain rule propagation.
inverse(x)Apply the inverse transform.
Attributes
- domain = 'ℝ'¶
Domain of the transform.
- codomain = 'ℝ'¶
Codomain of the transform.
- __call__(x)[source]¶
Apply the affine transform.
- Parameters:
x (torch.Tensor) – Input tensor in \(\mathbb{R}\).
- Returns:
Element-wise \(ax + b\).
- Return type:
torch.Tensor
Example:
import torch from torch_openreml.covariance.transform import TransformScaleShift t = TransformScaleShift(a=2.0, b=1.0) x = torch.tensor([0.0, 1.0, 2.0]) t(x)
tensor([1., 3., 5.])
- inverse(x)[source]¶
Apply the inverse transform.
- Parameters:
x (torch.Tensor) – Input tensor in \(\mathbb{R}\).
- Returns:
Element-wise \(\frac{x - b}{a}\).
- Return type:
torch.Tensor
Example:
import torch from torch_openreml.covariance.transform import TransformScaleShift t = TransformScaleShift(a=2.0, b=1.0) x = torch.tensor([1.0, 3.0, 5.0]) t.inverse(x)
tensor([0., 1., 2.])
- grad(x)[source]¶
Compute derivative of \(ax + b\) for chain rule propagation.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
[\(a\)].
- Return type:
torch.Tensor
Example:
import torch from torch_openreml.covariance.transform import TransformScaleShift t = TransformScaleShift(a=2.0, b=1.0) x = torch.tensor([0.0]) t.grad(x)
tensor([2.])