torch_spyre

When the torch_spyre package is installed, PyTorch picks it up through the torch.backends autoload entry point — no explicit import torch_spyre is needed. The Spyre backend registers itself on first use of torch and the public API is available under torch.spyre, mirroring the torch.cuda surface.

import torch

torch.spyre.is_available()
torch.spyre.device_count()

Device Management

torch.spyre.is_available() bool

Returns True if at least one Spyre device is available.

>>> torch.spyre.is_available()
True
torch.spyre.device_count() int

Returns the number of Spyre devices available.

>>> torch.spyre.device_count()
1
torch.spyre.current_device() int

Returns the index of the currently selected Spyre device.

>>> torch.spyre.current_device()
0
torch.spyre.set_device(idx)

Sets the current Spyre device.

Parameters:

idx (int) – Device index to set as current.

torch.spyre.is_initialized() bool

Returns True if the Spyre runtime has been initialized.

Note

torch.spyre.get_device_properties() is not yet exposed on the public torch.spyre namespace. The SpyreDeviceProperties dataclass and SpyreInterface.get_device_properties() exist internally and are used by the Inductor device interface (see torch_spyre/device/interface.py).

Random Number Generation

Preferred (device-agnostic): Use the PyTorch torch.accelerator API so that your code is portable across backends (CUDA, Spyre, etc.):

torch.accelerator.manual_seed(42)      # current device
torch.accelerator.manual_seed_all(42)  # all devices

Backend-specific alternative:

torch.spyre.manual_seed(seed)

Sets the seed for generating random numbers on the current Spyre device.

Parameters:

seed (int) – The desired seed.

Note

The public binding accepts a single seed argument. To target a specific device, either call set_device first, or use torch.spyre.manual_seed_all, which seeds every visible Spyre device.

torch.spyre.manual_seed_all(seed)

Sets the seed for generating random numbers on all Spyre devices.

Parameters:

seed (int) – The desired seed.

Streams

Streams allow overlapping execution of operations. The API mirrors torch.cuda streams.

class torch.spyre.Stream(device=None, priority=0)

Wrapper around a Spyre stream.

A stream is a linear sequence of execution that belongs to a specific device. Operations on different streams can run concurrently. The Stream object is itself a context manager: putting it in a with block sets it as the current stream for that block.

Parameters:
  • device (torch.device or int or str, optional) – Device for the stream. Accepts torch.device, int, or a string like "spyre" or "spyre:0". If None, the current device is used.

  • priority (int) – Priority class for the stream. 0 selects the low-priority pool; any non-zero value selects the high-priority pool. Each pool has 32 streams per device, allocated round-robin. Default: 0.

>>> s = torch.spyre.Stream()
>>> with torch.spyre.stream(s):
...     x = torch.randn(100, device="spyre", dtype=torch.float16)
synchronize()

Wait for all operations on this stream to complete.

query() bool

Returns True if all operations on this stream have completed.

device() torch.device

Returns the device associated with this stream. Unlike torch.cuda.Stream.device, this is a method, not a property.

id: int

The stream ID (read-only). 0 is the default stream, 1 to 32 are the low-priority streams, and 33 to 64 are the high-priority streams.

priority: int

The stream priority class (read-only). 0 for low-priority, anything non-zero for high-priority.

torch.spyre.stream(stream)

Pass-through helper for use inside a with block. The actual swap of the current stream is done by Stream.__enter__ and Stream.__exit__; calling stream(s) just returns s so the with form reads naturally.

Parameters:

stream (Stream) – The stream to use.

>>> s = torch.spyre.Stream()
>>> with torch.spyre.stream(s):
...     x = torch.randn(100, device="spyre", dtype=torch.float16)
torch.spyre.current_stream(device=None) Stream

Returns the currently active stream for the given device.

Parameters:

device (torch.device or int, optional) – Device to query. If None, uses the current device.

torch.spyre.default_stream(device=None) Stream

Returns the default stream (stream ID 0) for the given device.

Parameters:

device (torch.device or int, optional) – Device to query. If None, uses the current device.

torch.spyre.synchronize(device=None)

Waits for all operations on all streams to complete. If a device is specified, synchronizes only that device.

Parameters:

device (torch.device or int or str, optional) – Device to synchronize. If None, synchronizes all devices.

>>> torch.spyre.synchronize()          # sync all devices
>>> torch.spyre.synchronize("spyre:0") # sync device 0

Tensor Operations

Spyre tensors are created using the device="spyre" argument:

# Create a tensor on Spyre
x = torch.tensor([1, 2], dtype=torch.float16, device="spyre")

# Move an existing tensor to Spyre
y = cpu_tensor.to("spyre")

# Move back to CPU
z = x.cpu()

The default dtype for Spyre is torch.float16. See Tensor Layouts for details on how tensors are laid out in device memory.

Compilation

Spyre models are compiled using torch.compile with the "spyre" backend:

model = MyModel().to("spyre")
compiled = torch.compile(model, backend="spyre")
output = compiled(inputs)

See Running Models on Spyre for details and Supported Operations for the list of supported ops.

Tensor Layouts

