torch_openreml.utils.get_device

torch_openreml.utils.get_device(*args)[source]

Validate and return the shared device of a collection of tensors.

Parameters:

*args (torch.Tensor) – Zero or more tensors. All must reside on the same device.

Returns:

The shared device of all tensors, the PyTorch default device if no tensors are provided.

Return type:

torch.device

Raises:

ValueError – If any tensor resides on a different device than the first.

Example:

import torch
from torch_openreml.utils import get_device

x = torch.tensor([1.0, 2.0])
y = torch.tensor([3.0, 4.0])
get_device(x, y)
device(type='cpu')