torch_openreml.covariance.transform.Transform

class torch_openreml.covariance.transform.Transform[source]

Bases: ABC

Abstract base class for mathematical transforms.

A transform maps inputs from a domain to a codomain, and defines a forward operation, its inverse, and a chain rule factor for differentiation through the transform.

Methods

__call__(x)

Apply the forward transformation.

grad(x)

Compute the derivative factor for chain rule propagation.

inverse(x)

Apply the inverse transformation.

Attributes

codomain

Codomain of the transform.

domain

Domain of the transform.

domain = 'ℝ'

Domain of the transform.

codomain = 'ℝ'

Codomain of the transform.

abstractmethod __call__(x)[source]

Apply the forward transformation.

Parameters:

x – Input value in the transform’s domain.

Returns:

Transformed value in the codomain.

abstractmethod inverse(x)[source]

Apply the inverse transformation.

Parameters:

x – Input value in the codomain.

Returns:

Value mapped back to the domain.

abstractmethod grad(x)[source]

Compute the derivative factor for chain rule propagation.

Parameters:

x – Input value at which to evaluate the derivative factor.

Returns:

Scalar or tensor representing the local Jacobian/derivative.