Source code for torch_openreml.covariance.transform.transform_identity

"""
Identity transform module for unconstrained parameter mappings.

Provides a trivial bijective transform from
:math:`\\mathbb{R} \\rightarrow \\mathbb{R}` that leaves inputs unchanged.

Classes:
    TransformIdentity:
        Identity transform (:math:`f(x) = x`)
"""

from torch_openreml.covariance.transform.transform import Transform
import torch

[docs] class TransformIdentity(Transform): r""" Identity transform. .. math:: f(x) = x """ domain = "\u211D" codomain = "\u211D" def __init__(self): r""" Initialize the identity transform. """ pass
[docs] def __call__(self, x): r""" Apply the identity transform. Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}`. Returns: torch.Tensor: Unchanged input :math:`x`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformIdentity t = TransformIdentity() x = torch.tensor([0.0, 1.0, -3.5]) t(x) """ return x
[docs] def inverse(self, x): r""" Apply the inverse transform (identity). Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}`. Returns: torch.Tensor: Unchanged input :math:`x`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformIdentity t = TransformIdentity() x = torch.tensor([2.0, -1.0]) t.inverse(x) """ return x
[docs] def grad(self, x): r""" Compute derivative of :math:`f(x) = x` for chain rule propagation. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: [1.0] Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformIdentity t = TransformIdentity() x = torch.tensor([0.0]) t.grad(x) """ return torch.tensor([1.0], dtype=x.dtype, device=x.device)