Source code for torch_openreml.covariance.transform.transform_exp

"""
Exponential transform module for constrained parameter mappings.

Provides a set of differentiable bijective transforms from
:math:`\\mathbb{R} \\rightarrow \\mathbb{R}_{+}` using different
exponential bases.

Classes:
    TransformExp:
        Natural exponential transform (:math:`e^x`)

    TransformExp2:
        Base-2 exponential transform (:math:`2^x`)

    TransformExp10:
        Base-10 exponential transform (:math:`10^x`)
"""

from torch_openreml.covariance.transform.transform import Transform
import torch


[docs] class TransformExp(Transform): r""" Exponential transform using the natural exponential function. .. math:: f(x) = e^x """ domain = "\u211D" codomain = "\u211D\u207A" def __init__(self): r""" Initialize the exponential transform. """ pass
[docs] def __call__(self, x): r""" Apply the natural exponential transform. Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}`. Returns: torch.Tensor: Element-wise :math:`e^x`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExp t = TransformExp() x = torch.tensor([0.0, 1.0]) t(x) """ return torch.exp(x)
[docs] def inverse(self, x): r""" Apply the inverse transform (natural logarithm). Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}_{+}`. Returns: torch.Tensor: :math:`\log(x)`. .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExp t = TransformExp() x = torch.tensor([1.0]) t.inverse(x) """ return torch.log(x)
[docs] def grad(self, x): r""" Compute derivative of :math:`e^x` for chain rule propagation. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: :math:`e^x`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExp t = TransformExp() x = torch.tensor([0.0]) t.grad(x) """ return torch.exp(x)
[docs] class TransformExp2(Transform): r""" Base-2 exponential transform. .. math:: f(x) = 2^x """ domain = "\u211D" codomain = "\u211D\u207A" def __init__(self): r""" Initialize base-2 exponential transform. """ pass
[docs] def __call__(self, x): r""" Apply base-2 exponential transform. Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}`. Returns: torch.Tensor: :math:`2^x`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExp2 t = TransformExp2() x = torch.tensor([0.0, 1.0]) t(x) """ return torch.exp2(x)
[docs] def inverse(self, x): r""" Apply inverse base-2 logarithm. Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}_{+}`. Returns: torch.Tensor: :math:`\log_{2}(x)`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExp2 t = TransformExp2() x = torch.tensor([1.0, 2.0]) t.inverse(x) """ return torch.log2(x)
[docs] def grad(self, x): r""" Compute derivative of :math:`2^x`. Note: .. math:: \frac{d}{dx} 2^x = 2^x \ln 2 Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: :math:`2^x \ln 2`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExp2 t = TransformExp2() x = torch.tensor([1.0]) t.grad(x) """ return torch.exp2(x) * torch.log( torch.tensor([2], dtype=x.dtype, device=x.device) )
[docs] class TransformExp10(Transform): r""" Base-10 exponential transform. .. math:: f(x) = 10^x """ domain = "\u211D" codomain = "\u211D\u207A" def __init__(self): r""" Initialize base-10 exponential transform. """ pass
[docs] def __call__(self, x): r""" Apply base-10 exponential transform. Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}`. Returns: torch.Tensor: :math:`10^x`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExp10 t = TransformExp10() x = torch.tensor([0.0, 1.0]) t(x) """ return torch.pow(10.0, x)
[docs] def inverse(self, x): r""" Apply inverse base-10 logarithm. Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}_{+}`. Returns: torch.Tensor: :math:`\log_{10}(x)`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExp10 t = TransformExp10() x = torch.tensor([1.0, 10.0]) t.inverse(x) """ return torch.log10(x)
[docs] def grad(self, x): r""" Compute derivative of :math:`10^x`. Note: .. math:: \frac{d}{dx} 10^x = 10^x \ln 10 Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: :math:`10^x \ln 10`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformExp10 t = TransformExp10() x = torch.tensor([1.0]) t.grad(x) """ return torch.pow(10.0, x) * torch.log( torch.tensor([10], dtype=x.dtype, device=x.device) )