torch_openreml.covariance.Matrix

class torch_openreml.covariance.Matrix(shape, param_specs)[source]

Bases: ABC

Abstract base class for covariance matrices with parameterized structure.

\[\symbf{V} = \symbf{V}(\symbf{\theta})\]

where \(\symbf{\theta}\) denotes the collection of variance component parameters that define the matrix entries.

This class provides utilities for parameter validation, transform application, and Jacobian computation (both manual and automatic). Subclasses must implement __call__() to construct their specific matrix structure from the provided parameters.

Initialize a covariance matrix with parameter specifications.

Parameters:
  • shape (tuple or None) – Expected output dimensions of the constructed matrix. Used for validation; the actual shape may be set by subclasses.

  • param_specs (dict) – Parameter specifications. Keys should be strings representing parameter names. Values should be dictionaries containing the specification for each parameter. Each specification dictionary should contain the keys "fixed", "default", and "trans", representing whether the parameter is fixed or free (bool), the default value (1D torch.Tensor), and the transform (Transform), respectively.

Raises:
  • TypeError – If param_specs does not follow any of the requirements listed in the argument description, or if shape is not a tuple or torch.Size.

  • ValueError – If shape values are non-negative.

Methods

__call__([free_params])

Construct the matrix from a flat parameter tensor.

auto_grad([free_params])

Compute the Jacobian of build() with respect to free parameters using automatic differentiation.

build_params([free_params, include_fixed, ...])

Construct the full parameter tensor from free parameters.

get_intermediates(params)

Retrieve cached intermediate computation results if still valid.

grad([free_params])

Compute the Jacobian of __call__() with respect to trainable parameters.

manual_grad([free_params])

Compute the Jacobian of __call__() with respect to free parameters using a closed-form analytic expression.

map_theta_to_dv(theta)

An interface compatible with torch_openreml.REML that maps parameters to the matrix Jacobian.

map_theta_to_v(theta)

An interface compatible with torch_openreml.REML that maps parameters to a matrix.

reset_intermediates()

Clear the intermediate computation cache.

set_intermediates(params, intermediates)

Cache intermediate computation results keyed by parameter hash.

trans_grad([free_params])

Compute the element-wise derivative of the free parameter transforms.

Attributes

fixed_param_defaults

Fixed parameter defaults.

fixed_param_index

Index of fixed parameters.

fixed_param_names

Fixed parameter names.

fixed_param_trans

Transforms for fixed parameters.

free_param_defaults

Free parameter defaults.

free_param_index

Index of free parameters.

free_param_names

Free parameter names.

free_param_trans

Transforms for free parameters.

num_fixed_params

Total number of fixed parameters.

num_free_params

Total number of free parameters.

num_params

Total number of parameters.

param_defaults

Parameter defaults.

param_names

Parameter names.

param_specs

Parameter specifications.

param_trans

Parameter transforms.

repr_dict

Key-value pairs used to build the string representation.

shape

Output matrix shape.

set_intermediates(params, intermediates)[source]

Cache intermediate computation results keyed by parameter hash.

Stores arbitrary intermediate values alongside a hash of the current parameter tensor, dtype, and device. Cached values can be retrieved via get_intermediates() to avoid redundant computation across multiple calls with identical parameters.

Parameters:
  • params (torch.Tensor) – Current parameter tensor.

  • intermediates – Arbitrary object to cache (e.g. Cholesky factors, eigendecompositions, or any reusable computation).

Raises:
  • TypeError – If params is not a Torch tensor.

  • ValueError – If params is not a 1D tensor.

Note

If params has length 0 (no free parameters), this is a no-op.

Example:

import torch
from torch_openreml.covariance import DiagonalMatrix

mat = DiagonalMatrix(3)
free_params = torch.tensor([0.0, 0.5, 1.0])
params = mat.build_params(free_params)
mat.set_intermediates(params, {"log(sigma^2)/2": torch.log(params) / 2})
mat.get_intermediates(params)
{'log(sigma^2)/2': tensor([0.0000, 0.5000, 1.0000])}
get_intermediates(params)[source]

Retrieve cached intermediate computation results if still valid.

Compares the hash, dtype, and device of params against the stored cache from the last set_intermediates() call. Returns the cached value only if all three match, ensuring stale results are never returned after a parameter update, device transfer, or dtype cast.

Parameters:

params (torch.Tensor) – Current parameter tensor.

Raises:
  • TypeError – If params is not a Torch tensor.

  • ValueError – If params is not a 1D tensor.

Returns:

The cached intermediate object if the cache is valid, or None if the cache is missing, stale, or params has length 0.

