torch_openreml.covariance.Matrix¶
- class torch_openreml.covariance.Matrix(shape, param_names, trans=None, no_grad_index=None)[source]¶
Bases:
ABCAbstract base class for covariance matrices with parameterized structure.
\[\symbf{V} = \symbf{V}(\symbf{\theta})\]where \(\symbf{\theta}\) denotes the collection of variance component parameters that define the matrix entries.
This class provides utilities for parameter validation, transform application, and Jacobian computation (both manual and automatic). Subclasses must implement
build()to construct their specific matrix structure from the provided parameters.Initialize a covariance matrix with optional parameter transforms.
- Parameters:
shape (tuple or None) – Expected output dimensions of the constructed matrix. Used for validation; the actual shape may be set by subclasses.
param_names (list of str) – Ordered names of parameters in
params. Empty list if no trainable parameters (e.g., fixed matrices).trans (list of Transform or None) – List of transforms applied to each parameter before constructing the matrix. If None, no transforms are used. Typically used for variance (\(\exp(2\theta) > 0\)) or correlation constraints (\(\rho \in (-1, 1)\)).
no_grad_index (list of int) – Indices to exclude from gradient computation. Parameters at these indices will be omitted from
gradandgrad_names. Useset_no_grad()instead for convenience.
Note
The transform applies as
\[\symbf{V} = \left[f_0(\theta_0), \ldots, f_{p-1}(\theta_{p-1}) \right]^\top,\]where each \(f_i\) is the i-th transform in
trans. Iftranshas length 1, the single transform is broadcast and applied elementwise to all parameters.- Raises:
TypeError – If
param_namesis not a list of strings, or if transforms contain non-Transform objects.ValueError – If parameter names are not unique, or if indices in
no_grad_indexare out of range.
Methods
__call__(params)Construct the matrix from a flat parameter tensor.
auto_grad(params)Compute the Jacobian of
build()with respect to trainable parameters using automatic differentiation.check_params(params)Validate a parameter tensor and return its device and dtype.
from_param_dict(param_dict)Extract parameter tensors from a dictionary into a flat 1D tensor.
get_intermediates(params)Retrieve cached intermediate computation results if still valid.
grad(params)Compute the Jacobian of
__call__()with respect to trainable parameters.manual_grad(params)Compute the Jacobian of
__call__()with respect to trainable parameters using a closed-form analytic expression.map_theta_to_dv(theta)An interface compatible with
torch_openreml.REMLthat maps parameters to the matrix Jacobian.map_theta_to_v(theta)An interface compatible with
torch_openreml.REMLthat maps parameters to a matrix.Clear the intermediate computation cache.
set_intermediates(params, intermediates)Cache intermediate computation results keyed by parameter hash.
set_no_grad([index, param_name])Set the indices of parameters to exclude from gradient computation.
to_param_dict(params)Convert a flat parameter tensor to a parameter dictionary.
trans_grad(params)Compute the element-wise derivative of the parameter transforms.
trans_params(params)Apply parameter transforms to a flat parameter tensor.
Attributes
Indices of parameters excluded from gradient computation.
Total number of parameters.
Ordered parameter names.
Key-value pairs used to build the string representation.
Output matrix shape.
Parameter transforms.
- set_intermediates(params, intermediates)[source]¶
Cache intermediate computation results keyed by parameter hash.
Stores arbitrary intermediate values alongside a hash of the current parameter tensor, dtype, and device. Cached values can be retrieved via
get_intermediates()to avoid redundant computation across multiple calls with identical parameters.- Parameters:
params (torch.Tensor or dict) – Current parameter tensor or dictionary. Converted to a flat tensor via
from_param_dict()before hashing.intermediates – Arbitrary object to cache (e.g. Cholesky factors, eigendecompositions, or any reusable computation).
Note
If
paramshas length 0 (no trainable parameters), this is a no-op.Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) params = torch.tensor([0.0, 0.5, 1.0]) sigma2 = mat.trans_params(params) mat.set_intermediates(params, {"sigma2": sigma2}) mat.get_intermediates(params)
{'sigma2': tensor([1.0000, 2.7183, 7.3891])}
- get_intermediates(params)[source]¶
Retrieve cached intermediate computation results if still valid.
Compares the hash, dtype, and device of
paramsagainst the stored cache from the lastset_intermediates()call. Returns the cached value only if all three match, ensuring stale results are never returned after a parameter update, device transfer, or dtype cast.- Parameters:
params (torch.Tensor or dict) – Current parameter tensor or dictionary. Converted to a flat tensor via
from_param_dict()before comparison.- Returns:
The cached intermediate object if the cache is valid, or
Noneif the cache is missing, stale, orparamshas length 0.
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) params = torch.tensor([0.0, 0.5, 1.0]) sigma2 = mat.trans_params(params) mat.set_intermediates(params, {"sigma2": sigma2}) mat.get_intermediates(params)
{'sigma2': tensor([1.0000, 2.7183, 7.3891])}
- reset_intermediates()[source]¶
Clear the intermediate computation cache.
Resets all cached values set by
set_intermediates()toNone, forcing subsequent calls toget_intermediates()to returnNoneuntil the cache is repopulated. Called automatically in__init__()and withinauto_grad()before triggering a fresh Jacobian computation.Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) params = torch.tensor([0.0, 0.5, 1.0]) sigma2 = mat.trans_params(params) mat.set_intermediates(params, {"sigma2": sigma2}) print(mat.get_intermediates(params)) mat.reset_intermediates() print(mat.get_intermediates(params))
{'sigma2': tensor([1.0000, 2.7183, 7.3891])} None
- set_no_grad(index=None, param_name=None)[source]¶
Set the indices of parameters to exclude from gradient computation.
Replaces
no_grad_indexwith the provided indices. Exactly one ofindexorparam_namemust be supplied; providing both or neither raises an error.- Parameters:
index (int or list of int, optional) – Zero-based index or list of indices into
param_namesto exclude from gradient computation.param_name (str or list of str, optional) – Parameter name or list of names to exclude. Names must exist in
param_names.
- Raises:
ValueError – If both or neither of
indexandparam_nameare provided, or if any index is out of range.KeyError – If any name in
param_nameis not found inparam_names.
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) mat.set_no_grad(index=0) print(mat.no_grad_index) print(mat.grad(torch.zeros(3)))
[0] (tensor([[[0., 0., 0.], [0., 2., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 2.]]]), ['sigma^2_1', 'sigma^2_2'])
- from_param_dict(param_dict)[source]¶
Extract parameter tensors from a dictionary into a flat 1D tensor.
Converts a parameter dictionary to a concatenated 1D tensor ordered according to
param_names. The inverse operation is provided byto_param_dict().- Parameters:
param_dict (torch.Tensor or dict) – Either a flat parameter tensor (returned as-is), or a dictionary mapping parameter names to tensors. All keys must exist in
param_namesand no extra keys are allowed.- Returns:
- Concatenated 1D tensor containing all parameters
in the order specified by
param_names.
- Return type:
torch.Tensor
- Raises:
ValueError – If
param_dictis a dictionary missing required keys or containing unexpected keys, or if the tensor length does not match the number of parameters.
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) param_dict = {"sigma^2_0": torch.tensor([0.0]), "sigma^2_1": torch.tensor([0.5]), "sigma^2_2": torch.tensor([1.0])} mat.from_param_dict(param_dict)
tensor([0.0000, 0.5000, 1.0000])
- to_param_dict(params)[source]¶
Convert a flat parameter tensor to a parameter dictionary.
Maps each element of a 1D parameter tensor to its corresponding name in
param_names, returning a dictionary of scalar tensors. This is the inverse offrom_param_dict().- Parameters:
params (torch.Tensor or dict) – Either a flat 1D tensor of length
num_params(converted to a dict), or a dict (returned as-is).- Returns:
Mapping from each name in
param_namesto a 1D single-element tensor.- Return type:
dict
- Raises:
ValueError – If
paramsis a tensor whose length does not equalnum_params.
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) params = torch.tensor([0.0, 0.5, 1.0]) mat.to_param_dict(params)
{'sigma^2_0': tensor([0.]), 'sigma^2_1': tensor([0.5000]), 'sigma^2_2': tensor([1.])}
- trans_params(params)[source]¶
Apply parameter transforms to a flat parameter tensor.
Applies the transforms in
transelement-wise toparams. IftransisNoneor empty, returnsparamsunchanged. Iftranshas a single entry, that transform is broadcast and applied to all parameters simultaneously. Otherwise, each transform is applied to its corresponding parameter individually.- Parameters:
params (torch.Tensor or dict) – Flat 1D parameter tensor or dictionary. Converted via
from_param_dict()before transformation.- Returns:
Transformed parameter tensor of the same shape as
params.- Return type:
torch.Tensor
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) params = torch.tensor([0.0, 0.5, 1.0]) mat.trans_params(params)
tensor([1.0000, 2.7183, 7.3891])
- trans_grad(params)[source]¶
Compute the element-wise derivative of the parameter transforms.
Returns the Jacobian diagonal of
trans_params()with respect to the raw (untransformed) parameters. Used in the chain rule when computing gradients of the matrix with respect to the original parameterisation.If
transisNoneor empty, returns a tensor of ones (identity derivative). Iftranshas a single entry, its derivative is broadcast across all parameters. Otherwise, each transform’s derivative is evaluated at its corresponding parameter.- Parameters:
params (torch.Tensor or dict) – Flat 1D parameter tensor or dictionary. Converted via
from_param_dict()before evaluation.- Returns:
1D tensor of element-wise transform derivatives, of the same length as
params.- Return type:
torch.Tensor
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) params = torch.tensor([0.0, 0.5, 1.0]) mat.trans_grad(params)
tensor([ 2.0000, 5.4366, 14.7781])
- auto_grad(params)[source]¶
Compute the Jacobian of
build()with respect to trainable parameters using automatic differentiation.Uses
torch.func.jacrev()to compute the full Jacobian, then masks out parameters listed inno_grad_index.If all parameters are excluded via
no_grad_index, returns(None, [])- Parameters:
params (torch.Tensor) – Flat 1D parameter tensor.
- Returns:
(grad, grad_names), wheregradis a 3D tensor of shape(num_params - len(no_grad_index), *shape), andgrad_nameshas the same length asgrad.- Return type:
tuple
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(2) params = torch.tensor([0.0, 0.5]) grad, grad_names = mat.auto_grad(params) grad, grad_names
(tensor([[[2.0000, 0.0000], [0.0000, 0.0000]], [[0.0000, 0.0000], [0.0000, 5.4366]]]), ['sigma^2_0', 'sigma^2_1'])
- manual_grad(params)[source]¶
Compute the Jacobian of
__call__()with respect to trainable parameters using a closed-form analytic expression.This method is optional. When implemented by a subclass,
grad()will invoke it in preference toauto_grad()under the default grad mode. If not implemented, calling this method raisesNotImplementedErrorandgrad()falls back to automatic differentiation.Implementations must satisfy the following contract:
Return
(None, [])if all parameters are excluded viano_grad_index.Return a 3D gradient tensor of shape
(num_params - len(no_grad_index), *shape)and a matching list of parameter names, omitting any index inno_grad_index.Apply transform derivatives from
trans_grad()via the chain rule so that gradients are with respect to the raw (untransformed) parameters.
- Parameters:
params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary.
- Returns:
(grad, grad_names), wheregradis a 3D tensor of shape(num_params - len(no_grad_index), *shape)andgrad_namesis a list of the corresponding parameter names. Returns(None, [])if all parameters are excluded from gradient computation.- Return type:
tuple
- Raises:
NotImplementedError – If the subclass does not provide an analytic gradient.
grad()catches this and falls back toauto_grad().
- abstractmethod __call__(params)[source]¶
Construct the matrix from a flat parameter tensor.
Must be implemented by subclasses. Implementations should convert
paramsviafrom_param_dict()orto_param_dict(), then callcheck_params()to validate andtrans_params()to apply transforms before any computation.- Parameters:
params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary.
- Returns:
Constructed matrix of shape
shape.- Return type:
torch.Tensor
- grad(params)[source]¶
Compute the Jacobian of
__call__()with respect to trainable parameters.Dispatches to
manual_grad()orauto_grad()according tograd_mode:"default": attemptsmanual_grad(), falling back toauto_grad()if not implemented."auto": always usesauto_grad().
- Parameters:
params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary.
- Returns:
(grad, grad_names)as described inmanual_grad()andauto_grad().- Return type:
tuple
- Raises:
RuntimeError – If
grad_modeis not a recognised value.
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(2) params = torch.tensor([0.0, 0.5]) grad, grad_names = mat.grad(params) grad, grad_names
(tensor([[[2.0000, 0.0000], [0.0000, 0.0000]], [[0.0000, 0.0000], [0.0000, 5.4366]]]), ['sigma^2_0', 'sigma^2_1'])
- map_theta_to_v(theta)[source]¶
An interface compatible with
torch_openreml.REMLthat maps parameters to a matrix.Invokes
__call__().- Parameters:
theta (torch.Tensor) – Flat 1D parameter tensor.
- Returns:
Constructed matrix.
- Return type:
torch.Tensor
- map_theta_to_dv(theta)[source]¶
An interface compatible with
torch_openreml.REMLthat maps parameters to the matrix Jacobian.Invokes
grad().- Parameters:
theta (torch.Tensor) – Flat 1D parameter tensor.
- Returns:
Jacobian tensor of shape
(num_grad_params, *shape), orNoneif all parameters are excluded from gradient computation.- Return type:
torch.Tensor or None
- check_params(params)[source]¶
Validate a parameter tensor and return its device and dtype.
Accepts a parameter dictionary and converts it to a flat tensor via
from_param_dict()before validation.- Parameters:
params (torch.Tensor or dict) – Parameters to validate.
- Returns:
(device, dtype)of the parameter tensor.- Return type:
tuple
- Raises:
TypeError – If
paramsis not a tensor.ValueError – If
paramsis not 1D or has the wrong length.
- property shape¶
Output matrix shape.
- Type:
tuple
- property param_names¶
Ordered parameter names.
- Type:
list of str
- property num_params¶
Total number of parameters.
- Type:
int
- property no_grad_index¶
Indices of parameters excluded from gradient computation.
- Type:
list of int
- property repr_dict¶
Key-value pairs used to build the string representation.
- Type:
dict