torch_openreml.utils.get_dtype

torch_openreml.utils.get_dtype(*args)[source]

Validate and return the shared dtype of a collection of tensors.

Parameters:

*args (torch.Tensor) – Zero or more tensors. All must share the same dtype.

Returns:

The shared dtype of all tensors, or the PyTorch default dtype if no tensors are provided.

Return type:

torch.dtype

Raises:

ValueError – If any tensor has a different dtype than the first.

Example:

import torch
from torch_openreml.utils import get_dtype

x = torch.tensor([1.0, 2.0])
y = torch.tensor([3.0, 4.0])
get_dtype(x, y)
torch.float32