Source code for torch_openreml.covariance.design_matrix

"""
Design matrix.

This module provides a fixed design matrix for use  in linear
mixed-effects models. The matrix is constructed
from numeric or categorical input at initialisation and has no
trainable parameters.

Classes:
    DesignMatrix:
        A fixed design matrix constructed from numeric or categorical data.
"""

import torch
from torch_openreml.covariance.matrix import Matrix
from torch_openreml.utils import numeric_to_design_matrix, categorical_to_design_matrix

[docs] class DesignMatrix(Matrix): r""" Fixed design matrix constructed from numeric or categorical input. .. math:: \symbf{V} = \symbf{X} where :math:`\symbf{X}` is constructed from ``x`` at initialisation and remains fixed thereafter. This matrix has no trainable parameters, so :meth:`grad` always returns ``(None, [])``. Numeric input is passed to :func:`~torch_openreml.utils.numeric_to_design_matrix` and categorical string input to :func:`~torch_openreml.utils.categorical_to_design_matrix`. In both cases ``levels`` and ``drop_first`` control which columns are retained. """ def __init__(self, x, levels=None, drop_first=False, dtype=None, device=None): """ Initialize a fixed design matrix from numeric or categorical input. Args: x (torch.Tensor, list, or tuple): Input data. Either a numeric tensor or list, or a list of strings for categorical data. levels (list, optional): Explicit level ordering for categorical input, or bin edges for numeric input. drop_first (bool, optional): Whether to drop the first column. Defaults to ``False``. dtype (torch.dtype, optional): Desired dtype of the matrix. device (torch.device, optional): Desired device of the matrix. Raises: TypeError: If ``x`` is not a :class:`torch.Tensor`, list, or tuple. Example: .. jupyter-execute:: import torch from torch_openreml.covariance import DesignMatrix mat = DesignMatrix(torch.tensor([1.0, 2.0, 3.0, 4.0])) print(mat()) mat = DesignMatrix(["a", "b", "a", "c"]) print(mat()) mat = DesignMatrix(["a", "b", "a", "c"], levels=["c", "b", "a"]) print(mat()) """ if not isinstance(x, (torch.Tensor, list, tuple)): raise TypeError("'x' must be a torch.Tensor, a list or a tuple!") if torch.is_tensor(x): self._matrix = numeric_to_design_matrix(x, dtype=dtype or x.dtype, device=device or x.device) elif isinstance(x[0], str): self._matrix = categorical_to_design_matrix(x, levels, drop_first, dtype, device) else: self._matrix = numeric_to_design_matrix(x, levels, drop_first, dtype, device) super().__init__((self._matrix.shape[0], self._matrix.shape[1]), [], [])
[docs] def __call__(self, *args, **kwargs): return self._matrix
@property def repr_dict(self): return {"shape": self._shape}