Example:

import torch
from torch_openreml.covariance import DiagonalMatrix

mat = DiagonalMatrix(3)
free_params = torch.tensor([0.0, 0.5, 1.0])
params = mat.build_params(free_params)
mat.set_intermediates(params, {"log(sigma^2)/2": torch.log(params) / 2})
mat.get_intermediates(params)
{'log(sigma^2)/2': tensor([0.0000, 0.5000, 1.0000])}
reset_intermediates()[source]

Clear the intermediate computation cache.

Resets all cached values set by set_intermediates() to None, forcing subsequent calls to get_intermediates() to return None until the cache is repopulated. Called automatically in __init__() and within auto_grad() before triggering a fresh Jacobian computation.

Example:

import torch
from torch_openreml.covariance import DiagonalMatrix

mat = DiagonalMatrix(3)
free_params = torch.tensor([0.0, 0.5, 1.0])
params = mat.build_params(free_params)
mat.set_intermediates(params, {"log(sigma^2)/2": torch.log(params) / 2})
print(mat.get_intermediates(params))
{'log(sigma^2)/2': tensor([0.0000, 0.5000, 1.0000])}
mat.reset_intermediates()
print(mat.get_intermediates(free_params))
None
build_params(free_params=None, include_fixed=True, trans=True, out_format='tensor')[source]

Construct the full parameter tensor from free parameters.

Merges free (trainable) parameters with fixed parameter defaults and applies parameter transforms. Optionally returns a dictionary mapping parameter names to their transformed values.

Parameters:
  • free_params (torch.Tensor or dict) – Flat 1D tensor of free parameters or a dictionary mapping parameter names to tensors. If omitted, default values are used. Default: None.

  • include_fixed (bool, optional) – Whether to include fixed parameters in the output. Default: True.

  • trans (bool, optional) – Whether to apply parameter transforms to the output. Default: True.

  • out_format (str, optional) – Output format. One of "tensor" or "dict". Default: "tensor".

Returns:

Full parameter tensor of length num_params (or num_free_params when include_fixed=False), or a dictionary mapping parameter names to value tensors.

Return type:

torch.Tensor or dict

Raises:
  • ValueError – If out_format is not "tensor" or "dict".

  • TypeError – If free_params is not a Torch tensor.

  • ValueError – If free_params is not a 1D tensor or has the wrong length, or if free_params is a dict with missing or unexpected keys.

Example:

import torch
from torch_openreml.covariance import DiagonalMatrix

mat = DiagonalMatrix(3)
free_params = torch.tensor([0.0, 0.5, 1.0])
mat.build_params(free_params)
tensor([1.0000, 2.7183, 7.3891])
mat.build_params()
tensor([1., 1., 1.])
mat.param_specs["sigma^2_2"]["fixed"] = True
mat.build_params(free_params[0:2])
tensor([1.0000, 2.7183, 1.0000])
mat.build_params(free_params[0:2], include_fixed=False)
tensor([1.0000, 2.7183])
mat.build_params(free_params[0:2], include_fixed=False, trans=False)
tensor([0.0000, 0.5000])
trans_grad(free_params=None)[source]

Compute the element-wise derivative of the free parameter transforms.

Evaluates the derivative of each free parameter’s transform function at the current parameter values. Used in the chain rule when computing manual gradients of the matrix with respect to the original (untransformed) parameterisation.

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or dictionary of free parameters. If omitted, default values are used. Default: None.

Raises:
  • TypeError – If free_params is not a Torch tensor.

  • ValueError – If free_params is not a 1D tensor or has the wrong length, or if free_params is a dict with missing or unexpected keys.

Returns:

1D tensor of element-wise transform derivatives, of the same length as free_params.

Return type:

torch.Tensor

Example:

import torch
from torch_openreml.covariance import DiagonalMatrix

mat = DiagonalMatrix(3)
free_params = torch.tensor([0.0, 0.5, 1.0])
mat.trans_grad(free_params)
tensor([ 2.0000,  5.4366, 14.7781])
mat.trans_grad()
tensor([2., 2., 2.])
auto_grad(free_params=None)[source]

Compute the Jacobian of build() with respect to free parameters using automatic differentiation.

Uses torch.func.jacrev() to compute the full Jacobian.

If all parameters are fixed, returns (None, [])

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or dict. If omitted, default values are used. Default: None.

Raises:
  • TypeError – If free_params is not a Torch tensor.

  • ValueError – If free_params is not a 1D tensor or has the wrong length, or if free_params is a dict with missing or unexpected keys.

Returns:

