torch_openreml.utils.get_device¶
- torch_openreml.utils.get_device(*args)[source]¶
Validate and return the shared device of a collection of tensors.
- Parameters:
*args (torch.Tensor) – Zero or more tensors. All must reside on the same device.
- Returns:
The shared device of all tensors, the PyTorch default device if no tensors are provided.
- Return type:
torch.device
- Raises:
ValueError – If any tensor resides on a different device than the first.
Example:
import torch from torch_openreml.utils import get_device x = torch.tensor([1.0, 2.0]) y = torch.tensor([3.0, 4.0]) get_device(x, y)
device(type='cpu')