torch_openreml.covariance.Adapter

class torch_openreml.covariance.Adapter(adaptee, param_specs, param_map)[source]

Bases: Matrix

Reparameterise an existing covariance matrix with a new set of parameters.

An Adapter wraps an adaptee matrix and exposes a user-defined set of parameters. A callable param_map translates the adapter’s parameter tensor into the parameter tensor expected by the adaptee:

\[\symbf{V}(\boldsymbol{\theta}) = \symbf{V}_{\text{adaptee}}\!\big(f(\boldsymbol{\theta})\big)\]

where \(f\) is param_map and \(\boldsymbol{\theta}\) are the adapter’s free parameters.

Gradients are computed via the chain rule: the adaptee’s gradient with respect to its own parameters is pulled back through the Jacobian of param_map.

Note

All adapter parameter transforms must be TransformIdentity. Non-identity transforms should be applied inside param_map or on the adaptee directly.

Initialise an adapter wrapping an existing matrix.

Parameters:
  • adaptee (Matrix) – The matrix to reparameterise.

  • param_specs (dict) – Parameter specifications for the adapter’s own parameters. Keys are parameter names; values are dictionaries with keys "fixed", "default", and "trans". All transforms must be TransformIdentity.

  • param_map (callable) – A function f(params) -> adaptee_params that maps the adapter’s parameter tensor to the parameter tensor expected by adaptee. Must be differentiable (compatible with torch.func.jacrev()).

Raises:

ValueError – If any parameter specification uses a non-identity transform.

Example:

import torch
from torch_openreml.covariance import DiagonalMatrix, Adapter
from torch_openreml.covariance.transform import TransformIdentity

# Wrap a 2x2 diagonal matrix so that both variances share a
# single parameter (sum-to-one constraint).
adaptee = DiagonalMatrix(2)

def param_map(params):
    # params[0] drives one variance; the other is 1 - params[0]
    p = torch.sigmoid(params[0])
    return torch.stack([p, 1 - p])

param_specs = {
    "logit": {
        "fixed": False,
        "default": torch.tensor([0.0]),
        "trans": TransformIdentity(),
    }
}

adapter = Adapter(adaptee, param_specs, param_map)
adapter(torch.tensor([0.0]))
tensor([[2.7183, 0.0000],
        [0.0000, 2.7183]])

Methods

__call__([free_params])

Construct the matrix from a flat parameter tensor.

auto_grad([free_params])

Compute the Jacobian of __call__() using automatic differentiation.

build_params([free_params, include_fixed, ...])

Construct the full parameter tensor from free parameters.

get_intermediates(params)

Retrieve cached intermediate computation results if still valid.

grad([free_params])

Compute the Jacobian of __call__() with respect to trainable parameters.

manual_grad([free_params])

Compute the Jacobian of __call__() via the chain rule.

map_theta_to_dv(theta)

An interface compatible with torch_openreml.REML that maps parameters to the matrix Jacobian.

map_theta_to_v(theta)

An interface compatible with torch_openreml.REML that maps parameters to a matrix.

reset_intermediates()

Clear the intermediate computation cache.

set_intermediates(params, intermediates)

Cache intermediate computation results keyed by parameter hash.

trans_grad([free_params])

Compute the element-wise derivative of the free parameter transforms.

Attributes

adaptee

The wrapped matrix that this adapter reparameterises.

fixed_param_defaults

Fixed parameter defaults.

fixed_param_index

Index of fixed parameters.

fixed_param_names

Fixed parameter names.

fixed_param_trans

Transforms for fixed parameters.

free_param_defaults

Free parameter defaults.

free_param_index

Index of free parameters.

free_param_names

Free parameter names.

free_param_trans

Transforms for free parameters.

num_fixed_params

Total number of fixed parameters.

num_free_params

Total number of free parameters.

num_params

Total number of parameters.

param_defaults

Parameter defaults.

param_map

The mapping function f(params) -> adaptee_params that translates the adapter's parameter tensor to the adaptee's parameter tensor.

param_names

Parameter names.

param_specs

The adapter's parameter specifications.

param_trans

Parameter transforms.

repr_dict

A dictionary representation for display, containing param_specs, adaptee, and a placeholder for param_map.

shape

Output matrix shape.

__call__(free_params=None)[source]

Construct the matrix from a flat parameter tensor.

Must be implemented by subclasses. Implementations should convert free_params via build_params() to validate, include fixed parameters, and apply transforms before any computation.

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default: None.

Returns:

Constructed matrix of shape shape.

Return type:

torch.Tensor

auto_grad(free_params=None)[source]

Compute the Jacobian of __call__() using automatic differentiation.

Resets the adaptee’s intermediate cache before calling the parent auto_grad(), which uses torch.func.jacrev().

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default: None.

Returns:

(grad, grad_names), where grad is a 3D tensor of shape (num_free_params, *shape) and grad_names is a list of the corresponding parameter names. Returns (None, []) if all parameters are fixed.

Return type:

tuple

manual_grad(free_params=None)[source]

Compute the Jacobian of __call__() via the chain rule.

The gradient is obtained by pulling back the adaptee’s gradient through the Jacobian of param_map:

\[\frac{\partial \symbf{V}}{\partial \boldsymbol{\theta}} = \sum_k \frac{\partial \symbf{V}_{\text{adaptee}}}{\partial \phi_k} \cdot \frac{\partial \phi_k}{\partial \boldsymbol{\theta}}\]

where \(\boldsymbol{\phi} = f(\boldsymbol{\theta})\) are the adaptee’s parameters.

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default: None.

Returns:

(grad, grad_names), where grad is a 3D tensor of shape (num_free_params, *shape) and grad_names is a list of the corresponding parameter names. Returns (None, []) if all parameters are fixed.

Return type:

tuple

property param_specs

The adapter’s parameter specifications.

All transforms are forced to TransformIdentity.

Type:

dict

property param_map

The mapping function f(params) -> adaptee_params that translates the adapter’s parameter tensor to the adaptee’s parameter tensor.

Type:

callable

property adaptee

The wrapped matrix that this adapter reparameterises.

Type:

Matrix

property repr_dict

A dictionary representation for display, containing param_specs, adaptee, and a placeholder for param_map.

Type:

dict