torch_openreml.covariance.transform.TransformChain

class torch_openreml.covariance.transform.TransformChain(chain)[source]

Bases: Transform

Composition of multiple Transform objects applied sequentially.

The chain behaves as a single transform: forward pass applies transforms in order, while inverse applies them in reverse order.

Initialize a transform chain.

Parameters:

chain (Transform or list/tuple of Transform) – Sequence of transforms to compose.

Raises:

TypeError – If any element in chain is not a Transform instance.

Methods

__call__(x)

Apply the chained transformation forward.

grad(x)

Compute chain rule factor for the full composed transform.

inverse(x)

Apply the inverse of the chained transformation.

Attributes

codomain

Codomain of the transform.

domain

Domain of the transform.

__call__(x)[source]

Apply the chained transformation forward.

Parameters:

x – Input value in the domain of the first transform.

Returns:

Output after applying all transforms sequentially.

Example:

import torch
from torch_openreml.covariance.transform import TransformExp, TransformPow, TransformChain

t = TransformChain([TransformExp(), TransformPow(factor=2.0)])
x = torch.tensor([1.0])
t(x)
tensor([7.3891])
inverse(x)[source]

Apply the inverse of the chained transformation.

Parameters:

x – Input value in the codomain of the last transform.

Returns:

Value mapped back through the inverse chain to the original domain.

Example:

import torch
from torch_openreml.covariance.transform import TransformExp, TransformPow, TransformChain

t = TransformChain([TransformExp(), TransformPow(factor=2.0)])
x = torch.tensor([4.0])
t.inverse(x)
tensor([0.6931])
grad(x)[source]

Compute chain rule factor for the full composed transform.

Note

This assumes local derivatives are evaluated consistently along the forward pass.

Parameters:

x – Input value in the original domain.

Returns:

Combined derivative factor of all transforms in the chain.

Example:

import torch
from torch_openreml.covariance.transform import TransformExp, TransformPow, TransformChain

t = TransformChain([TransformExp(), TransformPow(factor=2.0)])
x = torch.tensor([1.0])
t.grad(x)
tensor([14.7781])