torch_openreml.utils.augment

torch_openreml.utils.augment(*args)[source]

Horizontally concatenate two or more design matrices.

A convenience wrapper around torch.cat() along the column dimension. Useful for combining numeric and categorical design matrices into a single matrix.

Parameters:

*args (torch.Tensor) – Two or more design matrices with the same number of rows.

Returns:

Concatenated matrix of shape (n, sum of columns across all inputs).

Return type:

torch.Tensor

Example:

import torch
from torch_openreml.utils import augment

x1 = torch.ones(4, 2)
x2 = torch.zeros(4, 3)
augment(x1, x2)
tensor([[1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.]])