torch_openreml.covariance.Operator

class torch_openreml.covariance.Operator(*args, **kwargs)[source]

Bases: Matrix

Abstract base class for composite covariance matrix operators.

An operator combines one or more Matrix instances and optional fixed torch.Tensor operands into a single covariance matrix. Parameters and gradients are namespaced by operand — a parameter "sigma^2" belonging to operand "residual" is exposed as "residual/sigma^2" in the combined param_names.

Operands may be passed as positional arguments, as keyword arguments, or as a single dictionary. Mixing positional and keyword arguments is not permitted. Positional operands are assigned names "op_0", "op_1", etc.

At least one operand must be a Matrix instance. Pure-tensor operands are treated as fixed matrices with no free parameters.

Initialize the operator from positional or keyword operands.

Parameters:
  • *args – Operands as positional arguments, each a Matrix or torch.Tensor. A single dict argument is also accepted and treated as a named operand mapping.

  • **kwargs – Operands as keyword arguments, mapping operand names to Matrix or torch.Tensor instances. Operand names must not contain "/".

Raises:
  • ValueError – If both positional and keyword arguments are provided, or if any operand name contains "/".

  • TypeError – If operands is not a dict, if any operand name is not a string, if any operand is not a Matrix or torch.Tensor, or if no operand is a Matrix.

Example

import torch
from torch_openreml.covariance import Sum, ScalarMatrix

x = Sum(ScalarMatrix(2), ScalarMatrix(2))
x
Sum(
  operands={
    'op_0': ScalarMatrix(shape=(2, 2), param_specs={'sigma^2': {'fixed': False, 'default': tensor([0.]), 'trans': TransformExpPow2()}}),
    'op_1': ScalarMatrix(shape=(2, 2), param_specs={'sigma^2': {'fixed': False, 'default': tensor([0.]), 'trans': TransformExpPow2()}})
  }
)
x = Sum(A = ScalarMatrix(2), B = ScalarMatrix(2))
x
Sum(
  operands={
    'A': ScalarMatrix(shape=(2, 2), param_specs={'sigma^2': {'fixed': False, 'default': tensor([0.]), 'trans': TransformExpPow2()}}),
    'B': ScalarMatrix(shape=(2, 2), param_specs={'sigma^2': {'fixed': False, 'default': tensor([0.]), 'trans': TransformExpPow2()}})
  }
)
x.param_names
['A/sigma^2', 'B/sigma^2']
x(torch.zeros(2))
tensor([[2., 0.],
        [0., 2.]])
x({"A/sigma^2": torch.zeros(1), "B/sigma^2": torch.zeros(1)})
tensor([[2., 0.],
        [0., 2.]])

Methods

__call__([free_params])

Construct the matrix from a flat parameter tensor.

auto_grad([free_params])

Compute the Jacobian of build() with respect to free parameters using automatic differentiation.

build_operands([free_params])

Evaluate each operand at the current free parameters.

build_params([free_params, include_fixed, ...])

Construct the full parameter tensor by delegating to each operand.

get_intermediates(params)

Retrieve cached intermediate computation results if still valid.

grad([free_params])

Compute the Jacobian of __call__() with respect to trainable parameters.

manual_grad([free_params])

Compute the Jacobian of __call__() with respect to free parameters using a closed-form analytic expression.

map_theta_to_dv(theta)

An interface compatible with torch_openreml.REML that maps parameters to the matrix Jacobian.

map_theta_to_v(theta)

An interface compatible with torch_openreml.REML that maps parameters to a matrix.

operands_grad([free_params])

Compute the Jacobian of each operand with respect to its parameters.

reset_intermediates()

Clear the intermediate computation cache.

set_intermediates(params, intermediates)

Cache intermediate computation results keyed by parameter hash.

trans_grad([free_params])

Compute the element-wise derivative of the free parameter transforms.

Attributes

fixed_param_defaults

Fixed parameter defaults.

fixed_param_index

Index of fixed parameters.

fixed_param_names

Fixed parameter names.

fixed_param_trans

Transforms for fixed parameters.

free_param_defaults

Free parameter defaults.

free_param_index

Index of free parameters.

free_param_names

Free parameter names.

free_param_trans

Transforms for free parameters.

num_fixed_params

Total number of fixed parameters.

num_free_params

Total number of free parameters.

num_params

Total number of parameters.

operands

Mapping from operand names to operand matrices or tensors.

param_defaults

Parameter defaults.

param_names

Parameter names.

param_specs

Parameter specifications.

param_trans

Parameter transforms.

repr_dict

Key-value pairs used to build the string representation.

shape

Output matrix shape.

build_params(free_params=None, include_fixed=True, trans=True, out_format='tensor')[source]

Construct the full parameter tensor by delegating to each operand.

Splits free_params into per-operand slices according to each operand’s num_free_params, calls build_params() on each Matrix operand, and concatenates the results. Fixed tensor operands are skipped.

