Source code for torch_openreml.covariance.operator

from torch_openreml.covariance.matrix import Matrix
import torch

[docs] class Operator(Matrix): _repr_single_line = False def __init__(self, shape, operands): self.check_operands(operands) self._operands = operands param_names = [ f"{operand_name}/{name}" for operand_name, operand in operands.items() for name in getattr(operand, "param_names", []) ] super().__init__(shape, param_names, []) del self._no_grad_index
[docs] def check_operands(self, operands): if not isinstance(operands, dict): raise TypeError(f"operands must be a dict, got {type(operands).__name__}!") for key, value in operands.items(): if not isinstance(key, str): raise TypeError(f"Operand name must be a string, got {type(key).__name__}!") if "/" in key: raise ValueError(f"Invalid operand name '{key}': '/' is not allowed!") if not isinstance(value, (Matrix, torch.Tensor)): raise TypeError( f"Operand '{key}' must be a Matrix or torch.Tensor, " f"got {type(value).__name__}!" ) if not any(isinstance(v, Matrix) for v in operands.values()): raise TypeError("operands must include at least one Matrix!")
[docs] def set_no_grad(self, index=None, param_name=None): raise RuntimeError( "This operator only provides a view of no_grad_index. " "Set it on the covariance matrix that owns the parameters instead!" )
[docs] def trans_params(self, params): params = self.from_param_dict(params) self.check_params(params) result = [] for name, operand in self.operands.items(): if isinstance(operand, Matrix): operand_params = params[0:operand.num_params] params = params[operand.num_params:] result.append(operand.trans_params(operand_params)) return torch.cat(result)
[docs] def build_operands(self, params): params = self.from_param_dict(params) self.check_params(params) v_groups = [] for name, operand in self.operands.items(): if isinstance(operand, Matrix): operand_params = params[0:operand.num_params] params = params[operand.num_params:] v_groups.append(operand(operand_params)) else: v_groups.append(operand) return v_groups
[docs] def operands_grad(self, params): params = self.from_param_dict(params) self.check_params(params) grad_groups = [] grad_name_groups = [] for name, operand in self.operands.items(): if isinstance(operand, Matrix): operand_params = params[0:operand.num_params] params = params[operand.num_params:] grad, grad_names = operand.grad(operand_params) if grad is not None: grad_groups.append(grad) grad_name_groups.append([f"{name}/{grad_name}" for grad_name in grad_names]) else: grad_groups.append(None) grad_name_groups.append([]) else: grad_groups.append(None) grad_name_groups.append([]) return grad_groups, grad_name_groups
@property def operands(self): return self._operands @property def no_grad_index(self): result = [] total_num_params = 0 for name, operand in self._operands.items(): if isinstance(operand, Matrix): result.extend([index + total_num_params for index in operand.no_grad_index]) total_num_params = total_num_params + operand.num_params return result @property def repr_dict(self): return {"operands": self.operands}