torch_openreml.covariance.Matrix¶
- class torch_openreml.covariance.Matrix(shape, param_specs)[source]¶
Bases:
ABCAbstract base class for covariance matrices with parameterized structure.
\[\symbf{V} = \symbf{V}(\symbf{\theta})\]where \(\symbf{\theta}\) denotes the collection of variance component parameters that define the matrix entries.
This class provides utilities for parameter validation, transform application, and Jacobian computation (both manual and automatic). Subclasses must implement
__call__()to construct their specific matrix structure from the provided parameters.Initialize a covariance matrix with parameter specifications.
- Parameters:
shape (tuple or None) – Expected output dimensions of the constructed matrix. Used for validation; the actual shape may be set by subclasses.
param_specs (dict) – Parameter specifications. Keys should be strings representing parameter names. Values should be dictionaries containing the specification for each parameter. Each specification dictionary should contain the keys
"fixed","default", and"trans", representing whether the parameter is fixed or free (bool), the default value (1D torch.Tensor), and the transform (Transform), respectively.
- Raises:
TypeError – If
param_specsdoes not follow any of the requirements listed in the argument description, or ifshapeis not a tuple or torch.Size.ValueError – If
shapevalues are non-negative.
Methods
__call__([free_params])Construct the matrix from a flat parameter tensor.
auto_grad([free_params])Compute the Jacobian of
build()with respect to free parameters using automatic differentiation.build_params([free_params, include_fixed, ...])Construct the full parameter tensor from free parameters.
get_intermediates(params)Retrieve cached intermediate computation results if still valid.
grad([free_params])Compute the Jacobian of
__call__()with respect to trainable parameters.manual_grad([free_params])Compute the Jacobian of
__call__()with respect to free parameters using a closed-form analytic expression.map_theta_to_dv(theta)An interface compatible with
torch_openreml.REMLthat maps parameters to the matrix Jacobian.map_theta_to_v(theta)An interface compatible with
torch_openreml.REMLthat maps parameters to a matrix.Clear the intermediate computation cache.
set_intermediates(params, intermediates)Cache intermediate computation results keyed by parameter hash.
trans_grad([free_params])Compute the element-wise derivative of the free parameter transforms.
Attributes
Fixed parameter defaults.
Index of fixed parameters.
Fixed parameter names.
Transforms for fixed parameters.
Free parameter defaults.
Index of free parameters.
Free parameter names.
Transforms for free parameters.
Total number of fixed parameters.
Total number of free parameters.
Total number of parameters.
Parameter defaults.
Parameter names.
Parameter specifications.
Parameter transforms.
Key-value pairs used to build the string representation.
Output matrix shape.
- set_intermediates(params, intermediates)[source]¶
Cache intermediate computation results keyed by parameter hash.
Stores arbitrary intermediate values alongside a hash of the current parameter tensor, dtype, and device. Cached values can be retrieved via
get_intermediates()to avoid redundant computation across multiple calls with identical parameters.- Parameters:
params (torch.Tensor) – Current parameter tensor.
intermediates – Arbitrary object to cache (e.g. Cholesky factors, eigendecompositions, or any reusable computation).
- Raises:
TypeError – If
paramsis not a Torch tensor.ValueError – If
paramsis not a 1D tensor.
Note
If
paramshas length 0 (no free parameters), this is a no-op.Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) free_params = torch.tensor([0.0, 0.5, 1.0]) params = mat.build_params(free_params) mat.set_intermediates(params, {"log(sigma^2)/2": torch.log(params) / 2}) mat.get_intermediates(params)
{'log(sigma^2)/2': tensor([0.0000, 0.5000, 1.0000])}
- get_intermediates(params)[source]¶
Retrieve cached intermediate computation results if still valid.
Compares the hash, dtype, and device of
paramsagainst the stored cache from the lastset_intermediates()call. Returns the cached value only if all three match, ensuring stale results are never returned after a parameter update, device transfer, or dtype cast.- Parameters:
params (torch.Tensor) – Current parameter tensor.
- Raises:
TypeError – If
paramsis not a Torch tensor.ValueError – If
paramsis not a 1D tensor.
- Returns:
The cached intermediate object if the cache is valid, or
Noneif the cache is missing, stale, orparamshas length 0.
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) free_params = torch.tensor([0.0, 0.5, 1.0]) params = mat.build_params(free_params) mat.set_intermediates(params, {"log(sigma^2)/2": torch.log(params) / 2}) mat.get_intermediates(params)
{'log(sigma^2)/2': tensor([0.0000, 0.5000, 1.0000])}
- reset_intermediates()[source]¶
Clear the intermediate computation cache.
Resets all cached values set by
set_intermediates()toNone, forcing subsequent calls toget_intermediates()to returnNoneuntil the cache is repopulated. Called automatically in__init__()and withinauto_grad()before triggering a fresh Jacobian computation.Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) free_params = torch.tensor([0.0, 0.5, 1.0]) params = mat.build_params(free_params) mat.set_intermediates(params, {"log(sigma^2)/2": torch.log(params) / 2}) print(mat.get_intermediates(params))
{'log(sigma^2)/2': tensor([0.0000, 0.5000, 1.0000])}mat.reset_intermediates() print(mat.get_intermediates(free_params))
None
- build_params(free_params=None, include_fixed=True, trans=True, out_format='tensor')[source]¶
Construct the full parameter tensor from free parameters.
Merges free (trainable) parameters with fixed parameter defaults and applies parameter transforms. Optionally returns a dictionary mapping parameter names to their transformed values.
- Parameters:
free_params (torch.Tensor or dict) – Flat 1D tensor of free parameters or a dictionary mapping parameter names to tensors. If omitted, default values are used. Default:
None.include_fixed (bool, optional) – Whether to include fixed parameters in the output. Default:
True.trans (bool, optional) – Whether to apply parameter transforms to the output. Default:
True.out_format (str, optional) – Output format. One of
"tensor"or"dict". Default:"tensor".
- Returns:
Full parameter tensor of length
num_params(ornum_free_paramswheninclude_fixed=False), or a dictionary mapping parameter names to value tensors.- Return type:
torch.Tensor or dict
- Raises:
ValueError – If
out_formatis not"tensor"or"dict".TypeError – If
free_paramsis not a Torch tensor.ValueError – If
free_paramsis not a 1D tensor or has the wrong length, or iffree_paramsis a dict with missing or unexpected keys.
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) free_params = torch.tensor([0.0, 0.5, 1.0]) mat.build_params(free_params)
tensor([1.0000, 2.7183, 7.3891])
mat.build_params()
tensor([1., 1., 1.])
mat.param_specs["sigma^2_2"]["fixed"] = True mat.build_params(free_params[0:2])
tensor([1.0000, 2.7183, 1.0000])
mat.build_params(free_params[0:2], include_fixed=False)
tensor([1.0000, 2.7183])
mat.build_params(free_params[0:2], include_fixed=False, trans=False)
tensor([0.0000, 0.5000])
- trans_grad(free_params=None)[source]¶
Compute the element-wise derivative of the free parameter transforms.
Evaluates the derivative of each free parameter’s transform function at the current parameter values. Used in the chain rule when computing manual gradients of the matrix with respect to the original (untransformed) parameterisation.
- Parameters:
free_params (torch.Tensor or dict) – Flat 1D parameter tensor or dictionary of free parameters. If omitted, default values are used. Default:
None.- Raises:
TypeError – If
free_paramsis not a Torch tensor.ValueError – If
free_paramsis not a 1D tensor or has the wrong length, or iffree_paramsis a dict with missing or unexpected keys.
- Returns:
1D tensor of element-wise transform derivatives, of the same length as
free_params.- Return type:
torch.Tensor
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) free_params = torch.tensor([0.0, 0.5, 1.0]) mat.trans_grad(free_params)
tensor([ 2.0000, 5.4366, 14.7781])
mat.trans_grad()
tensor([2., 2., 2.])
- auto_grad(free_params=None)[source]¶
Compute the Jacobian of
build()with respect to free parameters using automatic differentiation.Uses
torch.func.jacrev()to compute the full Jacobian.If all parameters are fixed, returns
(None, [])- Parameters:
free_params (torch.Tensor or dict) – Flat 1D parameter tensor or dict. If omitted, default values are used. Default:
None.- Raises:
TypeError – If
free_paramsis not a Torch tensor.ValueError – If
free_paramsis not a 1D tensor or has the wrong length, or iffree_paramsis a dict with missing or unexpected keys.
- Returns:
(grad, grad_names), wheregradis a 3D tensor of shape(num_free_params, *shape), andgrad_nameshas the same length asgrad.- Return type:
tuple
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(2) free_params = torch.tensor([0.0, 0.5]) grad, grad_names = mat.auto_grad(free_params) grad, grad_names
(tensor([[[2.0000, 0.0000], [0.0000, 0.0000]], [[0.0000, 0.0000], [0.0000, 5.4366]]]), ['sigma^2_0', 'sigma^2_1'])
- manual_grad(free_params=None)[source]¶
Compute the Jacobian of
__call__()with respect to free parameters using a closed-form analytic expression.This method is optional. When implemented by a subclass,
grad()will invoke it in preference toauto_grad()under the default grad mode. If not implemented, calling this method raisesNotImplementedErrorandgrad()falls back to automatic differentiation.Implementations must satisfy the following contract:
Return
(None, [])if all parameters are fixed.Return a 3D gradient tensor of shape
(num_free_params, *shape)and a matching list of parameter names.Apply transform derivatives from
trans_grad()via the chain rule so that gradients are with respect to the raw (untransformed) parameters.
- Parameters:
free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default:
None.- Returns:
(grad, grad_names), wheregradis a 3D tensor of shape(num_free_params, *shape)andgrad_namesis a list of the corresponding parameter names. Returns(None, [])if all parameters are fixed.- Return type:
tuple
- Raises:
NotImplementedError – If the subclass does not provide an analytic gradient.
grad()catches this and falls back toauto_grad().
- abstractmethod __call__(free_params=None)[source]¶
Construct the matrix from a flat parameter tensor.
Must be implemented by subclasses. Implementations should convert
free_paramsviabuild_params()to validate, include fixed parameters, and apply transforms before any computation.- Parameters:
free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default:
None.- Returns:
Constructed matrix of shape
shape.- Return type:
torch.Tensor
- grad(free_params=None)[source]¶
Compute the Jacobian of
__call__()with respect to trainable parameters.Dispatches to
manual_grad()orauto_grad()according tograd_mode:"default": attemptsmanual_grad(), falling back toauto_grad()if not implemented."auto": always usesauto_grad().
- Parameters:
free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default:
None.- Returns:
(grad, grad_names)as described inmanual_grad()andauto_grad().- Return type:
tuple
- Raises:
RuntimeError – If
grad_modeis not a recognised value.
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(2) free_params = torch.tensor([0.0, 0.5]) grad, grad_names = mat.grad(free_params) grad, grad_names
(tensor([[[2.0000, 0.0000], [0.0000, 5.4366]], [[2.0000, 0.0000], [0.0000, 5.4366]]]), ['sigma^2_0', 'sigma^2_1'])
- map_theta_to_v(theta)[source]¶
An interface compatible with
torch_openreml.REMLthat maps parameters to a matrix.Invokes
__call__().- Parameters:
theta (torch.Tensor) – Flat 1D parameter tensor.
- Returns:
Constructed matrix.
- Return type:
torch.Tensor
- map_theta_to_dv(theta)[source]¶
An interface compatible with
torch_openreml.REMLthat maps parameters to the matrix Jacobian.Invokes
grad().- Parameters:
theta (torch.Tensor) – Flat 1D parameter tensor.
- Raises:
RuntimeError – If
grad_modeis not a recognised value.- Returns:
Jacobian tensor of shape
(num_free_params, *shape), orNoneif all parameters are fixed.- Return type:
torch.Tensor or None
- property shape¶
Output matrix shape.
- Type:
tuple
- property param_specs¶
Parameter specifications.
- Type:
dict
- property param_names¶
Parameter names.
- Type:
list of str
- property free_param_names¶
Free parameter names.
- Type:
list of str
- property fixed_param_names¶
Fixed parameter names.
- Type:
list of str
- property free_param_index¶
Index of free parameters.
- Type:
tuple
- property fixed_param_index¶
Index of fixed parameters.
- Type:
tuple
- property num_params¶
Total number of parameters.
- Type:
int
- property num_free_params¶
Total number of free parameters.
- Type:
int
- property num_fixed_params¶
Total number of fixed parameters.
- Type:
int
- property param_defaults¶
Parameter defaults.
- Type:
Dict of torch.Tensor
- property free_param_defaults¶
Free parameter defaults.
- Type:
Dict of torch.Tensor
- property fixed_param_defaults¶
Fixed parameter defaults.
- Type:
Dict of torch.Tensor
- property repr_dict¶
Key-value pairs used to build the string representation.
- Type:
dict