Source code for torch_openreml.utils

"""
General utilities.

Functions:
    get_device:
        Validate and return the shared device of a collection of tensors.
    get_dtype:
        Validate and return the shared dtype of a collection of tensors.
    numeric_to_design_matrix:
        Construct a design matrix from one or more numeric vectors or tensors.
    categorical_to_design_matrix:
        Construct a one-hot encoded design matrix from a categorical vector.
    augment:
        Horizontally concatenate design matrices.
    interaction:
        Construct an interaction term from two or more categorical vectors.
"""

import torch
import pandas as pd

[docs] def get_device(*args): """ Validate and return the shared device of a collection of tensors. Args: *args (torch.Tensor): Zero or more tensors. All must reside on the same device. Returns: torch.device: The shared device of all tensors, the PyTorch default device if no tensors are provided. Raises: ValueError: If any tensor resides on a different device than the first. Example: .. jupyter-execute:: import torch from torch_openreml.utils import get_device x = torch.tensor([1.0, 2.0]) y = torch.tensor([3.0, 4.0]) get_device(x, y) """ if len(args) == 0: return torch.get_default_device() device = args[0].device for i, t in enumerate(args): if t.device != device: raise ValueError(f"Device mismatch at arg {i}: {t.device} != {device}") return device
[docs] def get_dtype(*args): """ Validate and return the shared dtype of a collection of tensors. Args: *args (torch.Tensor): Zero or more tensors. All must share the same dtype. Returns: torch.dtype: The shared dtype of all tensors, or the PyTorch default dtype if no tensors are provided. Raises: ValueError: If any tensor has a different dtype than the first. Example: .. jupyter-execute:: import torch from torch_openreml.utils import get_dtype x = torch.tensor([1.0, 2.0]) y = torch.tensor([3.0, 4.0]) get_dtype(x, y) """ if len(args) == 0: return torch.get_default_dtype() dtype = args[0].dtype for i, t in enumerate(args): if t.dtype != dtype: raise ValueError(f"Dtype mismatch at arg {i}: {t.dtype} != {dtype}") return dtype
[docs] def numeric_to_design_matrix(*args, dtype=None, device=None): """ Construct a design matrix from one or more numeric vectors or tensors. Each argument becomes one column of the resulting design matrix. All inputs must have the same length along the first dimension. Args: *args (torch.Tensor, list, tuple or pd.Series): One or more numeric vectors of equal length. Lists, tuples and pd.Series are converted to tensors. dtype (torch.dtype, optional): Desired dtype of the matrix. Defaults to the PyTorch default dtype. device (torch.device, optional): Desired device of the matrix. Defaults to the PyTorch default device. Returns: torch.Tensor: Design matrix of shape ``(n, num_args)``. Raises: ValueError: If no inputs are provided, or if inputs have inconsistent lengths. TypeError: If any input is not a tensor, list, or tuple. Example: .. jupyter-execute:: import torch from torch_openreml.utils import numeric_to_design_matrix x1 = torch.tensor([1.0, 2.0, 3.0]) x2 = torch.tensor([4.0, 5.0, 6.0]) numeric_to_design_matrix(x1, x2) """ if len(args) == 0: raise ValueError("At least one input is required.") cols = [] n = None dtype = dtype or torch.get_default_dtype() device = device or torch.get_default_device() for i, x in enumerate(args): if isinstance(x, pd.Series): x = x.to_list() if torch.is_tensor(x): x = x.to(dtype=dtype, device=device) elif isinstance(x, (list, tuple)): x = torch.tensor(x, dtype=dtype, device=device) else: raise TypeError("'x' must be either a tensor or a list/tuple of values!") if len(x.shape) == 2: x.unsqueeze_(0) if n is None: n = x.shape[0] elif x.shape[0] != n: raise ValueError(f"Inconsistent lengths at argument {i}: expected {n}, got {x.shape[0]}") cols.append(x) return torch.stack(cols, dim=1)
[docs] def augment(*args): """ Horizontally concatenate two or more design matrices. A convenience wrapper around :func:`torch.cat` along the column dimension. Useful for combining numeric and categorical design matrices into a single matrix. Args: *args (torch.Tensor): Two or more design matrices with the same number of rows. Returns: torch.Tensor: Concatenated matrix of shape ``(n, sum of columns across all inputs)``. Example: .. jupyter-execute:: import torch from torch_openreml.utils import augment x1 = torch.ones(4, 2) x2 = torch.zeros(4, 3) augment(x1, x2) """ return torch.cat(args, dim=1)
[docs] def interaction(*args, sep="\u22C8"): """ Construct an interaction term from two or more categorical vectors. Joins the corresponding elements of each input vector into a single string separated by ``sep``, producing a new categorical vector whose levels represent combined factor combinations. Args: *args (list or tuple of str): Two or more categorical vectors of equal length. sep (str, optional): Separator inserted between joined levels. Defaults to ``"⋈"`` (U+22C8, the bowtie symbol). Returns: list of str: Interaction vector of the same length as the inputs. Raises: ValueError: If no inputs are provided. TypeError: If any input is not a list or tuple of strings. Example: .. jupyter-execute:: from torch_openreml.utils import interaction a = ["control", "control", "treatment"] b = ["male", "female", "male"] interaction(a, b) """ if len(args) == 0: raise ValueError("At least one input is required.") for i, arg in enumerate(args): if not isinstance(arg, (list, tuple)): raise TypeError(f"Argument {i} is not list/tuple of strings!") return [sep.join(parts) for parts in zip(*args)]
def n_distinct(x): """ Count the unique elements of a list/tuple of strings. Args: x (list or tuple): Input. Returns: int: Number of unique elements. Example: .. jupyter-execute:: from torch_openreml.utils import n_distinct n_distinct(["a", "b", "a", "c"]) """ if isinstance(x, pd.Series): x = x.to_list() if isinstance(x, (list, tuple)): return len(set(x)) else: raise TypeError(f"Argument x is not a list/tuple of strings!")