torch_openreml.covariance.Adapter¶
- class torch_openreml.covariance.Adapter(adaptee, param_specs, param_map)[source]¶
Bases:
MatrixReparameterise an existing covariance matrix with a new set of parameters.
An
Adapterwraps an adaptee matrix and exposes a user-defined set of parameters. A callableparam_maptranslates the adapter’s parameter tensor into the parameter tensor expected by the adaptee:\[\symbf{V}(\boldsymbol{\theta}) = \symbf{V}_{\text{adaptee}}\!\big(f(\boldsymbol{\theta})\big)\]where \(f\) is
param_mapand \(\boldsymbol{\theta}\) are the adapter’s free parameters.Gradients are computed via the chain rule: the adaptee’s gradient with respect to its own parameters is pulled back through the Jacobian of
param_map.Note
All adapter parameter transforms must be
TransformIdentity. Non-identity transforms should be applied insideparam_mapor on the adaptee directly.Initialise an adapter wrapping an existing matrix.
- Parameters:
adaptee (
Matrix) – The matrix to reparameterise.param_specs (dict) – Parameter specifications for the adapter’s own parameters. Keys are parameter names; values are dictionaries with keys
"fixed","default", and"trans". All transforms must beTransformIdentity.param_map (callable) – A function
f(params) -> adaptee_paramsthat maps the adapter’s parameter tensor to the parameter tensor expected byadaptee. Must be differentiable (compatible withtorch.func.jacrev()).
- Raises:
ValueError – If any parameter specification uses a non-identity transform.
Example:
import torch from torch_openreml.covariance import DiagonalMatrix, Adapter from torch_openreml.covariance.transform import TransformIdentity # Wrap a 2x2 diagonal matrix so that both variances share a # single parameter (sum-to-one constraint). adaptee = DiagonalMatrix(2) def param_map(params): # params[0] drives one variance; the other is 1 - params[0] p = torch.sigmoid(params[0]) return torch.stack([p, 1 - p]) param_specs = { "logit": { "fixed": False, "default": torch.tensor([0.0]), "trans": TransformIdentity(), } } adapter = Adapter(adaptee, param_specs, param_map) adapter(torch.tensor([0.0]))
tensor([[2.7183, 0.0000], [0.0000, 2.7183]])Methods
__call__([free_params])Construct the matrix from a flat parameter tensor.
auto_grad([free_params])Compute the Jacobian of
__call__()using automatic differentiation.build_params([free_params, include_fixed, ...])Construct the full parameter tensor from free parameters.
get_intermediates(params)Retrieve cached intermediate computation results if still valid.
grad([free_params])Compute the Jacobian of
__call__()with respect to trainable parameters.manual_grad([free_params])Compute the Jacobian of
__call__()via the chain rule.map_theta_to_dv(theta)An interface compatible with
torch_openreml.REMLthat maps parameters to the matrix Jacobian.map_theta_to_v(theta)An interface compatible with
torch_openreml.REMLthat maps parameters to a matrix.reset_intermediates()Clear the intermediate computation cache.
set_intermediates(params, intermediates)Cache intermediate computation results keyed by parameter hash.
trans_grad([free_params])Compute the element-wise derivative of the free parameter transforms.
Attributes
The wrapped matrix that this adapter reparameterises.
fixed_param_defaultsFixed parameter defaults.
fixed_param_indexIndex of fixed parameters.
fixed_param_namesFixed parameter names.
fixed_param_transTransforms for fixed parameters.
free_param_defaultsFree parameter defaults.
free_param_indexIndex of free parameters.
free_param_namesFree parameter names.
free_param_transTransforms for free parameters.
num_fixed_paramsTotal number of fixed parameters.
num_free_paramsTotal number of free parameters.
num_paramsTotal number of parameters.
param_defaultsParameter defaults.
The mapping function
f(params) -> adaptee_paramsthat translates the adapter's parameter tensor to the adaptee's parameter tensor.param_namesParameter names.
The adapter's parameter specifications.
param_transParameter transforms.
A dictionary representation for display, containing
param_specs,adaptee, and a placeholder forparam_map.shapeOutput matrix shape.
- __call__(free_params=None)[source]¶
Construct the matrix from a flat parameter tensor.
Must be implemented by subclasses. Implementations should convert
free_paramsviabuild_params()to validate, include fixed parameters, and apply transforms before any computation.- Parameters:
free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default:
None.- Returns:
Constructed matrix of shape
shape.- Return type:
torch.Tensor
- auto_grad(free_params=None)[source]¶
Compute the Jacobian of
__call__()using automatic differentiation.Resets the adaptee’s intermediate cache before calling the parent
auto_grad(), which usestorch.func.jacrev().- Parameters:
free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default:
None.- Returns:
(grad, grad_names), wheregradis a 3D tensor of shape(num_free_params, *shape)andgrad_namesis a list of the corresponding parameter names. Returns(None, [])if all parameters are fixed.- Return type:
tuple
- manual_grad(free_params=None)[source]¶
Compute the Jacobian of
__call__()via the chain rule.The gradient is obtained by pulling back the adaptee’s gradient through the Jacobian of
param_map:\[\frac{\partial \symbf{V}}{\partial \boldsymbol{\theta}} = \sum_k \frac{\partial \symbf{V}_{\text{adaptee}}}{\partial \phi_k} \cdot \frac{\partial \phi_k}{\partial \boldsymbol{\theta}}\]where \(\boldsymbol{\phi} = f(\boldsymbol{\theta})\) are the adaptee’s parameters.
- Parameters:
free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default:
None.- Returns:
(grad, grad_names), wheregradis a 3D tensor of shape(num_free_params, *shape)andgrad_namesis a list of the corresponding parameter names. Returns(None, [])if all parameters are fixed.- Return type:
tuple
- property param_specs¶
The adapter’s parameter specifications.
All transforms are forced to
TransformIdentity.- Type:
dict
- property param_map¶
The mapping function
f(params) -> adaptee_paramsthat translates the adapter’s parameter tensor to the adaptee’s parameter tensor.- Type:
callable
- property repr_dict¶
A dictionary representation for display, containing
param_specs,adaptee, and a placeholder forparam_map.- Type:
dict