Source code for torch_openreml.covariance.operator_sum

from torch_openreml.covariance.operator import Operator
import torch

[docs] class Sum(Operator): def __init__(self, operands): if len(operands) < 2: raise ValueError("At least two operands are required") super().__init__(None, operands)
[docs] def __call__(self, params): v_groups = self.build_operands(params) v = sum(v_groups) self._shape = tuple(v.shape) return v
[docs] def manual_grad(self, params): grad_groups, grad_name_groups = self.operands_grad(params) grad_groups = [grad for grad in grad_groups if grad is not None] if len(grad_groups) > 0: grad = torch.cat(grad_groups) grad_names = [name for group in grad_name_groups for name in group] return grad, grad_names else: return None, []