torch_openreml.covariance.transform.TransformScaleShift

class torch_openreml.covariance.transform.TransformScaleShift(a, b=0.0)[source]

Bases: Transform

Affine transform with configurable scale and shift.

\[f(x) = ax + b\]

Initialize the affine transform.

Parameters:
  • a (float) – Scale factor.

  • b (float) – Shift offset. Defaults to 0.0.

Methods

__call__(x)

Apply the affine transform.

grad(x)

Compute derivative of \(ax + b\) for chain rule propagation.

inverse(x)

Apply the inverse transform.

Attributes

codomain

Codomain of the transform.

domain

Domain of the transform.

domain = 'ℝ'

Domain of the transform.

codomain = 'ℝ'

Codomain of the transform.

__call__(x)[source]

Apply the affine transform.

Parameters:

x (torch.Tensor) – Input tensor in \(\mathbb{R}\).

Returns:

Element-wise \(ax + b\).

Return type:

torch.Tensor

Example:

import torch
from torch_openreml.covariance.transform import TransformScaleShift

t = TransformScaleShift(a=2.0, b=1.0)
x = torch.tensor([0.0, 1.0, 2.0])
t(x)
tensor([1., 3., 5.])
inverse(x)[source]

Apply the inverse transform.

Parameters:

x (torch.Tensor) – Input tensor in \(\mathbb{R}\).

Returns:

Element-wise \(\frac{x - b}{a}\).

Return type:

torch.Tensor

Example:

import torch
from torch_openreml.covariance.transform import TransformScaleShift

t = TransformScaleShift(a=2.0, b=1.0)
x = torch.tensor([1.0, 3.0, 5.0])
t.inverse(x)
tensor([0., 1., 2.])
grad(x)[source]

Compute derivative of \(ax + b\) for chain rule propagation.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

[\(a\)].

Return type:

torch.Tensor

Example:

import torch
from torch_openreml.covariance.transform import TransformScaleShift

t = TransformScaleShift(a=2.0, b=1.0)
x = torch.tensor([0.0])
t.grad(x)
tensor([2.])