Source code for torch_openreml.covariance.transform.transform_exppow2

"""
Scaled exponential transform module for constrained parameter mappings.

Provides a differentiable bijective transform from
:math:`\\mathbb{R} \\rightarrow \\mathbb{R}_{0+}` using a natural
exponential with exponent scaled by 2.

Classes:
    TransformExpPow2:
        Scaled exponential transform (:math:`e^{2x}`)
"""

from torch_openreml.covariance.transform.transform import Transform
import torch


[docs] class TransformExpPow2(Transform): r""" Exponential transform with exponent scaled by 2. .. math:: f(x) = e^{2x} """ domain = "\u211D" codomain = "\u211D\u2080\u207A" def __init__(self): r""" Initialize the scaled exponential transform. """ pass
[docs] def __call__(self, x): r""" Apply the scaled exponential transform. Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}`. Returns: torch.Tensor: Element-wise :math:`e^{2x}`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExpPow2 t = TransformExpPow2() x = torch.tensor([0.0, 1.0]) t(x) """ return torch.exp(2.0 * x)
[docs] def inverse(self, x): r""" Apply the inverse transform. Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}_{0+}`. Returns: torch.Tensor: :math:`\frac{\log(x)}{2}`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExpPow2 t = TransformExpPow2() x = torch.tensor([1.0]) t.inverse(x) """ return torch.log(x) / 2.0
[docs] def grad(self, x): r""" Compute derivative of :math:`e^{2x}` for chain rule propagation. Note: .. math:: \frac{d}{dx} e^{2x} = 2e^{2x} Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: :math:`2e^{2x}`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExpPow2 t = TransformExpPow2() x = torch.tensor([0.0, 1.0]) t.grad(x) """ return 2 * torch.exp(2 * x)