Source code for torch_openreml.covariance.operator_linear_propagation

from torch_openreml.covariance.operator import Operator
import torch


[docs] class LinearPropagation(Operator): def __init__(self, operands): if len(operands) != 2: raise ValueError("Two operands are required") super().__init__(None, operands) def _get_or_build_intermediates(self, params): cache = self.get_intermediates(params) if cache is None: v_groups = self.build_operands(params) z = v_groups[0] g = v_groups[1] v = z @ g @ z.T cache = {"z": z, "g": g, "v": v} self.set_intermediates(params, cache) return cache
[docs] def __call__(self, params): cache = self._get_or_build_intermediates(params) v = cache["v"] self._shape = tuple(v.shape) return v
[docs] def manual_grad(self, params): grad_groups, grad_name_groups = self.operands_grad(params) cache = self._get_or_build_intermediates(params) z = cache["z"] g = cache["g"] grad_list = [] grad_names = [] dz = grad_groups[0] if dz is not None: grad_z = dz @ g @ z.T + z @ g @ dz.mT grad_list.append(grad_z) grad_names.extend(grad_name_groups[0]) dg = grad_groups[1] if dg is not None: grad_g = z @ dg @ z.T grad_list.append(grad_g) grad_names.extend(grad_name_groups[1]) if len(grad_list) > 0: grad = torch.cat(grad_list) return grad, grad_names else: return None, []