torch_openreml.covariance.Operator¶
- class torch_openreml.covariance.Operator(shape, operands)[source]¶
Bases:
MatrixInitialize a covariance matrix with optional parameter transforms.
- Parameters:
shape (tuple or None) – Expected output dimensions of the constructed matrix. Used for validation; the actual shape may be set by subclasses.
param_names (list of str) – Ordered names of parameters in
params. Empty list if no trainable parameters (e.g., fixed matrices).trans (list of Transform or None) – List of transforms applied to each parameter before constructing the matrix. If None, no transforms are used. Typically used for variance (\(\exp(2\theta) > 0\)) or correlation constraints (\(\rho \in (-1, 1)\)).
no_grad_index (list of int) – Indices to exclude from gradient computation. Parameters at these indices will be omitted from
gradandgrad_names. Useset_no_grad()instead for convenience.
Note
The transform applies as
\[\symbf{V} = \left[f_0(\theta_0), \ldots, f_{p-1}(\theta_{p-1}) \right]^\top,\]where each \(f_i\) is the i-th transform in
trans. Iftranshas length 1, the single transform is broadcast and applied elementwise to all parameters.- Raises:
TypeError – If
param_namesis not a list of strings, or if transforms contain non-Transform objects.ValueError – If parameter names are not unique, or if indices in
no_grad_indexare out of range.
Methods
__call__(params)Construct the matrix from a flat parameter tensor.
auto_grad(params)Compute the Jacobian of
build()with respect to trainable parameters using automatic differentiation.build_operands(params)check_operands(operands)check_params(params)Validate a parameter tensor and return its device and dtype.
from_param_dict(param_dict)Extract parameter tensors from a dictionary into a flat 1D tensor.
get_intermediates(params)Retrieve cached intermediate computation results if still valid.
grad(params)Compute the Jacobian of
__call__()with respect to trainable parameters.manual_grad(params)Compute the Jacobian of
__call__()with respect to trainable parameters using a closed-form analytic expression.map_theta_to_dv(theta)An interface compatible with
torch_openreml.REMLthat maps parameters to the matrix Jacobian.map_theta_to_v(theta)An interface compatible with
torch_openreml.REMLthat maps parameters to a matrix.operands_grad(params)reset_intermediates()Clear the intermediate computation cache.
set_intermediates(params, intermediates)Cache intermediate computation results keyed by parameter hash.
set_no_grad([index, param_name])Set the indices of parameters to exclude from gradient computation.
to_param_dict(params)Convert a flat parameter tensor to a parameter dictionary.
trans_grad(params)Compute the element-wise derivative of the parameter transforms.
trans_params(params)Apply parameter transforms to a flat parameter tensor.
Attributes
Indices of parameters excluded from gradient computation.
num_paramsTotal number of parameters.
param_namesOrdered parameter names.
Key-value pairs used to build the string representation.
shapeOutput matrix shape.
transParameter transforms.
- set_no_grad(index=None, param_name=None)[source]¶
Set the indices of parameters to exclude from gradient computation.
Replaces
no_grad_indexwith the provided indices. Exactly one ofindexorparam_namemust be supplied; providing both or neither raises an error.- Parameters:
index (int or list of int, optional) – Zero-based index or list of indices into
param_namesto exclude from gradient computation.param_name (str or list of str, optional) – Parameter name or list of names to exclude. Names must exist in
param_names.
- Raises:
ValueError – If both or neither of
indexandparam_nameare provided, or if any index is out of range.KeyError – If any name in
param_nameis not found inparam_names.
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) mat.set_no_grad(index=0) print(mat.no_grad_index) print(mat.grad(torch.zeros(3)))
[0] (tensor([[[0., 0., 0.], [0., 2., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 2.]]]), ['sigma^2_1', 'sigma^2_2'])
- trans_params(params)[source]¶
Apply parameter transforms to a flat parameter tensor.
Applies the transforms in
transelement-wise toparams. IftransisNoneor empty, returnsparamsunchanged. Iftranshas a single entry, that transform is broadcast and applied to all parameters simultaneously. Otherwise, each transform is applied to its corresponding parameter individually.- Parameters:
params (torch.Tensor or dict) – Flat 1D parameter tensor or dictionary. Converted via
from_param_dict()before transformation.- Returns:
Transformed parameter tensor of the same shape as
params.- Return type:
torch.Tensor
Example:
import torch from torch_openreml.covariance import DiagonalMatrix mat = DiagonalMatrix(3) params = torch.tensor([0.0, 0.5, 1.0]) mat.trans_params(params)
tensor([1.0000, 2.7183, 7.3891])
- property operands¶
- property no_grad_index¶
Indices of parameters excluded from gradient computation.
- Type:
list of int
- property repr_dict¶
Key-value pairs used to build the string representation.
- Type:
dict