(grad, grad_names), where grad is a 3D tensor of shape (num_free_params, *shape), and grad_names has the same length as grad.

Return type:

tuple

Example:

import torch
from torch_openreml.covariance import DiagonalMatrix

mat = DiagonalMatrix(2)
free_params = torch.tensor([0.0, 0.5])
grad, grad_names = mat.auto_grad(free_params)
grad, grad_names
(tensor([[[2.0000, 0.0000],
          [0.0000, 0.0000]],
 
         [[0.0000, 0.0000],
          [0.0000, 5.4366]]]),
 ['sigma^2_0', 'sigma^2_1'])
manual_grad(free_params=None)[source]

Compute the Jacobian of __call__() with respect to free parameters using a closed-form analytic expression.

This method is optional. When implemented by a subclass, grad() will invoke it in preference to auto_grad() under the default grad mode. If not implemented, calling this method raises NotImplementedError and grad() falls back to automatic differentiation.

Implementations must satisfy the following contract:

  • Return (None, []) if all parameters are fixed.

  • Return a 3D gradient tensor of shape (num_free_params, *shape) and a matching list of parameter names.

  • Apply transform derivatives from trans_grad() via the chain rule so that gradients are with respect to the raw (untransformed) parameters.

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default: None.

Returns:

(grad, grad_names), where grad is a 3D tensor of shape (num_free_params, *shape) and grad_names is a list of the corresponding parameter names. Returns (None, []) if all parameters are fixed.

Return type:

tuple

Raises:

NotImplementedError – If the subclass does not provide an analytic gradient. grad() catches this and falls back to auto_grad().

abstractmethod __call__(free_params=None)[source]

Construct the matrix from a flat parameter tensor.

Must be implemented by subclasses. Implementations should convert free_params via build_params() to validate, include fixed parameters, and apply transforms before any computation.

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default: None.

Returns:

Constructed matrix of shape shape.

Return type:

torch.Tensor

grad(free_params=None)[source]

Compute the Jacobian of __call__() with respect to trainable parameters.

Dispatches to manual_grad() or auto_grad() according to grad_mode:

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default: None.

Returns:

(grad, grad_names) as described in manual_grad() and auto_grad().

Return type:

tuple

Raises:

RuntimeError – If grad_mode is not a recognised value.

Example:

import torch
from torch_openreml.covariance import DiagonalMatrix

mat = DiagonalMatrix(2)
free_params = torch.tensor([0.0, 0.5])
grad, grad_names = mat.grad(free_params)
grad, grad_names
(tensor([[[2.0000, 0.0000],
          [0.0000, 5.4366]],
 
         [[2.0000, 0.0000],
          [0.0000, 5.4366]]]),
 ['sigma^2_0', 'sigma^2_1'])
map_theta_to_v(theta)[source]

An interface compatible with torch_openreml.REML that maps parameters to a matrix.

Invokes __call__().

Parameters:

theta (torch.Tensor) – Flat 1D parameter tensor.

Returns:

Constructed matrix.

Return type:

torch.Tensor

map_theta_to_dv(theta)[source]

An interface compatible with torch_openreml.REML that maps parameters to the matrix Jacobian.

Invokes grad().

Parameters:

theta (torch.Tensor) – Flat 1D parameter tensor.

Raises:

RuntimeError – If grad_mode is not a recognised value.

Returns:

Jacobian tensor of shape (num_free_params, *shape), or None if all parameters are fixed.

Return type:

torch.Tensor or None

property shape

Output matrix shape.

Type:

tuple

property param_specs

Parameter specifications.

Type:

dict

property param_names

Parameter names.

Type:

list of str

property free_param_names

Free parameter names.

Type:

list of str

property fixed_param_names

Fixed parameter names.

Type:

list of str

property free_param_index

Index of free parameters.

Type:

tuple

property fixed_param_index

Index of fixed parameters.

Type:

tuple

property num_params

Total number of parameters.

Type:

int

property num_free_params

Total number of free parameters.

Type:

int

property num_fixed_params

Total number of fixed parameters.

Type:

int

property param_defaults

Parameter defaults.

Type:

Dict of torch.Tensor

property free_param_defaults

Free parameter defaults.

Type:

Dict of torch.Tensor

property fixed_param_defaults

Fixed parameter defaults.

Type:

Dict of torch.Tensor

property param_trans

Parameter transforms.

Type:

Dict of Transform

property free_param_trans

Transforms for free parameters.

Type:

Dict of Transform

property fixed_param_trans

Transforms for fixed parameters.

Type:

Dict of Transform

property repr_dict

Key-value pairs used to build the string representation.

Type:

dict