"""
General utilities.
Functions:
get_device:
Validate and return the shared device of a collection of tensors.
get_dtype:
Validate and return the shared dtype of a collection of tensors.
numeric_to_design_matrix:
Construct a design matrix from one or more numeric vectors or tensors.
categorical_to_design_matrix:
Construct a one-hot encoded design matrix from a categorical vector.
augment:
Horizontally concatenate design matrices.
interaction:
Construct an interaction term from two or more categorical vectors.
"""
import torch
[docs]
def get_device(*args):
"""
Validate and return the shared device of a collection of tensors.
Args:
*args (torch.Tensor): Zero or more tensors. All must reside on
the same device.
Returns:
torch.device: The shared device of all tensors, the PyTorch
default device if no tensors are provided.
Raises:
ValueError: If any tensor resides on a different device than the first.
Example:
.. jupyter-execute::
import torch
from torch_openreml.utils import get_device
x = torch.tensor([1.0, 2.0])
y = torch.tensor([3.0, 4.0])
get_device(x, y)
"""
if len(args) == 0:
return torch.get_default_device()
device = args[0].device
for i, t in enumerate(args):
if t.device != device:
raise ValueError(f"Device mismatch at arg {i}: {t.device} != {device}")
return device
[docs]
def get_dtype(*args):
"""
Validate and return the shared dtype of a collection of tensors.
Args:
*args (torch.Tensor): Zero or more tensors. All must share the
same dtype.
Returns:
torch.dtype: The shared dtype of all tensors, or the PyTorch
default dtype if no tensors are provided.
Raises:
ValueError: If any tensor has a different dtype than the first.
Example:
.. jupyter-execute::
import torch
from torch_openreml.utils import get_dtype
x = torch.tensor([1.0, 2.0])
y = torch.tensor([3.0, 4.0])
get_dtype(x, y)
"""
if len(args) == 0:
return torch.get_default_dtype()
dtype = args[0].dtype
for i, t in enumerate(args):
if t.dtype != dtype:
raise ValueError(f"Dtype mismatch at arg {i}: {t.dtype} != {dtype}")
return dtype
[docs]
def numeric_to_design_matrix(*args, dtype=None, device=None):
"""
Construct a design matrix from one or more numeric vectors or tensors.
Each argument becomes one column of the resulting design matrix. All
inputs must have the same length along the first dimension.
Args:
*args (torch.Tensor, list, or tuple): One or more numeric vectors
of equal length. Lists and tuples are converted to tensors.
dtype (torch.dtype, optional): Desired dtype of the matrix.
Defaults to the PyTorch default dtype.
device (torch.device, optional): Desired device of the matrix.
Defaults to the PyTorch default device.
Returns:
torch.Tensor: Design matrix of shape ``(n, num_args)``.
Raises:
ValueError: If no inputs are provided, or if inputs have
inconsistent lengths.
TypeError: If any input is not a tensor, list, or tuple.
Example:
.. jupyter-execute::
import torch
from torch_openreml.utils import numeric_to_design_matrix
x1 = torch.tensor([1.0, 2.0, 3.0])
x2 = torch.tensor([4.0, 5.0, 6.0])
numeric_to_design_matrix(x1, x2)
"""
if len(args) == 0:
raise ValueError("At least one input is required.")
cols = []
n = None
dtype = dtype or torch.float32
device = device or "cpu"
for i, x in enumerate(args):
if torch.is_tensor(x):
x = x.to(dtype=dtype, device=device)
elif isinstance(x, (list, tuple)):
x = torch.tensor(x, dtype=dtype, device=device)
else:
raise TypeError("'x' must be either a tensor or a list/tuple of values!")
if len(x.shape) == 2:
x.unsqueeze_(0)
if n is None:
n = x.shape[0]
elif x.shape[0] != n:
raise ValueError(f"Inconsistent lengths at argument {i}: expected {n}, got {x.shape[0]}")
cols.append(x)
return torch.stack(cols, dim=1)
[docs]
def categorical_to_design_matrix(x, levels=None, drop_first=False, dtype=None, device=None):
"""
Construct a one-hot encoded design matrix from a categorical vector.
Each unique level in ``x`` becomes one column in the output matrix.
The column ordering follows ``levels`` if provided, otherwise the
sorted unique values of ``x``.
Args:
x (list or tuple of str): Categorical vector of string labels.
levels (list of str, optional): Explicit level ordering. Must
contain exactly the same unique values as ``x``. Defaults to
the sorted unique values of ``x``.
drop_first (bool, optional): Whether to drop the first column to
avoid multicollinearity. Defaults to ``False``.
dtype (torch.dtype, optional): Desired dtype of the matrix.
Defaults to the PyTorch default dtype.
device (torch.device, optional): Desired device of the matrix.
Defaults to the PyTorch default device.
Returns:
torch.Tensor: One-hot encoded matrix of shape
``(len(x), len(levels))`` or ``(len(x), len(levels) - 1)`` if
``drop_first=True``.
Raises:
ValueError: If ``levels`` does not match the number of unique
values in ``x``.
Example:
.. jupyter-execute::
import torch
from torch_openreml.utils import categorical_to_design_matrix
print(categorical_to_design_matrix(["a", "b", "a", "c"]))
print(categorical_to_design_matrix(["a", "b", "a", "c"], drop_first=True))
"""
dtype = dtype or torch.float32
device = device or "cpu"
if levels is None:
levels = sorted(set(x))
if len(levels) != len(set(x)):
raise ValueError("'levels' must match the number of unique values in 'x'!")
level_to_idx = {lev: i for i, lev in enumerate(levels)}
idx = torch.tensor([level_to_idx[v] for v in x], device=device, dtype=torch.long)
z = torch.nn.functional.one_hot(idx, num_classes=len(levels)).to(dtype=dtype)
if drop_first:
z = z[:, 1:]
return z
[docs]
def augment(*args):
"""
Horizontally concatenate two or more design matrices.
A convenience wrapper around :func:`torch.cat` along the column
dimension. Useful for combining numeric and categorical design
matrices into a single matrix.
Args:
*args (torch.Tensor): Two or more design matrices with the same
number of rows.
Returns:
torch.Tensor: Concatenated matrix of shape
``(n, sum of columns across all inputs)``.
Example:
.. jupyter-execute::
import torch
from torch_openreml.utils import augment
x1 = torch.ones(4, 2)
x2 = torch.zeros(4, 3)
augment(x1, x2)
"""
return torch.cat(args, dim=1)
[docs]
def interaction(*args, sep="\u22C8"):
"""
Construct an interaction term from two or more categorical vectors.
Joins the corresponding elements of each input vector into a single
string separated by ``sep``, producing a new categorical vector whose
levels represent combined factor combinations.
Args:
*args (list or tuple of str): Two or more categorical vectors of
equal length.
sep (str, optional): Separator inserted between joined levels.
Defaults to ``"⋈"`` (U+22C8, the bowtie symbol).
Returns:
list of str: Interaction vector of the same length as the inputs.
Raises:
ValueError: If no inputs are provided.
TypeError: If any input is not a list or tuple of strings.
Example:
.. jupyter-execute::
from torch_openreml.utils import interaction
a = ["control", "control", "treatment"]
b = ["male", "female", "male"]
interaction(a, b)
"""
if len(args) == 0:
raise ValueError("At least one input is required.")
for i, arg in enumerate(args):
if not isinstance(arg, (list, tuple)):
raise TypeError(f"Argument {i} is not list/tuple of strings!")
return [sep.join(parts) for parts in zip(*args)]