torch_openreml.covariance.OperatorGram¶
- class torch_openreml.covariance.OperatorGram(x, gram_type='xtx')[source]¶
Bases:
OperatorGram matrix operator: \(\symbf{X}^\top \symbf{X}\) or \(\symbf{X} \symbf{X}^\top\).
Given a matrix \(\symbf{X}\) (which may be a fixed
torch.Tensoror a trainableMatrix), this operator computes its Gram product in the direction specified at construction:\[\begin{split}\symbf{V} = \begin{cases} \symbf{X}^\top \symbf{X} & \text{if } gram\_type = \texttt{"xtx"} \\ \symbf{X} \symbf{X}^\top & \text{if } gram\_type = \texttt{"xxt"} \end{cases}\end{split}\]If \(\symbf{X}\) has shape
(n, m), then"xtx"yields an(m, m)matrix and"xxt"yields an(n, n)matrix.When \(\symbf{X}\) is a
Matrix, its parameters are exposed through this operator and gradients are computed analytically via the product rule.Initialise a Gram operator.
- Parameters:
x (
torch.TensororMatrix) – The input matrix of shape(n, m).gram_type (str) – Which Gram product to compute. Must be one of
"xtx"(\(\symbf{X}^\top \symbf{X}\)) or"xxt"(\(\symbf{X} \symbf{X}^\top\)). Default:"xtx".
- Raises:
ValueError – If
gram_typeis not"xtx"or"xxt".
Example:
import torch from torch_openreml.covariance import OperatorGram, LowerTriangularMatrix x = LowerTriangularMatrix(3, 2) op = OperatorGram(x, gram_type="xtx") op()
tensor([[3., 2.], [2., 2.]])op_xxt = OperatorGram(x, gram_type="xxt") op_xxt()
tensor([[1., 1., 1.], [1., 2., 2.], [1., 2., 2.]])Methods
__call__([free_params])Construct the matrix from a flat parameter tensor.
auto_grad([free_params])Compute the Jacobian of
build()with respect to free parameters using automatic differentiation.build_operands([free_params])Evaluate each operand at the current free parameters.
build_params([free_params, include_fixed, ...])Construct the full parameter tensor by delegating to each operand.
get_intermediates(params)Retrieve cached intermediate computation results if still valid.
grad([free_params])Compute the Jacobian of
__call__()with respect to trainable parameters.manual_grad([free_params])Compute the Jacobian of
__call__()with respect to trainable parameters using the product rule.map_theta_to_dv(theta)An interface compatible with
torch_openreml.REMLthat maps parameters to the matrix Jacobian.map_theta_to_v(theta)An interface compatible with
torch_openreml.REMLthat maps parameters to a matrix.operands_grad([free_params])Compute the Jacobian of each operand with respect to its parameters.
reset_intermediates()Clear the intermediate computation cache.
set_intermediates(params, intermediates)Cache intermediate computation results keyed by parameter hash.
trans_grad([free_params])Compute the element-wise derivative of the free parameter transforms.
Attributes
fixed_param_defaultsFixed parameter defaults.
fixed_param_indexIndex of fixed parameters.
fixed_param_namesFixed parameter names.
fixed_param_transTransforms for fixed parameters.
free_param_defaultsFree parameter defaults.
free_param_indexIndex of free parameters.
free_param_namesFree parameter names.
free_param_transTransforms for free parameters.
The Gram product type (
"xtx"or"xxt").num_fixed_paramsTotal number of fixed parameters.
num_free_paramsTotal number of free parameters.
num_paramsTotal number of parameters.
operandsMapping from operand names to operand matrices or tensors.
param_defaultsParameter defaults.
param_namesParameter names.
param_specsParameter specifications.
param_transParameter transforms.
Key-value pairs used to build the string representation.
shapeOutput matrix shape.
- __call__(free_params=None)[source]¶
Construct the matrix from a flat parameter tensor.
Must be implemented by subclasses. Implementations should convert
free_paramsviabuild_params()to validate, include fixed parameters, and apply transforms before any computation.- Parameters:
free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default:
None.- Returns:
Constructed matrix of shape
shape.- Return type:
torch.Tensor
- manual_grad(free_params=None)[source]¶
Compute the Jacobian of
__call__()with respect to trainable parameters using the product rule.For \(\symbf{V} = \symbf{X}^\top \symbf{X}\):
\[\frac{\partial \symbf{V}}{\partial \theta_k} = \symbf{X}^\top \frac{\partial \symbf{X}}{\partial \theta_k} + \left(\frac{\partial \symbf{X}}{\partial \theta_k}\right)^\top \symbf{X}\]For \(\symbf{V} = \symbf{X} \symbf{X}^\top\):
\[\frac{\partial \symbf{V}}{\partial \theta_k} = \frac{\partial \symbf{X}}{\partial \theta_k} \symbf{X}^\top + \symbf{X} \left(\frac{\partial \symbf{X}}{\partial \theta_k}\right)^\top\]If \(\symbf{X}\) has no trainable parameters, returns
(None, []).- Parameters:
free_params (torch.Tensor or dict) – Flat 1D parameter tensor or parameter dictionary. If omitted, default values are used. Default:
None.- Returns:
(grad, grad_names), wheregradis a 3D tensor of shape(num_free_params, *shape)andgrad_namesis a list of the corresponding parameter names. Returns(None, [])if all parameters are fixed.- Return type:
tuple
Example:
import torch from torch_openreml.covariance import OperatorGram, LowerTriangularMatrix x = LowerTriangularMatrix(3, 2) op = OperatorGram(x, gram_type="xtx") free_params = torch.tensor([0.0, 0.5, 1.0, 0.2, -0.3]) grad, grad_names = op.manual_grad(free_params) grad
tensor([[[ 0.0000, 0.0000], [ 0.0000, 0.0000]], [[ 1.0000, 1.0000], [ 1.0000, 0.0000]], [[ 0.0000, 0.5000], [ 0.5000, 2.0000]], [[ 0.4000, -0.3000], [-0.3000, 0.0000]], [[ 0.0000, 0.2000], [ 0.2000, -0.6000]]])grad_names['x/L_0_0', 'x/L_1_0', 'x/L_1_1', 'x/L_2_0', 'x/L_2_1']
op_xxt = OperatorGram(x, gram_type="xxt") grad_xxt, grad_names_xxt = op_xxt.manual_grad(free_params) grad_xxt
tensor([[[ 0.0000, 0.5000, 0.2000], [ 0.5000, 0.0000, 0.0000], [ 0.2000, 0.0000, 0.0000]], [[ 0.0000, 0.0000, 0.0000], [ 0.0000, 1.0000, 0.2000], [ 0.0000, 0.2000, 0.0000]], [[ 0.0000, 0.0000, 0.0000], [ 0.0000, 2.0000, -0.3000], [ 0.0000, -0.3000, 0.0000]], [[ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.5000], [ 0.0000, 0.5000, 0.4000]], [[ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 1.0000], [ 0.0000, 1.0000, -0.6000]]])
- property gram_type¶
The Gram product type (
"xtx"or"xxt").- Type:
str
- property repr_dict¶
Key-value pairs used to build the string representation.
- Type:
dict