torch_openreml.utils.get_dtype¶
- torch_openreml.utils.get_dtype(*args)[source]¶
Validate and return the shared dtype of a collection of tensors.
- Parameters:
*args (torch.Tensor) – Zero or more tensors. All must share the same dtype.
- Returns:
The shared dtype of all tensors, or the PyTorch default dtype if no tensors are provided.
- Return type:
torch.dtype
- Raises:
ValueError – If any tensor has a different dtype than the first.
Example:
import torch from torch_openreml.utils import get_dtype x = torch.tensor([1.0, 2.0]) y = torch.tensor([3.0, 4.0]) get_dtype(x, y)
torch.float32