Source code for torch_openreml.covariance.operator_block_diagonal
from torch_openreml.covariance.operator import Operator
import torch
[docs]
class BlockDiagonal(Operator):
def __init__(self, operands):
if len(operands) < 2:
raise ValueError("At least two operands are required")
super().__init__(None, operands)
def _get_or_build_intermediates(self, params):
cache = self.get_intermediates(params)
if cache is None:
v_groups = self.build_operands(params)
v = torch.block_diag(*v_groups)
row_offsets = []
col_offsets = []
n = 0
m = 0
for vg in v_groups:
rows, cols = vg.shape
row_offsets.append((n, n + rows))
col_offsets.append((m, m + cols))
n += rows
m += cols
cache = {
"v_groups": v_groups,
"v": v,
"row_offsets": row_offsets,
"col_offsets": col_offsets
}
self.set_intermediates(params, cache)
return cache
[docs]
def __call__(self, params):
cache = self._get_or_build_intermediates(params)
v = cache["v"]
self._shape = tuple(v.shape)
return v
[docs]
def manual_grad(self, params):
grad_groups, grad_name_groups = self.operands_grad(params)
cache = self._get_or_build_intermediates(params)
v = cache["v"]
v_groups = cache["v_groups"]
row_offsets = cache["row_offsets"]
col_offsets = cache["col_offsets"]
grad_list = []
grad_names = []
for i, grad in enumerate(grad_groups):
if grad is None:
continue
(r0, r1) = row_offsets[i]
(c0, c1) = col_offsets[i]
tmp = torch.zeros((grad.shape[0],) + tuple(v.shape),
dtype=params.dtype,
device=params.device)
tmp[:, r0:r1, c0:c1] = grad
grad_list.append(tmp)
grad_names.extend(grad_name_groups[i])
if len(grad_list) > 0:
grad = torch.cat(grad_list)
return grad, grad_names
else:
return None, []