torch_openreml.covariance.transform.TransformChain¶
- class torch_openreml.covariance.transform.TransformChain(chain)[source]¶
Bases:
TransformComposition of multiple Transform objects applied sequentially.
The chain behaves as a single transform: forward pass applies transforms in order, while inverse applies them in reverse order.
Initialize a transform chain.
- Parameters:
chain (Transform or list/tuple of Transform) – Sequence of transforms to compose.
- Raises:
TypeError – If any element in chain is not a Transform instance.
Methods
__call__(x)Apply the chained transformation forward.
grad(x)Compute chain rule factor for the full composed transform.
inverse(x)Apply the inverse of the chained transformation.
Attributes
codomainCodomain of the transform.
domainDomain of the transform.
- __call__(x)[source]¶
Apply the chained transformation forward.
- Parameters:
x – Input value in the domain of the first transform.
- Returns:
Output after applying all transforms sequentially.
Example:
import torch from torch_openreml.covariance.transform import TransformExp, TransformPow, TransformChain t = TransformChain([TransformExp(), TransformPow(factor=2.0)]) x = torch.tensor([1.0]) t(x)
tensor([7.3891])
- inverse(x)[source]¶
Apply the inverse of the chained transformation.
- Parameters:
x – Input value in the codomain of the last transform.
- Returns:
Value mapped back through the inverse chain to the original domain.
Example:
import torch from torch_openreml.covariance.transform import TransformExp, TransformPow, TransformChain t = TransformChain([TransformExp(), TransformPow(factor=2.0)]) x = torch.tensor([4.0]) t.inverse(x)
tensor([0.6931])
- grad(x)[source]¶
Compute chain rule factor for the full composed transform.
Note
This assumes local derivatives are evaluated consistently along the forward pass.
- Parameters:
x – Input value in the original domain.
- Returns:
Combined derivative factor of all transforms in the chain.
Example:
import torch from torch_openreml.covariance.transform import TransformExp, TransformPow, TransformChain t = TransformChain([TransformExp(), TransformPow(factor=2.0)]) x = torch.tensor([1.0]) t.grad(x)
tensor([14.7781])