torch_spyre
When the torch_spyre package is installed, PyTorch picks it up
through the torch.backends autoload entry point — no explicit
import torch_spyre is needed. The Spyre backend registers itself
on first use of torch and the public API is available under
torch.spyre, mirroring the torch.cuda surface.
import torch
torch.spyre.is_available()
torch.spyre.device_count()
Device Management
- torch.spyre.is_available() bool
Returns
Trueif at least one Spyre device is available.>>> torch.spyre.is_available() True
- torch.spyre.device_count() int
Returns the number of Spyre devices available.
>>> torch.spyre.device_count() 1
- torch.spyre.current_device() int
Returns the index of the currently selected Spyre device.
>>> torch.spyre.current_device() 0
- torch.spyre.set_device(idx)
Sets the current Spyre device.
- Parameters:
idx (int) – Device index to set as current.
Note
torch.spyre.get_device_properties() is not yet exposed on the public
torch.spyre namespace. The SpyreDeviceProperties dataclass and
SpyreInterface.get_device_properties() exist internally and are used
by the Inductor device interface (see torch_spyre/device/interface.py).
Random Number Generation
Preferred (device-agnostic): Use the PyTorch torch.accelerator API so
that your code is portable across backends (CUDA, Spyre, etc.):
torch.accelerator.manual_seed(42) # current device
torch.accelerator.manual_seed_all(42) # all devices
Backend-specific alternative:
- torch.spyre.manual_seed(seed)
Sets the seed for generating random numbers on the current Spyre device.
- Parameters:
seed (int) – The desired seed.
Note
The public binding accepts a single
seedargument. To target a specific device, either callset_devicefirst, or usetorch.spyre.manual_seed_all, which seeds every visible Spyre device.
Streams
Streams allow overlapping execution of operations. The API mirrors
torch.cuda streams.
- class torch.spyre.Stream(device=None, priority=0)
Wrapper around a Spyre stream.
A stream is a linear sequence of execution that belongs to a specific device. Operations on different streams can run concurrently. The
Streamobject is itself a context manager: putting it in awithblock sets it as the current stream for that block.- Parameters:
device (torch.device or int or str, optional) – Device for the stream. Accepts
torch.device,int, or a string like"spyre"or"spyre:0". IfNone, the current device is used.priority (int) – Priority class for the stream.
0selects the low-priority pool; any non-zero value selects the high-priority pool. Each pool has 32 streams per device, allocated round-robin. Default:0.
>>> s = torch.spyre.Stream() >>> with torch.spyre.stream(s): ... x = torch.randn(100, device="spyre", dtype=torch.float16)
- synchronize()
Wait for all operations on this stream to complete.
- device() torch.device
Returns the device associated with this stream. Unlike
torch.cuda.Stream.device, this is a method, not a property.
- torch.spyre.stream(stream)
Pass-through helper for use inside a
withblock. The actual swap of the current stream is done byStream.__enter__andStream.__exit__; callingstream(s)just returnssso thewithform reads naturally.- Parameters:
stream (Stream) – The stream to use.
>>> s = torch.spyre.Stream() >>> with torch.spyre.stream(s): ... x = torch.randn(100, device="spyre", dtype=torch.float16)
- torch.spyre.current_stream(device=None) Stream
Returns the currently active stream for the given device.
- Parameters:
device (torch.device or int, optional) – Device to query. If
None, uses the current device.
- torch.spyre.default_stream(device=None) Stream
Returns the default stream (stream ID 0) for the given device.
- Parameters:
device (torch.device or int, optional) – Device to query. If
None, uses the current device.
- torch.spyre.synchronize(device=None)
Waits for all operations on all streams to complete. If a device is specified, synchronizes only that device.
- Parameters:
device (torch.device or int or str, optional) – Device to synchronize. If
None, synchronizes all devices.
>>> torch.spyre.synchronize() # sync all devices >>> torch.spyre.synchronize("spyre:0") # sync device 0
Tensor Operations
Spyre tensors are created using the device="spyre" argument:
# Create a tensor on Spyre
x = torch.tensor([1, 2], dtype=torch.float16, device="spyre")
# Move an existing tensor to Spyre
y = cpu_tensor.to("spyre")
# Move back to CPU
z = x.cpu()
The default dtype for Spyre is torch.float16. See
Tensor Layouts for details on how tensors are
laid out in device memory.
Compilation
Spyre models are compiled using torch.compile with the "spyre"
backend:
model = MyModel().to("spyre")
compiled = torch.compile(model, backend="spyre")
output = compiled(inputs)
See Running Models on Spyre for details and Supported Operations for the list of supported ops.
Tensor Layouts
Spyre uses a tiled memory layout that differs from PyTorch’s standard strided layout. The following classes and functions allow inspection and manipulation of device tensor layouts. See Tensor Layouts for background.
- class torch_spyre._C.SpyreTensorLayout
Describes how a tensor is laid out in Spyre device memory. Each
SpyreTensorLayoutcaptures the tiling, padding, and dimension mapping required by the hardware.Can be constructed in two ways:
# From host tensor metadata (automatic layout computation) layout = SpyreTensorLayout(host_size=[4, 128], dtype=torch.float16) # From explicit device layout parameters layout = SpyreTensorLayout( device_size=[4, 2, 64], stride_map=[128, 64, 1], device_dtype=DataFormats.SEN169_FP16, )
- stride_map: list[int]
Host stride for each device dimension. A value of -1 indicates a synthetic or padded dimension with no corresponding host stride.
- device_dtype: DataFormats
The on-device data format (e.g.,
SEN169_FP16).
- class torch_spyre._C.DataFormats
Enumeration of Spyre on-device data formats. Each format defines the bit-level encoding used in device memory.
Common values:
- SEN169_FP16
Spyre native 16-bit floating point (default for
torch.float16).
- IEEE_FP32
IEEE 754 single-precision floating point.
- IEEE_FP16
IEEE 754 half-precision floating point.
- BFLOAT16
Brain floating-point 16-bit format.
- SEN143_FP8
Spyre native 8-bit floating point (E4M3 variant).
- SEN152_FP8
Spyre native 8-bit floating point (E5M2 variant).
- SENINT8
Spyre native 8-bit integer.
- torch_spyre._C.get_spyre_tensor_layout(tensor) SpyreTensorLayout
Returns the
SpyreTensorLayoutfor a tensor that resides on a Spyre device.- Parameters:
tensor (torch.Tensor) – A Spyre device tensor.
- Returns:
The device layout of the tensor.
- Return type:
>>> x = torch.randn(4, 128, dtype=torch.float16, device="spyre") >>> layout = torch_spyre._C.get_spyre_tensor_layout(x) >>> print(layout.device_size) [4, 2, 64]
- torch_spyre._C.set_spyre_tensor_layout(tensor, layout)
Sets the
SpyreTensorLayouton a Spyre device tensor.- Parameters:
tensor (torch.Tensor) – A Spyre device tensor.
layout (SpyreTensorLayout) – The layout to assign.
Warnings
Constants
- torch_spyre.constants.DEVICE_NAME = "spyre"
The device name string used to register Spyre with PyTorch.
Environment Variables
Spyre runtime and compiler:
Variable |
Purpose |
|---|---|
|
Enable C++ debug logging and |
|
Suppress float32 → float16 downcast warnings |
|
Enable Spyre Inductor logging |
|
Set Spyre Inductor log verbosity (DEBUG, INFO, WARNING, ERROR) |
|
Redirect Spyre Inductor logs to a file |
|
SendNN library logging level (default: |
|
DeepTools runtime verbosity (default: |
|
DeepTools log level (default: |
Compiler / Inductor configuration (torch_spyre/_inductor/config.py):
Variable |
Purpose |
|---|---|
|
Number of Spyre cores (1–32, default 32) |
|
Enable LX scratchpad memory planning during the pre-scheduling pass |
|
Fraction of LX scratchpad available to the planner |
Device enumeration (torch_spyre/csrc/spyre_device_enum.cpp):
Variable |
Purpose |
|---|---|
|
Override the visible Spyre device count |
|
Comma-separated list of device indices to expose |
|
Select the underlying flex runtime mode (PF / VF) |
Internal:
Variable |
Purpose |
|---|---|
|
Marker set by Inductor when spawning compile subprocesses |
Useful PyTorch knobs (not defined by torch-spyre):
Variable |
Purpose |
|---|---|
|
Verbose PyTorch Inductor logging |
|
Dump Inductor debug artifacts |