Spyre uses a tiled memory layout that differs from PyTorch’s standard strided layout. The following classes and functions allow inspection and manipulation of device tensor layouts. See Tensor Layouts for background.

class torch_spyre._C.SpyreTensorLayout

Describes how a tensor is laid out in Spyre device memory. Each SpyreTensorLayout captures the tiling, padding, and dimension mapping required by the hardware.

Can be constructed in two ways:

# From host tensor metadata (automatic layout computation)
layout = SpyreTensorLayout(host_size=[4, 128], dtype=torch.float16)

# From explicit device layout parameters
layout = SpyreTensorLayout(
    device_size=[4, 2, 64],
    stride_map=[128, 64, 1],
    device_dtype=DataFormats.SEN169_FP16,
)
device_size: list[int]

Shape on device, including tiling dimensions and padding.

stride_map: list[int]

Host stride for each device dimension. A value of -1 indicates a synthetic or padded dimension with no corresponding host stride.

device_dtype: DataFormats

The on-device data format (e.g., SEN169_FP16).

elems_per_stick() int

Returns the number of elements per stick for this layout’s dtype.

class torch_spyre._C.DataFormats

Enumeration of Spyre on-device data formats. Each format defines the bit-level encoding used in device memory.

Common values:

SEN169_FP16

Spyre native 16-bit floating point (default for torch.float16).

IEEE_FP32

IEEE 754 single-precision floating point.

IEEE_FP16

IEEE 754 half-precision floating point.

BFLOAT16

Brain floating-point 16-bit format.

SEN143_FP8

Spyre native 8-bit floating point (E4M3 variant).

SEN152_FP8

Spyre native 8-bit floating point (E5M2 variant).

SENINT8

Spyre native 8-bit integer.

elems_per_stick() int

Returns the number of elements that fit in a single 128-byte stick for this data format.

torch_spyre._C.get_spyre_tensor_layout(tensor) SpyreTensorLayout

Returns the SpyreTensorLayout for a tensor that resides on a Spyre device.

Parameters:

tensor (torch.Tensor) – A Spyre device tensor.

Returns:

The device layout of the tensor.

Return type:

SpyreTensorLayout

>>> x = torch.randn(4, 128, dtype=torch.float16, device="spyre")
>>> layout = torch_spyre._C.get_spyre_tensor_layout(x)
>>> print(layout.device_size)
[4, 2, 64]
torch_spyre._C.set_spyre_tensor_layout(tensor, layout)

Sets the SpyreTensorLayout on a Spyre device tensor.

Parameters:
  • tensor (torch.Tensor) – A Spyre device tensor.

  • layout (SpyreTensorLayout) – The layout to assign.

Warnings

torch_spyre._C.get_downcast_warning() bool

Returns whether float32 → float16 downcast warnings are enabled.

torch_spyre._C.set_downcast_warning(enabled)

Enable or disable float32 → float16 downcast warnings.

Parameters:

enabled (bool) – True to enable warnings, False to suppress.

Can also be controlled via the TORCH_SPYRE_DOWNCAST_WARN environment variable.

Constants

torch_spyre.constants.DEVICE_NAME = "spyre"

The device name string used to register Spyre with PyTorch.

Environment Variables

Spyre runtime and compiler:

Variable

Purpose

TORCH_SPYRE_DEBUG=1

Enable C++ debug logging and -O0 builds

TORCH_SPYRE_DOWNCAST_WARN=0

Suppress float32 → float16 downcast warnings

SPYRE_INDUCTOR_LOG=1

Enable Spyre Inductor logging

SPYRE_INDUCTOR_LOG_LEVEL=DEBUG

Set Spyre Inductor log verbosity (DEBUG, INFO, WARNING, ERROR)

SPYRE_LOG_FILE=path

Redirect Spyre Inductor logs to a file

TORCH_SENDNN_LOG

SendNN library logging level (default: CRITICAL)

DT_DEEPRT_VERBOSE

DeepTools runtime verbosity (default: -1, disabled)

DTLOG_LEVEL

DeepTools log level (default: error)

Compiler / Inductor configuration (torch_spyre/_inductor/config.py):

Variable

Purpose

SENCORES

Number of Spyre cores (1–32, default 32)

LX_PLANNING=1

Enable LX scratchpad memory planning during the pre-scheduling pass

DXP_LX_FRAC_AVAIL

Fraction of LX scratchpad available to the planner

Device enumeration (torch_spyre/csrc/spyre_device_enum.cpp):

Variable

Purpose

AIU_WORLD_SIZE

Override the visible Spyre device count

SPYRE_DEVICES

Comma-separated list of device indices to expose

FLEX_DEVICE

Select the underlying flex runtime mode (PF / VF)

Internal:

Variable

Purpose

IS_INDUCTOR_SPAWNED_SUBPROCESS

Marker set by Inductor when spawning compile subprocesses

Useful PyTorch knobs (not defined by torch-spyre):

Variable

Purpose

TORCH_LOGS="+inductor"

Verbose PyTorch Inductor logging

TORCH_COMPILE_DEBUG=1

Dump Inductor debug artifacts