torch_openreml.covariance.OperatorGram

class torch_openreml.covariance.OperatorGram(x, gram_type='xtx')[source]

Bases: Operator

Gram matrix operator: \(\symbf{X}^\top \symbf{X}\) or \(\symbf{X} \symbf{X}^\top\).

Given a matrix \(\symbf{X}\) (which may be a fixed torch.Tensor or a trainable Matrix), this operator computes its Gram product in the direction specified at construction:

\[\begin{split}\symbf{V} = \begin{cases} \symbf{X}^\top \symbf{X} & \text{if } gram\_type = \texttt{"xtx"} \\ \symbf{X} \symbf{X}^\top & \text{if } gram\_type = \texttt{"xxt"} \end{cases}\end{split}\]

If \(\symbf{X}\) has shape (n, m), then "xtx" yields an (m, m) matrix and "xxt" yields an (n, n) matrix.

When \(\symbf{X}\) is a Matrix, its parameters are exposed through this operator and gradients are computed analytically via the product rule.

Initialise a Gram operator.

Parameters:
  • x (torch.Tensor or Matrix) – The input matrix of shape (n, m).

  • gram_type (str) – Which Gram product to compute. Must be one of "xtx" (\(\symbf{X}^\top \symbf{X}\)) or "xxt" (\(\symbf{X} \symbf{X}^\top\)). Default: "xtx".

Raises:

ValueError – If gram_type is not "xtx" or "xxt".

Example:

import torch
from torch_openreml.covariance import OperatorGram, LowerTriangularMatrix

x = LowerTriangularMatrix(3, 2)
op = OperatorGram(x, gram_type="xtx")
op()
tensor([[3., 2.],
        [2., 2.]])
op_xxt = OperatorGram(x, gram_type="xxt")
op_xxt()
tensor([[1., 1., 1.],
        [1., 2., 2.],
        [1., 2., 2.]])

Methods

__call__([free_params])

Construct the matrix from a flat parameter tensor.

auto_grad([free_params])

Compute the Jacobian of build() with respect to free parameters using automatic differentiation.

build_operands([free_params])

Evaluate each operand at the current free parameters.

build_params([free_params, include_fixed, ...])

Construct the full parameter tensor by delegating to each operand.

get_intermediates(params)

Retrieve cached intermediate computation results if still valid.

grad([free_params])

Compute the Jacobian of __call__() with respect to trainable parameters.

manual_grad([free_params])

Compute the Jacobian of __call__() with respect to trainable parameters using the product rule.

map_theta_to_dv(theta)

An interface compatible with torch_openreml.REML that maps parameters to the matrix Jacobian.

map_theta_to_v(theta)

An interface compatible with torch_openreml.REML that maps parameters to a matrix.

operands_grad([free_params])

Compute the Jacobian of each operand with respect to its parameters.

reset_intermediates()

Clear the intermediate computation cache.

set_intermediates(params, intermediates)

Cache intermediate computation results keyed by parameter hash.

trans_grad([free_params])

Compute the element-wise derivative of the free parameter transforms.

Attributes

fixed_param_defaults

Fixed parameter defaults.

fixed_param_index

Index of fixed parameters.

fixed_param_names

Fixed parameter names.

fixed_param_trans

Transforms for fixed parameters.

free_param_defaults

Free parameter defaults.

free_param_index

Index of free parameters.

free_param_names

Free parameter names.

free_param_trans

Transforms for free parameters.

gram_type

The Gram product type ("xtx" or "xxt").

num_fixed_params

Total number of fixed parameters.

num_free_params

Total number of free parameters.

num_params

Total number of parameters.

operands

Mapping from operand names to operand matrices or tensors.

param_defaults

Parameter defaults.

param_names

Parameter names.

param_specs

Parameter specifications.

param_trans

Parameter transforms.

repr_dict

Key-value pairs used to build the string representation.

shape

Output matrix shape.

__call__(free_params=None)[source]

Construct the matrix from a flat parameter tensor.

Must be implemented by subclasses. Implementations should convert free_params via build_params() to validate, include fixed parameters, and apply transforms before any computation.

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default: None.

Returns:

Constructed matrix of shape shape.

Return type:

torch.Tensor

manual_grad(free_params=None)[source]

Compute the Jacobian of __call__() with respect to trainable parameters using the product rule.

For \(\symbf{V} = \symbf{X}^\top \symbf{X}\):

\[\frac{\partial \symbf{V}}{\partial \theta_k} = \symbf{X}^\top \frac{\partial \symbf{X}}{\partial \theta_k} + \left(\frac{\partial \symbf{X}}{\partial \theta_k}\right)^\top \symbf{X}\]

For \(\symbf{V} = \symbf{X} \symbf{X}^\top\):

\[\frac{\partial \symbf{V}}{\partial \theta_k} = \frac{\partial \symbf{X}}{\partial \theta_k} \symbf{X}^\top + \symbf{X} \left(\frac{\partial \symbf{X}}{\partial \theta_k}\right)^\top\]

If \(\symbf{X}\) has no trainable parameters, returns (None, []).

Parameters:

free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default: None.

Returns:

(grad, grad_names), where grad is a 3D tensor of shape (num_free_params, *shape) and grad_names is a list of the corresponding parameter names. Returns (None, []) if all parameters are fixed.

Return type:

tuple

Example:

import torch
from torch_openreml.covariance import OperatorGram, LowerTriangularMatrix

x = LowerTriangularMatrix(3, 2)
op = OperatorGram(x, gram_type="xtx")
free_params = torch.tensor([0.0, 0.5, 1.0, 0.2, -0.3])
grad, grad_names = op.manual_grad(free_params)
grad
tensor([[[ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[ 1.0000,  1.0000],
         [ 1.0000,  0.0000]],

        [[ 0.0000,  0.5000],
         [ 0.5000,  2.0000]],

        [[ 0.4000, -0.3000],
         [-0.3000,  0.0000]],

        [[ 0.0000,  0.2000],
         [ 0.2000, -0.6000]]])
grad_names
['x/L_0_0', 'x/L_1_0', 'x/L_1_1', 'x/L_2_0', 'x/L_2_1']
op_xxt = OperatorGram(x, gram_type="xxt")
grad_xxt, grad_names_xxt  = op_xxt.manual_grad(free_params)
grad_xxt
tensor([[[ 0.0000,  0.5000,  0.2000],
         [ 0.5000,  0.0000,  0.0000],
         [ 0.2000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  1.0000,  0.2000],
         [ 0.0000,  0.2000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  2.0000, -0.3000],
         [ 0.0000, -0.3000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.5000],
         [ 0.0000,  0.5000,  0.4000]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  1.0000],
         [ 0.0000,  1.0000, -0.6000]]])
property gram_type

The Gram product type ("xtx" or "xxt").

Type:

str

property repr_dict

Key-value pairs used to build the string representation.

Type:

dict