Source code for torch_openreml.covariance.transform.transform_pow

"""
Power transform module for parameter mappings.

Provides a differentiable transform from
:math:`\\mathbb{R} \\rightarrow \\mathbb{R}` using a configurable
power function.

Classes:
    TransformPow:
        Power transform (:math:`f(x) = x^p`)
"""

from torch_openreml.covariance.transform.transform import Transform
import torch


[docs] class TransformPow(Transform): r""" Power transform with configurable exponent. .. math:: f(x) = x^p """ domain = "\u211D" codomain = "\u211D" def __init__(self, factor=2.0): r""" Initialize the power transform. Args: factor (float): Exponent :math:`p`. Defaults to ``2.0``. """ self.factor = factor
[docs] def __call__(self, x): r""" Apply the power transform. Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}`. Returns: torch.Tensor: Element-wise :math:`x^p`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformPow t = TransformPow(factor=3.0) x = torch.tensor([1.0, 2.0, 3.0]) t(x) """ return torch.pow(x, self.factor)
[docs] def inverse(self, x): r""" Apply the inverse transform (square root). Args: x (torch.Tensor): Input tensor in :math:`\mathbb{R}`. Returns: torch.Tensor: Element-wise :math:`\sqrt{x}`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformPow t = TransformPow(factor=2.0) x = torch.tensor([1.0, 4.0, 9.0]) t.inverse(x) """ return torch.sqrt(x)
[docs] def grad(self, x): r""" Compute derivative of :math:`x^p` for chain rule propagation. Note: .. math:: \frac{d}{dx} x^p = p x^{p-1} Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: :math:`p x^{p-1}`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance.transform import TransformPow t = TransformPow(factor=3.0) x = torch.tensor([2.0, 3.0]) t.grad(x) """ return self.factor * x
def __repr__(self): return f"{self.__class__.__name__}(factor={self.factor})"