Parameters:
  • free_params (torch.Tensor or dict) – Flat 1D tensor of free parameters or a dictionary mapping parameter names to tensors. If omitted, default values are used. Default: None.

  • include_fixed (bool, optional) – Whether to include fixed parameters in the output. Passed through to each operand’s build_params(). Default: True.

  • trans (bool, optional) – Whether to apply parameter transforms to the output. Passed through to each operand’s build_params(). Default: True.

  • out_format (str, optional) – Output format. One of "tensor" or "dict". Default: "tensor".

Returns:

Concatenated parameter tensor or dictionary

mapping namespaced parameter names to value tensors.

Return type:

torch.Tensor or dict

Raises:
  • TypeError – If free_params is not a Torch tensor.

  • ValueError – If free_params is not a 1D tensor or has the wrong length, or if free_params is a dict with missing or unexpected keys.

  • ValueError – If out_format is not "tensor" or "dict".

import torch
from torch_openreml.covariance import Sum, ScalarMatrix

x = Sum(ScalarMatrix(2), ScalarMatrix(2))
free_params = torch.tensor([0.0, 0.5])

x.build_params(free_params)
tensor([1.0000, 2.7183])
x.build_params()
tensor([1., 1.])
x.build_params(free_params, trans=False)
tensor([0.0000, 0.5000])
x.build_params(free_params, out_format="dict")
{'op_0/sigma^2': tensor(1.), 'op_1/sigma^2': tensor(2.7183)}
build_operands(free_params=None)[source]

Evaluate each operand at the current free parameters.

Splits free_params into per-operand slices and calls each Matrix operand to produce its matrix. Fixed tensor operands are included as-is.

Parameters:

free_params (torch.Tensor or dict) – Flat 1D joint parameter tensor or parameter dictionary of length num_free_params. If omitted, default values are used. Default: None.

Returns:

Evaluated operand matrices in the same order as operands.

Return type:

list of torch.Tensor

Example

import torch
from torch_openreml.covariance import Sum, ScalarMatrix

x = Sum(ScalarMatrix(2), ScalarMatrix(2))
v_groups = x.build_operands(torch.tensor([1.0, 2.0]))
print(v_groups[0])
print(v_groups[1])
tensor([[7.3891, 0.0000],
        [0.0000, 7.3891]])
tensor([[54.5982,  0.0000],
        [ 0.0000, 54.5982]])
operands_grad(free_params=None)[source]

Compute the Jacobian of each operand with respect to its parameters.

Splits free_params into per-operand slices, calls grad() on each Matrix operand, and prefixes the returned names with the operand name. Fixed tensor operands contribute None and an empty name list.

Parameters:

params (torch.Tensor or dict) – Flat 1D joint parameter tensor or parameter dictionary of length num_params. If omitted, default values are used. Default: None.

Returns:

(grad_groups, grad_name_groups), where grad_groups is a list of per-operand Jacobian tensors or None for fixed operands, and grad_name_groups is a list of corresponding namespaced parameter name lists.

Return type:

tuple

Example

import torch
from torch_openreml.covariance import Sum, ScalarMatrix

x = Sum(ScalarMatrix(2), ScalarMatrix(2))
grad_groups, grad_name_groups = x.operands_grad(torch.tensor([1.0, 2.0]))
print(grad_groups[0])
print(grad_groups[1])
print(grad_name_groups[0])
print(grad_name_groups[1])
tensor([[[14.7781,  0.0000],
         [ 0.0000, 14.7781]]])
tensor([[[109.1963,   0.0000],
         [  0.0000, 109.1963]]])
['op_0/sigma^2']
['op_1/sigma^2']
auto_grad(free_params=None)[source]

Compute the Jacobian of build() with respect to free parameters using automatic differentiation.

Uses torch.func.jacrev() to compute the full Jacobian.

If all parameters are fixed, returns (None, [])

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or dict. If omitted, default values are used. Default: None.

Raises:
  • TypeError – If free_params is not a Torch tensor.

  • ValueError – If free_params is not a 1D tensor or has the wrong length, or if free_params is a dict with missing or unexpected keys.

Returns:

(grad, grad_names), where grad is a 3D tensor of shape (num_free_params, *shape), and grad_names has the same length as grad.

Return type:

tuple

Example:

import torch
from torch_openreml.covariance import DiagonalMatrix

mat = DiagonalMatrix(2)
free_params = torch.tensor([0.0, 0.5])
grad, grad_names = mat.auto_grad(free_params)
grad, grad_names
(tensor([[[2.0000, 0.0000],
          [0.0000, 0.0000]],
 
         [[0.0000, 0.0000],
          [0.0000, 5.4366]]]),
 ['sigma^2_0', 'sigma^2_1'])
property operands

Mapping from operand names to operand matrices or tensors.

Type:

dict

property param_specs

Parameter specifications.

Type:

dict

property repr_dict

Key-value pairs used to build the string representation.

Type:

dict