Source code for torch_openreml.covariance.operator_covariance_propagation

"""
Covariance propagation operator.

This module provides a covariance propagation operator that transforms
a covariance matrix through a design matrix, for use in linear
mixed-effects models.

Classes:
    CovariancePropagation:
        A covariance propagation operator :math:`V = Z G Z^\\top`.
"""

from torch_openreml.covariance.operator import Operator
import torch


[docs] class CovariancePropagation(Operator): r""" Covariance propagation operator. .. math:: \symbf{V} = \symbf{Z} \symbf{G} \symbf{Z}^\top Propagates the covariance matrix :math:`\symbf{G}` through the design matrix :math:`\symbf{Z}`. The first operand is treated as :math:`\symbf{Z}` and the second as :math:`\symbf{G}`. Either or both may be trainable :class:`~torch_openreml.covariance.matrix.Matrix` instances or fixed :class:`torch.Tensor` values. This structure arises naturally in linear mixed-effects models where :math:`\symbf{Z}` is the random-effect design matrix and :math:`\symbf{G}` is the random-effect covariance matrix, giving the random-effect contribution :math:`\symbf{Z}\symbf{G}\symbf{Z}^\top` to the marginal covariance. """ def __init__(self, *args, **kwargs): """ Initialize a covariance propagation operator from exactly two operands. Args: *args: Exactly two operands as positional arguments or a single dict. The first is :math:`\\symbf{Z}`, the second :math:`\\symbf{G}`. **kwargs: Exactly two operands as keyword arguments. Raises: ValueError: If the number of operands is not exactly two. Example: .. jupyter-execute:: import torch from torch_openreml.covariance import DummyMatrix, DiagonalMatrix, CovariancePropagation z = DummyMatrix(["a", "b", "c", "a"]) z() .. jupyter-execute:: g = DiagonalMatrix(3) op = CovariancePropagation(z=z, g=g) free_params = torch.tensor([0.0, 0.5, 1.0]) op(free_params) """ super().__init__(*args, **kwargs) if len(self.operands) != 2: raise ValueError("Two operands are required") def _get_or_build_intermediates(self, free_params): built_params = self.build_params(free_params) cache = self.get_intermediates(built_params) if cache is None: v_groups = self.build_operands(free_params) z = v_groups[0] g = v_groups[1] v = z @ g @ z.T cache = {"z": z, "g": g, "v": v} self.set_intermediates(built_params, cache) return cache
[docs] def __call__(self, free_params=None): if free_params is None: free_params = self.free_param_defaults cache = self._get_or_build_intermediates(free_params) v = cache["v"] self._shape = tuple(v.shape) return v
[docs] def manual_grad(self, free_params=None): """ Compute the Jacobian of :meth:`__call__` with respect to trainable parameters using a closed-form analytic expression. Applies the product rule to :math:`\\symbf{V} = \\symbf{Z} \\symbf{G} \\symbf{Z}^\\top`: - With respect to :math:`\\theta_{\\symbf{Z}}`: :math:`\\frac{\\partial \\symbf{V}}{\\partial \\theta} = \\frac{\\partial \\symbf{Z}}{\\partial \\theta} \\symbf{G} \\symbf{Z}^\\top + \\symbf{Z} \\symbf{G} \\frac{\\partial \\symbf{Z}^\\top}{\\partial \\theta}` (two terms because :math:`\\symbf{Z}` appears twice). - With respect to :math:`\\theta_{\\symbf{G}}`: :math:`\\frac{\\partial \\symbf{V}}{\\partial \\theta} = \\symbf{Z} \\frac{\\partial \\symbf{G}}{\\partial \\theta} \\symbf{Z}^\\top` (linear in :math:`\\symbf{G}`). Per-operand Jacobians from :meth:`~torch_openreml.covariance.operator.Operator.operands_grad` are propagated through the same structure. Args: free_params (torch.Tensor or dict): Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default: ``None``. Returns: tuple: ``(grad, grad_names)``, where ``grad`` is a 3D tensor of shape ``(num_free_params, *shape)`` and ``grad_names`` is a list of the corresponding parameter names. Returns ``(None, [])`` if all parameters are fixed. Raises: TypeError: If ``free_params`` is not a Torch tensor. ValueError: If ``free_params`` is not a 1D tensor or has the wrong length, or if ``free_params`` is a dict with missing or unexpected keys. Example: .. jupyter-execute:: import torch from torch_openreml.covariance import DummyMatrix, DiagonalMatrix, CovariancePropagation z = DummyMatrix(["a", "b", "c", "a"]) z() .. jupyter-execute:: g = DiagonalMatrix(3) op = CovariancePropagation(z=z, g=g) free_params = torch.tensor([0.0, 0.5, 1.0]) grad, grad_names = op.manual_grad(free_params) grad .. jupyter-execute:: grad_names """ if free_params is None: free_params = self.free_param_defaults grad_groups, grad_name_groups = self.operands_grad(free_params) cache = self._get_or_build_intermediates(free_params) z = cache["z"] g = cache["g"] grad_list = [] grad_names = [] dz = grad_groups[0] if dz is not None: grad_z = dz @ g @ z.T + z @ g @ dz.mT grad_list.append(grad_z) grad_names.extend(grad_name_groups[0]) dg = grad_groups[1] if dg is not None: grad_g = z @ dg @ z.T grad_list.append(grad_g) grad_names.extend(grad_name_groups[1]) if len(grad_list) > 0: grad = torch.cat(grad_list) return grad, grad_names else: return None, []