Source code for torch_openreml.covariance.operator

"""
Operator base class for composite covariance matrices.

This module provides an abstract base class for combining multiple
covariance matrices into a single composite structure. Operators
delegate parameter management, transforms, and gradient computation
to their constituent operands, presenting a unified
:class:`~torch_openreml.covariance.matrix.Matrix` interface to the rest
of the library.

Classes:
    Operator:
        Base class for composite covariance matrix operators.
"""

from functools import reduce
from torch_openreml.covariance.matrix import Matrix
import torch

[docs] class Operator(Matrix): r""" Abstract base class for composite covariance matrix operators. An operator combines one or more :class:`~torch_openreml.covariance.matrix.Matrix` instances and optional fixed :class:`torch.Tensor` operands into a single covariance matrix. Parameters and gradients are namespaced by operand — a parameter ``"sigma^2"`` belonging to operand ``"residual"`` is exposed as ``"residual/sigma^2"`` in the combined :attr:`param_names`. Operands may be passed as positional arguments, as keyword arguments, or as a single dictionary. Mixing positional and keyword arguments is not permitted. Positional operands are assigned names ``"op_0"``, ``"op_1"``, etc. At least one operand must be a :class:`~torch_openreml.covariance.matrix.Matrix` instance. Pure-tensor operands are treated as fixed matrices with no free parameters. """ _repr_single_line = False def __init__(self, *args, **kwargs): """ Initialize the operator from positional or keyword operands. Args: *args: Operands as positional arguments, each a :class:`~torch_openreml.covariance.matrix.Matrix` or :class:`torch.Tensor`. A single dict argument is also accepted and treated as a named operand mapping. **kwargs: Operands as keyword arguments, mapping operand names to :class:`~torch_openreml.covariance.matrix.Matrix` or :class:`torch.Tensor` instances. Operand names must not contain ``"/"``. Raises: ValueError: If both positional and keyword arguments are provided, or if any operand name contains ``"/"``. TypeError: If ``operands`` is not a dict, if any operand name is not a string, if any operand is not a :class:`~torch_openreml.covariance.matrix.Matrix` or :class:`torch.Tensor`, or if no operand is a :class:`~torch_openreml.covariance.matrix.Matrix`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance import Sum, ScalarMatrix x = Sum(ScalarMatrix(2), ScalarMatrix(2)) x .. jupyter-execute:: x = Sum(A = ScalarMatrix(2), B = ScalarMatrix(2)) x .. jupyter-execute:: x.param_names .. jupyter-execute:: x(torch.zeros(2)) .. jupyter-execute:: x({"A/sigma^2": torch.zeros(1), "B/sigma^2": torch.zeros(1)}) """ if len(args) > 0 and len(kwargs) > 0: raise ValueError('Operands must be provided either as keyword arguments or as positional arguments, but not both!') if len(args) > 0: if len(args) == 1 and isinstance(args[0], dict): operands = args[0] else: operands = {f"op_{i}": arg for i, arg in enumerate(args)} else: operands = kwargs self._check_operands(operands) self._operands = operands super().__init__(None, {}) def _check_operands(self, operands): """ Validate the operand dictionary. Ensures that all keys are strings without ``"/"``, all values are :class:`~torch_openreml.covariance.matrix.Matrix` or :class:`torch.Tensor` instances, and that at least one value is a :class:`~torch_openreml.covariance.matrix.Matrix`. Args: operands (dict): Mapping from operand names to operands. Raises: TypeError: If ``operands`` is not a dict, if any key is not a string, if any value is not a :class:`~torch_openreml.covariance.matrix.Matrix` or :class:`torch.Tensor`, or if no value is a :class:`~torch_openreml.covariance.matrix.Matrix`. ValueError: If any key contains ``"/"``. """ if not isinstance(operands, dict): raise TypeError(f"operands must be a dict, got {type(operands).__name__}!") for key, value in operands.items(): if not isinstance(key, str): raise TypeError(f"Operand name must be a string, got {type(key).__name__}!") if "/" in key: raise ValueError(f"Invalid operand name '{key}': '/' is not allowed!") if not isinstance(value, (Matrix, torch.Tensor)): raise TypeError( f"Operand '{key}' must be a Matrix or torch.Tensor, " f"got {type(value).__name__}!" ) if not any(isinstance(v, Matrix) for v in operands.values()): raise TypeError("operands must include at least one Matrix!")
[docs] def build_params(self, free_params=None, include_fixed=True, trans=True, out_format="tensor"): """ Construct the full parameter tensor by delegating to each operand. Splits ``free_params`` into per-operand slices according to each operand's :attr:`~torch_openreml.covariance.matrix.Matrix.num_free_params`, calls :meth:`~torch_openreml.covariance.matrix.Matrix.build_params` on each :class:`~torch_openreml.covariance.matrix.Matrix` operand, and concatenates the results. Fixed tensor operands are skipped. Args: free_params (torch.Tensor or dict): Flat 1D tensor of free parameters or a dictionary mapping parameter names to tensors. If omitted, default values are used. Default: ``None``. include_fixed (bool, optional): Whether to include fixed parameters in the output. Passed through to each operand's :meth:`~torch_openreml.covariance.matrix.Matrix.build_params`. Default: ``True``. trans (bool, optional): Whether to apply parameter transforms to the output. Passed through to each operand's :meth:`~torch_openreml.covariance.matrix.Matrix.build_params`. Default: ``True``. out_format (str, optional): Output format. One of ``"tensor"`` or ``"dict"``. Default: ``"tensor"``. Returns: torch.Tensor or dict: Concatenated parameter tensor or dictionary mapping namespaced parameter names to value tensors. Raises: TypeError: If ``free_params`` is not a Torch tensor. ValueError: If ``free_params`` is not a 1D tensor or has the wrong length, or if ``free_params`` is a dict with missing or unexpected keys. ValueError: If ``out_format`` is not ``"tensor"`` or ``"dict"``. .. jupyter-execute:: import torch from torch_openreml.covariance import Sum, ScalarMatrix x = Sum(ScalarMatrix(2), ScalarMatrix(2)) free_params = torch.tensor([0.0, 0.5]) x.build_params(free_params) .. jupyter-execute:: x.build_params() .. jupyter-execute:: x.build_params(free_params, trans=False) .. jupyter-execute:: x.build_params(free_params, out_format="dict") """ if free_params is None: free_params = self.free_param_defaults free_params = self._from_free_param_dict(free_params) self._check_param_tensor(free_params, length=self.num_free_params) result = [] for name, operand in self._operands.items(): if isinstance(operand, Matrix): operand_free_params = free_params[0:operand.num_free_params] free_params = free_params[operand.num_free_params:] result.append(operand.build_params(operand_free_params, include_fixed=include_fixed, trans=trans, out_format="tensor")) result = torch.cat(result) if out_format == "tensor": return result elif out_format == "dict": return dict(zip(self.free_param_names, result)) else: raise ValueError(f"Unexpected 'out_format': {out_format}!")
[docs] def build_operands(self, free_params=None): """ Evaluate each operand at the current free parameters. Splits ``free_params`` into per-operand slices and calls each :class:`~torch_openreml.covariance.matrix.Matrix` operand to produce its matrix. Fixed tensor operands are included as-is. Args: free_params (torch.Tensor or dict): Flat 1D joint parameter tensor or parameter dictionary of length :attr:`num_free_params`. If omitted, default values are used. Default: ``None``. Returns: list of torch.Tensor: Evaluated operand matrices in the same order as :attr:`operands`. Example: .. jupyter-execute:: import torch from torch_openreml.covariance import Sum, ScalarMatrix x = Sum(ScalarMatrix(2), ScalarMatrix(2)) v_groups = x.build_operands(torch.tensor([1.0, 2.0])) print(v_groups[0]) print(v_groups[1]) """ if free_params is None: free_params = self.free_param_defaults free_params = self._from_free_param_dict(free_params) self._check_param_tensor(free_params, length=self.num_free_params) v_groups = [] for name, operand in self.operands.items(): if isinstance(operand, Matrix): operand_params = free_params[0:operand.num_free_params] free_params = free_params[operand.num_free_params:] v_groups.append(operand(operand_params)) else: v_groups.append(operand) return v_groups
[docs] def operands_grad(self, free_params=None): """ Compute the Jacobian of each operand with respect to its parameters. Splits ``free_params`` into per-operand slices, calls :meth:`~torch_openreml.covariance.matrix.Matrix.grad` on each :class:`~torch_openreml.covariance.matrix.Matrix` operand, and prefixes the returned names with the operand name. Fixed tensor operands contribute ``None`` and an empty name list. Args: params (torch.Tensor or dict): Flat 1D joint parameter tensor or parameter dictionary of length :attr:`num_params`. If omitted, default values are used. Default: ``None``. Returns: tuple: ``(grad_groups, grad_name_groups)``, where ``grad_groups`` is a list of per-operand Jacobian tensors or ``None`` for fixed operands, and ``grad_name_groups`` is a list of corresponding namespaced parameter name lists. Example: .. jupyter-execute:: import torch from torch_openreml.covariance import Sum, ScalarMatrix x = Sum(ScalarMatrix(2), ScalarMatrix(2)) grad_groups, grad_name_groups = x.operands_grad(torch.tensor([1.0, 2.0])) print(grad_groups[0]) print(grad_groups[1]) print(grad_name_groups[0]) print(grad_name_groups[1]) """ if free_params is None: free_params = self.free_param_defaults free_params = self._from_free_param_dict(free_params) self._check_param_tensor(free_params, length=self.num_free_params) grad_groups = [] grad_name_groups = [] for name, operand in self.operands.items(): if isinstance(operand, Matrix): operand_params = free_params[0:operand.num_free_params] free_params = free_params[operand.num_free_params:] grad, grad_names = operand.grad(operand_params) if grad is not None: grad_groups.append(grad) grad_name_groups.append([f"{name}/{grad_name}" for grad_name in grad_names]) else: grad_groups.append(None) grad_name_groups.append([]) else: grad_groups.append(None) grad_name_groups.append([]) return grad_groups, grad_name_groups
[docs] def auto_grad(self, free_params=None): for operand in self.operands: if isinstance(operand, Matrix): operand.reset_intermediates() return super().auto_grad(free_params)
@property def operands(self): """dict: Mapping from operand names to operand matrices or tensors.""" return self._operands @property def param_specs(self): param_specs = {} for name, operand in self.operands.items(): if isinstance(operand, Matrix): this_param_specs = {f"{name}/{param_name}": spec for param_name, spec in operand.param_specs.items()} param_specs.update(this_param_specs) return param_specs @property def repr_dict(self): """dict: Key-value pairs used to build the string representation.""" return {"operands": self.operands}