Source code for torch_openreml.covariance.operator_block_diagonal

"""
Block diagonal covariance matrix operator.

This module provides a block diagonal covariance matrix formed from two
or more constituent covariance matrices, for use in linear mixed-effects
models.

Classes:
    BlockDiagonal:
        A block diagonal covariance matrix operator.
"""

from torch_openreml.covariance.operator import Operator
import torch


[docs] class BlockDiagonal(Operator): r""" Block diagonal covariance matrix formed from two or more operands. .. math:: \symbf{V} = \mathrm{blockdiag}(\symbf{V}_0, \symbf{V}_1, \ldots) Each operand contributes a contiguous block along the diagonal. Parameters and gradients are namespaced by operand name and aggregated into a single joint parameter tensor, following the convention of :class:`~torch_openreml.covariance.operator.Operator`. """ def __init__(self, *args, **kwargs): """ Initialize a block diagonal operator from two or more operands. Args: *args: Two or more operands as positional arguments, each a :class:`~torch_openreml.covariance.matrix.Matrix` or :class:`torch.Tensor`. A single dict argument is also accepted. **kwargs: Two or more operands as keyword arguments. Raises: ValueError: If fewer than two operands are provided. Example: .. jupyter-execute:: import torch from torch_openreml.covariance import ScalarMatrix, DiagonalMatrix, BlockDiagonal block = BlockDiagonal( residual=ScalarMatrix(3), random=DiagonalMatrix(2) ) free_params = torch.tensor([0.5, 0.0, 1.0]) block(free_params) """ super().__init__(*args, **kwargs) if len(self.operands) < 2: raise ValueError("At least two operands are required") def _get_or_build_intermediates(self, free_params): built_params = self.build_params(free_params) cache = self.get_intermediates(built_params) if cache is None: v_groups = self.build_operands(free_params) v = torch.block_diag(*v_groups) row_offsets = [] col_offsets = [] n = 0 m = 0 for vg in v_groups: rows, cols = vg.shape row_offsets.append((n, n + rows)) col_offsets.append((m, m + cols)) n += rows m += cols cache = { "v_groups": v_groups, "v": v, "row_offsets": row_offsets, "col_offsets": col_offsets } self.set_intermediates(built_params, cache) return cache
[docs] def __call__(self, free_params=None): if free_params is None: free_params = self.free_param_defaults cache = self._get_or_build_intermediates(free_params) v = cache["v"] self._shape = tuple(v.shape) return v
[docs] def manual_grad(self, free_params=None): """ Compute the Jacobian of :meth:`__call__` with respect to trainable parameters using a closed-form analytic expression. Evaluates each operand's gradient via :meth:`~torch_openreml.covariance.operator.Operator.operands_grad`, then places each per-operand Jacobian into the corresponding block of a zero-initialised full Jacobian matching the block diagonal output shape. Non-matrix operands contribute zero blocks. Args: free_params (torch.Tensor or dict): Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default: ``None``. Returns: tuple: ``(grad, grad_names)``, where ``grad`` is a 3D tensor of shape ``(num_free_params, *shape)`` and ``grad_names`` is a list of the corresponding parameter names. Returns ``(None, [])`` if all parameters are fixed. Raises: TypeError: If ``free_params`` is not a Torch tensor. ValueError: If ``free_params`` is not a 1D tensor or has the wrong length, or if ``free_params`` is a dict with missing or unexpected keys. Example: .. jupyter-execute:: import torch from torch_openreml.covariance import ScalarMatrix, DiagonalMatrix, BlockDiagonal block = BlockDiagonal( A=ScalarMatrix(3), B=DiagonalMatrix(2) ) free_params = torch.tensor([0.5, 0.0, 1.0]) grad, grad_names = block.manual_grad(free_params) grad .. jupyter-execute:: grad_names """ if free_params is None: free_params = self.free_param_defaults grad_groups, grad_name_groups = self.operands_grad(free_params) cache = self._get_or_build_intermediates(free_params) v = cache["v"] row_offsets = cache["row_offsets"] col_offsets = cache["col_offsets"] grad_list = [] grad_names = [] for i, grad in enumerate(grad_groups): if grad is None: continue (r0, r1) = row_offsets[i] (c0, c1) = col_offsets[i] tmp = torch.zeros((grad.shape[0],) + tuple(v.shape), dtype=v.dtype, device=v.device) tmp[:, r0:r1, c0:c1] = grad grad_list.append(tmp) grad_names.extend(grad_name_groups[i]) if len(grad_list) > 0: grad = torch.cat(grad_list) return grad, grad_names else: return None, []