Run PyTorch on Spyre device
Using Torch-Spyre, you can run PyTorch on spyre device as further described in this document.
Creating a Tensor
The Torch-Spyre adds the spyre device type to PyTorch. This device type works similarly to other PyTorch device types. The example below shows an example of creating a Torch-Spyre tensor:
import torch
x = [[1, 2], [3, 4]]
x = torch.tensor(x, dtype=torch.float16, device="spyre")
print(x)
print(x.device)
Running Tensor Operations
Torch-Spyre supported operations can be performed on spyre device in a similar way to using other devices.
For example, you can add spyre tensors together as below:
import torch
DEVICE = torch.device("spyre")
x = torch.rand(512, 1024, dtype=torch.float16).to(DEVICE)
y = torch.rand(512, 1024, dtype=torch.float16).to(DEVICE)
output = x + y # or torch.add(x, y)
print(output)
You can do matrix multiplication in a various ways as below:
import torch
DEVICE = torch.device("spyre")
x = torch.rand(512, 1024, dtype=torch.float16).to(DEVICE)
y = torch.rand(1024, 512, dtype=torch.float16).to(DEVICE)
output = torch.matmul(x, y)
print(f"Output of torch.matmul\n: {output}")
output = torch.mm(x, y)
print(f"Output of torch.mm\n: {output}")
output = x @ y
print(f"Output of matmul with @ operator\n: {output}")
And here is an example of using torch.compile:
import torch
DEVICE = torch.device("spyre")
x = torch.rand(512, 1024, dtype=torch.float16).to(DEVICE)
y = torch.rand(1024, 512, dtype=torch.float16).to(DEVICE)
c_matmul = torch.compile(torch.matmul)
output = c_matmul(x, y)
print(f"Output of matmul with torch.compile\n: {output}")
More Examples
Refer to the examples directory in this repository, which provides more examples of using PyTorch on spyre device.