Coarse-Tiling Loop IR for the Spyre Backend
Background
Spyre’s compilation pipeline runs a sequence of optimization passes over
ir.Operation objects in CustomPreSchedulingPasses, before Inductor’s
Scheduler is constructed. One planned optimization is coarse-level
tiling: take a sequence of operations that share an iteration space
dimension, split that dimension into K chunks (where K may be a symbolic
shape), and emit the body operations inside a counted outer loop. This
is the key program transformation for working set reduction – a tiling
of the computation in the time domain that enables effective scratchpad
utilization by reshaping the computation so that most tensors can be
allocated to the scratchpad.
The output of this pass needs to survive through:
Inductor’s
Scheduler(which wraps eachir.Operationin aSchedulerNode)Spyre’s
SuperDSCScheduling.codegen_node()(which drivesSpyreKernelto produceOpSpecobjects)Downstream SDSC compilation (which needs an explicit loop count to generate correct hardware instructions)
This document describes how that loop structure is represented, transported, and consumed.
Quick navigation:
Design Overview
The tiling loop structure must be created early (before work division sees the iteration space) and preserved intact through scheduling and codegen so that the hardware executes the reduced per-iteration working set — not the full pre-tiling range. The design has three layers that correspond to the three pipeline stages above.
Pre-scheduling IR pass (CustomPreSchedulingPasses)
└─ stamps loop_group_id + loop_count on each ir.Operation
└─ rewrites each op's ranges (divides the tiled dimension by K)
↓ Inductor Scheduler wraps each ir.Operation → SchedulerNode
↓ CustomPostFusionPasses fires
Post-fusion scheduler pass (build_loop_scheduler_nodes)
└─ runs BEFORE spyre_fuse_nodes
└─ scans list[BaseSchedulerNode] for runs sharing a loop_group_id
└─ wraps each run in a CountedLoopSchedulerNode(count=K, snodes=[...])
└─ spyre_fuse_nodes runs after; CountedLoopSchedulerNode.can_fuse=False
prevents cross-group merging
↓ Scheduler calls SuperDSCScheduling.codegen_node()
codegen_node
└─ receives CountedLoopSchedulerNode
└─ drives SpyreKernel for the inner ops, collecting inner OpSpecs
└─ wraps them in LoopSpec(count=K, body=[OpSpec, ...])
└─ LoopSpec is serialized alongside OpSpec in codegen_kernel()
Small Example
Consider two chained pointwise operations over [1024, 4096] tensors, where
A=1024 names the row dimension and B=4096 names the column dimension:
from torch_spyre._inductor import spyre_hint
from torch_spyre._inductor.propagate_named_dims import declare_tensor_dim, name_tensor_dims
A, B = 1024, 4096
declare_tensor_dim("A", A)
declare_tensor_dim("B", B)
a = torch.randn(A, B, dtype=torch.float16).to("spyre")
b = torch.randn(A, B, dtype=torch.float16).to("spyre")
c = torch.randn(A, B, dtype=torch.float16).to("spyre")
name_tensor_dims(a, ["A", "B"])
name_tensor_dims(b, ["A", "B"])
name_tensor_dims(c, ["A", "B"])
def f(a, b, c):
with spyre_hint(slices={"A": 2}): # outer loop: 2 iterations over rows
with spyre_hint(slices={"B": 4}): # inner loop: 4 iterations over cols
y = a + b
z = y * c
return z
Both operations are placed in a single tiling group with K=2 in the outer
loop (splitting the 1024 rows into 2 groups of 512) and M=4 in the inner
loop (splitting the 4096 columns into 4 groups of 1024). Each inner-loop
iteration processes a 512 × 1024 tile (1/8th of the full tensor), enabling
the intermediate result y to remain in scratchpad across both operations
within the tile.
This example is the canonical small example tested by
test_hint_nested_loop_with_scratchpad in
tests/inductor/test_coarse_tile_e2e.py.
What the coarse-tiling pass stamps
coarse_tile() sees this as a nested group spec and stamps the following
attributes on both ir.Operation objects:
op.loop_group_id = (0, 0) # depth-2 path: group 0, inner slot 0
op.loop_count = [2, 4] # [K_outer, M_inner]
op.loop_tiled_dims = [[0], [1]] # outer loop tiles dim 0; inner tiles dim 1
_divide_ranges is applied once per level in outermost-first order:
Outer level
(K=2, [dim 0]):data.ranges [1024, 4096] → [512, 4096]Inner level
(M=4, [dim 1]):data.ranges [512, 4096] → [512, 1024]
The per-inner-iteration data.ranges for both ops is [512, 1024].
LoopLevel IR after CustomPreSchedulingPasses
After coarse_tile, span_reduction, work_distribution, and
scratchpad_planning have all run, the two ComputedBuffer objects look like
this (the _format_operations representation with loop attributes added):
buf0: ComputedBuffer # y = a + b
layout = FixedTiledLayout(size=[512, 1024], stride=[1024, 1],
device_size=[16, 512, 64]) # per-tile shape
op_it_space_splits = {1024: 32} # work division: 32 cores along dim 1
loop_group_id = (0, 0)
loop_count = [2, 4]
loop_tiled_dims = [[0], [1]]
Pointwise(
ranges=[512, 1024], # per-tile iteration space
inner_fn: load(a, i1 + 4096*i0)
load(b, i1 + 4096*i0)
return a + b
)
buf1: ComputedBuffer # z = y * c
layout = FixedTiledLayout(size=[512, 1024], stride=[1024, 1],
device_size=[16, 512, 64]) # per-tile shape
op_it_space_splits = {1024: 32}
loop_group_id = (0, 0)
loop_count = [2, 4]
loop_tiled_dims = [[0], [1]]
Pointwise(
ranges=[512, 1024],
inner_fn: load(buf0, i1 + 4096*i0) # reads y
load(c, i1 + 4096*i0)
return y * c
)
Key points:
Both ops share the same
loop_group_id = (0, 0),loop_count = [2, 4], andloop_tiled_dims = [[0], [1]]— this is whatbuild_loop_scheduler_nodesuses to wrap them together in aCountedLoopSchedulerNode.ranges = [512, 1024]is the per-tile iteration space (1/8th of the full tensor). Work division and codegen see only this reduced space; the loop trip counts carry the information needed to reconstruct the full addressing.layout.size = [512, 1024]matches the per-tileranges. The layout describes the smaller per-tile output buffer allocated for each loop iteration. Per-iteration addressing into the full HBM region is handled bytiled_symbols/affine.applyinbundle.mlirat runtime.op_it_space_splits = {1024: 32}is stamped bywork_distribution: the coefficient1024identifies the per-tile stride-1 dimension (columns after tiling), and32is the number of cores dividing that dimension’s work.buf0(y) is the intermediate result. At this point its layout is aFixedTiledLayoutwithsize=[512, 1024];scratchpad_planninglater assigns itallocation={'lx': 0}, placing it in LX scratchpad memory at address 0. Becauseyis produced and fully consumed within the same tile iteration and its per-tile size fits in scratchpad, no HBM allocation is needed for it at all.
Generated OpSpec (Python wrapper source)
The Python wrapper emitted by codegen_kernel() contains both ops inside a
single nested LoopSpec. Below is the actual output produced by running the e2e test
test_hint_nested_loop_with_scratchpad (which uses spyre_hint(slices=...) /
declare_tensor_dim / name_tensor_dims with lx_planning=True and
allow_all_ops_in_lx_planning=True; concrete HBM addresses replaced with
symbolic names for readability):
sdsc_fused_add_mul_0 = async_compile.sdsc('sdsc_fused_add_mul_0',
[
LoopSpec(
count=sympify('2'), # outer K=2 loop
body=[
LoopSpec(
count=sympify('4'), # inner M=4 loop
body=[
OpSpec(
op='add',
is_reduction=False,
iteration_space={
sympify('c0'): (sympify('512'), 32),
sympify('c1'): (sympify('1024'), 1),
},
op_info={},
tiled_symbols=[sympify('c0'), sympify('c1')],
args=[
TensorArg( # input a
is_input=True, arg_index=0,
device_dtype=DataFormats.SEN169_FP16,
device_size=[64, 1024, 64],
device_coordinates=[
sympify('floor(c1/64)'),
sympify('c0'),
sympify('Mod(c1, 64)'),
],
allocation={'hbm': <base_addr_a>},
),
TensorArg( # input b
is_input=True, arg_index=1,
device_dtype=DataFormats.SEN169_FP16,
device_size=[64, 1024, 64],
device_coordinates=[
sympify('floor(c1/64)'),
sympify('c0'),
sympify('Mod(c1, 64)'),
],
allocation={'hbm': <base_addr_b>},
),
TensorArg( # output y (LX scratchpad)
is_input=False, arg_index=-1,
device_dtype=DataFormats.SEN169_FP16,
device_size=[16, 512, 64],
device_coordinates=[
sympify('floor(c1/64)'),
sympify('c0'),
sympify('Mod(c1, 64)'),
],
allocation={'lx': 0},
),
]
),
OpSpec(
op='mul',
is_reduction=False,
iteration_space={
sympify('c0'): (sympify('512'), 32),
sympify('c1'): (sympify('1024'), 1),
},
op_info={},
tiled_symbols=[sympify('c0'), sympify('c1')],
args=[
TensorArg( # input y (LX scratchpad)
is_input=True, arg_index=-1,
device_dtype=DataFormats.SEN169_FP16,
device_size=[16, 512, 64],
device_coordinates=[
sympify('floor(c1/64)'),
sympify('c0'),
sympify('Mod(c1, 64)'),
],
allocation={'lx': 0},
),
TensorArg( # input c
is_input=True, arg_index=2,
device_dtype=DataFormats.SEN169_FP16,
device_size=[64, 1024, 64],
device_coordinates=[
sympify('floor(c1/64)'),
sympify('c0'),
sympify('Mod(c1, 64)'),
],
allocation={'hbm': <base_addr_c>},
),
TensorArg( # output z (HBM, per-tile)
is_input=False, arg_index=3,
device_dtype=DataFormats.SEN169_FP16,
device_size=[16, 512, 64],
device_coordinates=[
sympify('floor(c1/64)'),
sympify('c0'),
sympify('Mod(c1, 64)'),
],
allocation={'hbm': <base_addr_z>},
),
]
),
],
),
],
),
]
)
Key observations:
c0andc1are Inductor’s iteration-space symbols for the two dimensions.iteration_spacereflects the per-inner-iteration tile size[512, 1024].tiled_symbols=[c0, c1]records — outermost first — which symbols correspond to the tiled dimensions:c0drives the outerscf.for,c1the inner one.The intermediate tensor
y(output ofadd, input tomul) hasallocation={'lx': 0}— it lives in LX scratchpad memory at address 0. Itsdevice_size=[16, 512, 64]reflects the per-tile shape[512, 1024]. Becauseyis produced and fully consumed within the same tile iteration, no HBM allocation is needed and its address is a fixed scratchpad offset that does not change between loop iterations (noaffine.applyneeded).The final output
z(output ofmul) hasallocation={'hbm': ...}andarg_index=3— it lives in HBM. Itsdevice_size=[16, 512, 64]also reflects the per-tile shape; the per-iteration write address into the full HBM buffer is computed byaffine.applyinbundle.mlir(see next section).HBM inputs
a,b,chavedevice_size=[64, 1024, 64]— the full tensor shape[1024, 4096]in Spyre stick layout. Theirdevice_coordinatesusec0andc1to index the per-iteration tile window into the full tensor. The per-tile output buffers (y,z) havedevice_size=[16, 512, 64], the stick-layout shape for[512, 1024]fp16: 16 sticks of 64 columns across 512 rows.
Generated bundle.mlir
The SDSC compiler (compile_op_spec) translates tiled_symbols into per-loop
byte strides, producing a 2-dimensional affine_map. For this [1024, 4096]
fp16 tensor with Spyre stick layout (128 bytes/stick, 64 elements/stick):
Outer stride: 512 rows × 64 sticks/row × 128 bytes/stick = 4,194,304 bytes
Inner stride: 1024 columns / 64 elements/stick × 128 bytes/stick = 2,048 bytes
#map_0 = affine_map<(d0, d1)[s0] -> (s0 + 4194304*d0 + 2048*d1)>
module {
func.func @sdsc_bundle() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%loop_bound_0 = arith.constant 2 : index
%loop_bound_1 = arith.constant 4 : index
%sym_1 = arith.constant <base_addr_a> : index
%sym_2 = arith.constant <base_addr_b> : index
%sym_3 = arith.constant <base_addr_c> : index
%sym_4 = arith.constant <base_addr_z> : index
scf.for %i_0 = %c0 to %loop_bound_0 step %c1 {
scf.for %i_1 = %c0 to %loop_bound_1 step %c1 {
%addr_0 = affine.apply #map_0(%i_0, %i_1)[%sym_1]
%addr_1 = affine.apply #map_0(%i_0, %i_1)[%sym_2]
sdscbundle.sdsc_execute (%addr_0, %addr_1) {sdsc_filename="sdsc_0.json", ...} // add: a+b→y(lx)
%addr_2 = affine.apply #map_0(%i_0, %i_1)[%sym_3]
%addr_3 = affine.apply #map_0(%i_0, %i_1)[%sym_4]
sdscbundle.sdsc_execute (%addr_2, %addr_3) {sdsc_filename="sdsc_1.json", ...} // mul: y(lx)*c→z
}
}
return
}
}
Both operations share the same affine map because they operate on tensors of
the same shape and stride structure. The scratchpad tensor y does not appear
as a symbol — it has a fixed lx address that does not change between
iterations. Each inner-loop iteration dispatches add then mul at tile
(i_0, i_1), keeping the intermediate result y in scratchpad between the
two dispatches.
Layer 1 — Pre-scheduling IR pass
Attribute contract on ir.Operation
The coarse-tiling pass stamps two attributes onto each ir.Operation that
participates in a loop group. These attributes are plain Python values
attached with setattr; no Inductor base class is modified.
Attribute |
Type |
Meaning |
|---|---|---|
|
|
Nesting-path tuple identifying which loop group this op belongs to. Its length equals the nesting depth. All ops sharing the same tuple form the body of the innermost counted loop at that path. |
|
|
Trip counts, one per nesting level from outermost to innermost. For a flat (depth-1) group this is a 1-element list |
|
|
Per-level positional indices into |
The pass also rewrites the op’s iteration ranges: for each level, the
dimensions at the corresponding indices in loop_tiled_dims are divided by
the corresponding count in loop_count, so that each inner OpSpec
describes only the work done per innermost-loop iteration.
loop_group_id is a tuple rather than a flat integer to support nested
loops. See “Nested loops and the loop_group_id tree” below.
Why these three attributes are sufficient
loop_count is redundant across all ops sharing the same loop_group_id
(they must agree), but keeping it on each op means the post-fusion pass does
not need to maintain a separate side table. The loop_group_id is the join
key. loop_tiled_dims is the bridge between the pre-scheduling pass (which
operates on positional data.ranges indices) and the codegen phase (which
uses named sympy Symbols) — it is read by create_op_spec to identify, by
index, which scheduler-level symbols correspond to the tiled dimensions and
should be recorded in OpSpec.tiled_symbols. All levels are flattened
(outermost first) so that tiled_symbols covers every loop variable for the
op. Using a list-of-lists of indices (rather than a count or a flag) allows
different ops in the same loop to tile non-contiguous or differently
positioned dimensions of their respective iteration spaces.
Loops is a frozen dataclass
Inductor’s ir.Loops (the base of Pointwise and Reduction) is
declared @ir_dataclass(frozen=True), so data.ranges = x raises
FrozenInstanceError. The tiling pass uses object.__setattr__ to
bypass this:
object.__setattr__(data, "ranges", ranges)
Public API: coarse_tile()
def coarse_tile(
operations: list[Operation],
groups: list[tuple],
*,
tiled_dims: list[int] | None = None,
) -> None:
groups is a pre-computed list of group tuples supplied by the caller
(e.g., config.coarse_tiling_groups_fn). Each ops list must be a
contiguous sub-sequence of operations; a gap indicates a data-flow
dependency crossing the group boundary and raises RuntimeError.
Each group tuple takes one of two forms:
Flat (single loop):
(ops, K) # tile dim 0 by K (default)
(ops, K, [0, 1]) # tile dims 0 and 1 by K
The optional third element overrides the tiled_dims keyword argument for
that specific group. None (the default) divides only dimension 0.
Nested (multiple independent loops on the same ops):
(ops, [(K1, [0]), (K2, [1])]) # outer K1 on dim 0; inner K2 on dim 1
(ops, [(K1, [0]), (K2, [0,1])]) # outer K1 on dim 0; inner K2 on dims 0 and 1
The second element is a list of (count, tiled_dims) pairs, outermost first.
The ops end up in the innermost loop body; each level’s count divides the
corresponding dims independently (outermost pass applied first).
coarse_tile() normalises flat syntax to the list-of-pairs form internally,
so _stamp_group always works with the canonical representation.
Feature flag and groups callable
# config.py
coarse_tiling: bool = os.environ.get("COARSE_TILING", "0") == "1"
coarse_tiling_groups_fn: Optional[Callable] = None # overrides hint-derived groups
When coarse_tiling=True and coarse_tiling_groups_fn is None, groups
are derived automatically from spyre_hint(slices=...) annotations via
hints_to_coarse_tile_groups — a no-op if no hints are present. Setting
coarse_tiling_groups_fn to a callable overrides the hint-derived groups
entirely; this is intended for interim testing until the annotation framework
matures and will be removed once it is complete.
coarse_tiling_groups_fn must be a module-level named function, not a
lambda, because Inductor’s FX graph cache pickles the config values.
Placement in CustomPreSchedulingPasses
The coarse-tiling pass runs after layout finalization and before
span_reduction:
deadcode_elimination(operations)
propagate_spyre_tensor_layouts(operations)
optimize_restickify_locations(operations)
finalize_layouts(operations)
insert_restickify(operations)
insert_bmm_padding(operations)
dedup_and_promote_constants(operations)
if config.chunk_large_tensors:
chunk_large_tensors(operations)
propagate_named_dims(operations)
assign_dim_hints(operations)
if config.coarse_tiling:
groups = hints_to_coarse_tile_groups(operations)
if config.coarse_tiling_groups_fn is not None:
groups = config.coarse_tiling_groups_fn(operations)
coarse_tile(operations, groups=groups)
span_reduction(operations)
k_fast_ops = (
k_fast_division(operations) if config.core_id_k_fast_emission else []
)
work_distribution(operations, k_fast_ops)
if config.lx_planning:
allocator = (
StrategyBCoOptimizingAllocator()
if config.co_optimizing_lx_planning
else None
)
scratchpad_planning(graph, allocator=allocator)
This ordering is required by several constraints:
propagate_named_dims and assign_dim_hints must run before coarse tiling.
propagate_named_dims propagates name_tensor_dims() annotations through the
op graph, attaching named dimension metadata to each ir.Operation.
assign_dim_hints then combines those named dimensions with the spyre_hint
scope annotations (attached to FX nodes as meta["custom"]) to produce
op.dim_hints — a flat list of DimHint objects consumed by
hints_to_coarse_tile_groups to form the coarse tiling groups.
Must run after stickify and padding. propagate_spyre_tensor_layouts,
insert_restickify, and insert_bmm_padding establish the final tiled
memory layout for each tensor. The tiling pass must see the post-stickify,
post-padding shapes or it will split on the wrong dimension or produce a
non-stick-aligned inner size.
Must run before work_distribution. work_distribution stamps
op_it_space_splits on each ir.Operation to assign per-core work
slices. It must see the already-reduced (inner) iteration space so that
cores divide the per-iteration work, not the full pre-tiling iteration
space. Running coarse tiling after work_distribution would produce
op_it_space_splits values sized for the full range, which would then
be wrong relative to the reduced ranges written by the tiling pass.
span_reduction and k_fast_division have the same requirement and
already run before work_distribution, so placing coarse_tile with
them is consistent.
scratchpad_planning must run after coarse tiling because it sizes
scratchpad allocations to fit the per-iteration working set. If it ran
before, it would see the full iteration space and allocate too much —
defeating the working-set reduction that coarse tiling is designed to
achieve. scratchpad_planning receives the full GraphLowering object
(not just operations) because it needs access to graph-level metadata
for buffer lifetime analysis.
Buffer propagation: insert_tiling_propagation
coarse_tile() calls insert_tiling_propagation(operations, groups)
immediately after stamping all loop attributes. Its job is to ensure that
any op whose result is consumed outside the loop (or is a graph output)
exposes a complete, fully-sized buffer to its consumers. Ops whose outputs
are consumed only inside the loop are marked so the unroller does not advance
their base addresses.
Use-def analysis
For each ComputedBuffer in a loop group the pass asks two questions:
Does this buffer have outside consumers? A consumer is “outside” if it carries a different
loop_group_idprefix, or has noloop_group_idat all. Graph outputs (recorded in the Inductor buffer’susers/get_alias_namemachinery) count as outside consumers.Does this buffer have inside consumers? A consumer is “inside” if it shares the same
loop_group_idtuple (i.e. it is another op in the same innermost loop body).
per_tile_fixed — loop-internal buffers
If a buffer has no outside consumers and is not a graph output, it is a per-tile scratch region that is fully overwritten and read within the same loop iteration. The pass marks it:
if isinstance(op.layout, FixedTiledLayout):
op.layout.per_tile_fixed = True
This flag propagates to TensorArg.per_tile_fixed during codegen (in
spyre_kernel.py). The unroller (codegen/unroll.py) then skips two
things for these args:
Address advance — the base address is fixed; no per-iteration offset is added.
device_sizeupdate — the tile geometry is not applied; the hardware uses the original allocation size, which already matches the tile.
Case 1 — output used both inside and outside the loop
The tiled op writes into its small, per-iteration buffer as usual. The pass
allocates a full-sized HBM buffer (sized to the original pre-division ranges)
and inserts a copy op immediately after the tiled op in the operations
list. The copy op carries the same loop_group_id, loop_count, and
loop_tiled_dims as the original, so the scheduler wraps both ops in the
same CountedLoopSchedulerNode. Its TensorArg for the destination uses
the full buffer’s address; the existing tiled_symbols / affine.apply
machinery in SpyreKernel and bundle.py computes the per-iteration slice
offset automatically.
All outside consumers are then patched to read the full buffer instead of the tiled one.
Case 2 — output used only outside the loop
When no inside consumer needs the per-iteration buffer, the simplest fix is to rewire the tiled op itself to write directly into the full buffer:
op.layout = MutationLayoutSHOULDREMOVE(TensorBox(StorageBox(full_buf)))
MutationLayoutSHOULDREMOVE tells Inductor that the op mutates an existing
storage in-place. Because the full buffer is pre-allocated and its address
is encoded in the TensorArg via the tiled_symbols offset, no copy op is
needed.
Reduction safety checks
Before running the propagation logic for a Reduction op, the pass calls
_check_reduction_tiling_safety(op) which raises RuntimeError in two
unsupported configurations:
Matrix multiply (
reduction_type == "batchmatmul") inside a tiling loop — the accumulation semantics are not handled.Tiled reduction dim — if any entry in
loop_tiled_dimsis>= len(data.ranges), the tiled index falls inreduction_ranges. Accumulation-buffer support for this case is not yet implemented.
Both checks happen before any buffer allocation, so the error is clean.
Layer 2 — CountedLoopSchedulerNode
Class definition
CountedLoopSchedulerNode lives in
torch_spyre/_inductor/scheduler.py alongside SuperDSCScheduling.
It subclasses Inductor’s FusedSchedulerNode:
class CountedLoopSchedulerNode(FusedSchedulerNode):
loop_count: sympy.Expr
def __init__(
self,
scheduler,
snodes: list[BaseSchedulerNode],
loop_count: sympy.Expr,
) -> None:
super().__init__(scheduler, snodes)
self.loop_count = loop_count
def unpack(self) -> list[BaseSchedulerNode]:
# CountedLoopSchedulerNode is an atomic codegen unit; do not unpack.
return [self]
@classmethod
def can_fuse(
cls,
producer: BaseSchedulerNode,
consumer: BaseSchedulerNode,
) -> bool:
return False
unpack() returns [self] to prevent Inductor’s
Scheduler.process_grouped_nodes() from dissolving the node back into its
constituent SchedulerNodes before codegen. can_fuse returns False
— a loop group is atomic; nothing can be fused into it from outside.
Why FusedSchedulerNode is the right base
CountedLoopSchedulerNode subclasses FusedSchedulerNode rather than
GroupedSchedulerNode for two reasons:
Dispatch:
Scheduler._codegenonly dispatchesFusedSchedulerNode | SchedulerNodetocodegen_node(). AGroupedSchedulerNodesubclass falls through toassert isinstance(node, NopKernelSchedulerNode)and crashes.Unpack control:
GroupedSchedulerNodeis unconditionally unpacked byScheduler.process_grouped_nodes()at the start of codegen.FusedSchedulerNodeis not subject to that unpack, so overridingunpack()is sufficient to keep the node intact.
FusedSchedulerNode already merges unmet_dependencies across all
constituent nodes, exposes get_nodes(), and registers all constituent
names in scheduler.name_to_fused_node. Nothing needs to be
reimplemented.
The post-fusion pass and ordering
CountedLoopSchedulerNodes are created by build_loop_scheduler_nodes,
which is registered as the first pass in CustomPostFusionPasses:
class CustomPostFusionPasses(CustomNodePassBase):
def get_passes(self):
return [memory_planning, build_loop_scheduler_nodes, spyre_fuse_nodes]
build_loop_scheduler_nodes must run before spyre_fuse_nodes.
If spyre_fuse_nodes ran first, it could merge SchedulerNodes from
different loop groups into a single FusedSchedulerNode. The loop-group
pass would then see one fused node spanning multiple groups rather than
the individual nodes with distinct loop_group_ids, and would wrap both
groups in a single CountedLoopSchedulerNode with a single loop_count.
Running loop grouping first ensures each group is already wrapped and
opaque before spyre_fuse_nodes runs. can_fuse = False then prevents
spyre_fuse_nodes from merging across group boundaries.
The grouping algorithm
build_loop_scheduler_nodes scans the flat node list and groups
contiguous runs sharing the same outermost loop_group_id key:
result = []
i = 0
while i < len(nodes):
node = nodes[i]
gid = _loop_group_id(node) # reads loop_group_id from the inner ir.Operation
if gid is None:
result.append(node)
i += 1
continue
outer_key = gid[0]
run = [node]; i += 1
while i < len(nodes) and _loop_group_id(nodes[i])[0] == outer_key:
run.append(nodes[i]); i += 1
# Recursively wrap deeper nesting within this run.
inner = _build_loop_group(run, depth=1)
result.append(CountedLoopSchedulerNode.create(inner, loop_count))
return result
Key invariant: because the pre-scheduling pass runs in topological order
and the scheduler’s topological sort preserves that order, a loop group’s
SchedulerNodes will be contiguous in the post-fusion node list. If they
are not contiguous it means a data-flow constraint separates them, which is a
bug in the tiling pass. The post-fusion pass asserts contiguity.
Layer 3 — LoopSpec and codegen
LoopSpec and OpSpec.tiled_symbols in op_spec.py
@dataclasses.dataclass
class LoopSpec:
count: sympy.Expr
body: list[OpSpec | UnimplementedOp | LoopSpec]
@dataclasses.dataclass
class OpSpec:
op: str
is_reduction: bool
iteration_space: dict[Symbol, tuple[Expr, int]]
args: Sequence[TensorArg]
op_info: dict[str, Any]
tiled_symbols: list[Symbol] = field(default_factory=list)
LoopSpec is a peer of OpSpec and UnimplementedOp in the list that
SpyreKernel.codegen_kernel() serializes. It is not a subclass of OpSpec
because it has no iteration_space, args, or op_info of its own — those
belong to the inner OpSpecs.
The body type is recursive: a LoopSpec body may itself contain
LoopSpec entries, representing nested counted loops.
OpSpec.tiled_symbols carries the per-op tiling information: all
iteration-space symbols that are divided by any enclosing loop, listed
outermost first. It is empty for ops that are not inside a
LoopSpec. For a single-level tiled op, tiled_symbols = [s0]. For
a two-level nested tiled op, tiled_symbols = [s_outer, s_inner]. The
runtime uses this list together with the enclosing loop variables (also
outermost-first) to build the affine address formula:
base + s_outer_stride * i_outer + s_inner_stride * i_inner.
Tiling information is stored on OpSpec rather than on LoopSpec because
different body ops may tile different iteration-space dimensions. Two ops in
the same loop group can have different tiled_symbols if, for example, work
division or stickification places the batch dimension at different positions in
each op’s iteration space. A single int on LoopSpec cannot express this;
per-op list[Symbol] can.
Nested loops and the loop_group_id tree
Each ir.Operation carries a loop_group_id that is a path rather
than a flat integer. A path is a tuple of integers, one element per
nesting level:
|
Meaning |
|---|---|
|
outermost loop group 0, not nested |
|
single op nested two levels deep inside group 0 |
|
ops at depth 2 inside outer group 0, inner group 1 |
loop_count is a list parallel to the path. For a flat op at
(0,), loop_count = [K]. For a single op at (0, 0),
loop_count = [K1, K2] — the scheduler reads loop_count[0] = K1 when
building the outer CountedLoopSchedulerNode and loop_count[1] = K2
when building the inner one. This allows a single op to supply the counts
for all its enclosing loops without requiring sibling ops at intermediate
depths.
The post-fusion pass (_build_loop_group) reconstructs the tree
recursively:
Group the flat
SchedulerNodelist into runs that share the same outermost group id element (indexdepth).Read the count for this depth from
_loop_count(node, depth), which indexesloop_count[depth - base_depth]. All nodes in the run must agree on this count.Recursively call
_build_loop_group(run, depth + 1)to build the inner level.Wrap the result in a
CountedLoopSchedulerNode(count=K_outer, ...).
Because every op carries the full loop_count list, the algorithm works
even when a run contains only a single op that spans all nesting levels —
there is no need for placeholder ops at intermediate depths.
Bundle boundary constraint
A CountedLoopSchedulerNode (at any nesting depth) and all its
descendant SchedulerNodes must be codegen’d into a single SuperDSC
bundle — i.e., a single codegen_node() call must produce the entire
LoopSpec tree. This is automatically satisfied because Inductor calls
codegen_node() once per BaseSchedulerNode in the topological order,
and a CountedLoopSchedulerNode is a single node that encapsulates all
its children. No loop group can be split across two codegen_node()
calls.
The bundle boundary constraint also forbids a loop group from being split
by Inductor fusion: can_fuse returns False on
CountedLoopSchedulerNode, so no external node can be merged into or
absorb part of a loop group.
In bundle.py, generate_bundle iterates the flat list[OpSpec]
emitted by codegen_kernel(). When it encounters a LoopSpec it
emits SDSC JSON files for each OpSpec in the body (recursively) and
wraps those executions in an scf.for in bundle.mlir.
Changes to SuperDSCScheduling.codegen_node()
codegen_node already handles FusedSchedulerNode | SchedulerNode.
CountedLoopSchedulerNode is recognized by an isinstance check:
def codegen_node(
self,
node: Union[FusedSchedulerNode, SchedulerNode, CountedLoopSchedulerNode],
) -> None:
if isinstance(node, CountedLoopSchedulerNode):
self._codegen_counted_loop(node)
return
# existing flat-list path unchanged
...
def _codegen_counted_loop(self, node: CountedLoopSchedulerNode) -> None:
inner_nodes = [
n for n in node.get_nodes()
if n.get_name() not in self.scheduler.removed_ops
]
kernel = SpyreKernel()
all_schedule_nodes = []
with kernel:
for inner in inner_nodes:
if isinstance(inner, CountedLoopSchedulerNode):
self._codegen_loop_body(inner, kernel, all_schedule_nodes)
else:
sched = self.generate_node_schedule([inner])
all_schedule_nodes.extend(sched)
for snode in sched:
var_ranges = iteration_space(snode)
vs = list(var_ranges.keys())
index_vars = [vs[:len(snode._body.iter_vars)],
vs[len(snode._body.iter_vars):]]
snode.codegen(index_vars)
# Wrap the collected inner specs in a LoopSpec
kernel.wrap_op_specs_in_loop(node.loop_count)
with V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, all_schedule_nodes, kernel)
...
_codegen_loop_body handles nested CountedLoopSchedulerNodes: it
codegens the body ops into the existing kernel, then wraps only the newly
added op_specs entries in an inner LoopSpec. The outer
_codegen_counted_loop then wraps everything in the outer LoopSpec via
wrap_op_specs_in_loop.
SpyreKernel.wrap_op_specs_in_loop(count) replaces the flat self.op_specs
list with [LoopSpec(count=count, body=self.op_specs)].
generate_node_schedule handles FusedSchedulerNodes that may appear
among the inner nodes (e.g. from earlier passes that fused nodes within
the same loop group) by flattening them into their constituent
SchedulerNodes.
Serialization in codegen_kernel()
codegen_kernel() already iterates self.op_specs to emit Python source.
A LoopSpec entry is serialized as:
LoopSpec(
count=sympify('K'),
body=[
OpSpec(
...,
tiled_symbols=[sympify('c0')], # emitted only when non-empty
),
LoopSpec( # nested loop
count=sympify('J'),
body=[
OpSpec(..., tiled_symbols=[sympify('c0'), sympify('c1')]),
],
),
],
)
tiled_symbols is populated by SpyreKernel.create_op_spec: it reads
loop_tiled_dims (a list[list[int]]) from the ir.Operation (stamped
by coarse_tile()), flattens all levels outermost-first, and selects the
symbols at those indices from the scheduler-level iteration_space dict.
MemoryDep.ranges preserves the data.ranges ordering, so this
positional correspondence is stable across the pre-scheduling to codegen
boundary.
tiled_symbols is omitted from the serialized source when empty (i.e. for ops
that are not inside a loop), keeping the generated output identical to the
pre-tiling baseline for non-tiled kernels.
The generated Python wrapper imports LoopSpec from op_spec.py so the
serialized source is re-loadable from the Inductor cache.
The arg_index fixup loop (which maps tensor names to kernel argument
positions) runs before serialization. It must walk the LoopSpec tree
recursively to find all TensorArg objects inside nested bodies, not
just the top-level self.op_specs list.
bundle.mlir generation for loops
generate_bundle in bundle.py emits one
sdscbundle.sdsc_execute line per OpSpec. When a LoopSpec is
present it emits an scf.for block in bundle.mlir wrapping the
execute calls for the body ops.
The loop induction variable is an index type running from 0 to
count with step 1. For the current prototype, count must be a
concrete integer; symbolic loop counts raise NotImplementedError.
Emitted MLIR for a single-level loop with one body op:
module {
func.func @sdsc_bundle() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%loop_bound_0 = arith.constant 4 : index
scf.for %i_0 = %c0 to %loop_bound_0 step %c1 {
sdscbundle.sdsc_execute () {sdsc_filename="sdsc_a_0.json"}
}
return
}
}
For nested loops, scf.for blocks are nested and induction variables are
numbered sequentially (%i_0, %i_1, …):
%loop_bound_0 = arith.constant 4 : index
%loop_bound_1 = arith.constant 8 : index
scf.for %i_0 = %c0 to %loop_bound_0 step %c1 {
sdscbundle.sdsc_execute () {sdsc_filename="sdsc_a_0.json"}
scf.for %i_1 = %c0 to %loop_bound_1 step %c1 {
sdscbundle.sdsc_execute () {sdsc_filename="sdsc_a_1.json"}
}
}
generate_bundle walks the list[OpSpec | LoopSpec] recursively,
maintaining an indentation level and a counter for SDSC JSON filenames.
The filenames are assigned in depth-first traversal order.
Files changed
File |
Change |
|---|---|
|
Add |
|
Add |
|
Add |
|
Add |
|
Add |
|
Add |
|
|
|
New file: |
|
Add |
|
Add |
|
Add |
|
Add |
|
Extend |
|
|
|
Consolidated unit test suite: |
|
End-to-end compilation tests: baseline, single group, softmax-shaped, two groups, per-group tiled dims, unrolled execution |
|
Unit tests for |
Invariants and failure modes
Contiguity invariant: all SchedulerNodes sharing a loop_group_id
must be contiguous after the scheduler’s topological sort. If the tiling
pass stamps ops that have a data dependency crossing the group boundary,
the post-fusion pass will detect a non-contiguous run and raise a
RuntimeError.
Consistent loop_count: all ops sharing a loop_group_id must agree
on loop_count at every depth level. The post-fusion pass asserts this.
tiled_symbols populated iff inside a loop: OpSpec.tiled_symbols is
non-empty exactly when the op was codegen’d inside a CountedLoopSchedulerNode.
Its elements are the flattened (outermost-first) per-level tiled dims from
loop_tiled_dims on the corresponding ir.Operation, selected from the
scheduler-level iteration_space keys.
Pass ordering: coarse tiling must run after stickify/padding and
before span_reduction, k_fast_division, work_distribution, and
scratchpad_planning. build_loop_scheduler_nodes must run before
spyre_fuse_nodes in CustomPostFusionPasses — see the ordering
rationale above.
Cache invalidation: coarse_tile.py, scratchpad_planning, and all
other pass source files are included in CustomPreSchedulingPasses.uuid()
so the Inductor FX cache is invalidated when any pass changes. The
coarse_tiling_groups_fn must be a module-level named function (not a
lambda) for Inductor’s cache pickling to work.
Rejected design alternatives
Inductor’s existing loop IR
Inductor has several loop-related constructs, none of which fit the requirement.
ir.Loops / Pointwise / Reduction (torch/_inductor/ir.py).
These have a ranges: Sequence[Expr] field that describes the iteration
space of a single operation. They model per-op loop bounds, not a loop
that groups multiple operations together. There is no concept of “execute
this sequence of ops N times.”
ir.WhileLoop (torch/_inductor/ir.py). A while-loop IR node for
data-dependent control flow. Trip count is not statically known; not
appropriate for the counted, coarse-tiling use case.
GroupedSchedulerNode (torch/_inductor/scheduler.py). Groups a
sequence of SchedulerNodes so the scheduler cannot interleave other
nodes between them. This is a pure scheduling constraint: it carries no
loop count, does not rewrite iteration spaces, and is unconditionally
unpacked by Scheduler.process_grouped_nodes() before codegen. It also
does not appear in the FusedSchedulerNode | SchedulerNode isinstance
check in Scheduler._codegen, so a subclass of GroupedSchedulerNode
would not be dispatched to codegen_node() at all. These limitations
make FusedSchedulerNode the correct base instead.
codegen.cpp.LoopLevel / LoopNest (torch/_inductor/codegen/cpp.py).
Codegen-time loop structures used by the C++ backend to emit nested
for loops. They exist only during C++ code emission and have no
presence in the scheduler or IR layers where Spyre’s optimization passes
run.
Helion’s ForLoopGraphInfo
Helion (helion/_compiler/device_ir.py) represents loops as
ForLoopGraphInfo nodes. Each node wraps a nested FX sub-graph
(referenced by graph_id) and a block_ids list that determines which
tile dimensions participate in the loop. The FX graph for the outer
scope contains a _for_loop(graph_id, begin, end, args) node
(helion/language/_tracing_ops.py) as a placeholder. A companion
ReductionLoopGraphInfo handles reduction loops.
This design is well-suited to Helion’s tile-strategy-driven GPU compilation model, where the loop structure is discovered during tracing and the body is a reusable sub-graph. It is a poor fit for Spyre’s pipeline for three reasons:
Wrong representation layer. Spyre’s optimization passes operate on
list[ir.Operation]before the InductorSchedulerexists. Helion’s loop nodes live in an FX graph; adopting that representation would require building and maintaining a parallel FX graph for the pre-scheduling IR, adding substantial complexity.Tile strategy coupling.
ForLoopGraphInfocarriesblock_idsthat reference Helion’s tile strategy objects. Spyre has no tile strategy layer; loop structure comes from the coarse-tiling pass decision, not from a tiling configuration object.Sub-graph identity vs. flat sequence. Helion identifies loop bodies by an opaque
graph_idand looks them up in a registry. For Spyre’s use case — a contiguous run ofSchedulerNodes that must stay together — a flat ordered list insideCountedLoopSchedulerNodeis simpler and directly matches whatcodegen_nodealready iterates.
The key insight borrowed from Helion is that the loop body should be a
separate, named structure rather than an attribute on individual ops.
That insight shaped the decision to make CountedLoopSchedulerNode a
first-class scheduler node (rather than stamping a loop-count attribute
on each SchedulerNode and reconstructing the grouping at codegen time).
Attribute-only approach (Option B)
An earlier candidate design stamped loop_group_id and loop_count
directly onto ir.Operation objects and deferred all grouping to
codegen_node(), which would scan the flat node_schedule list and
reconstruct loop boundaries at codegen time.
This was rejected because it is fragile in the face of correctness
requirements. If the scheduler ever reorders nodes within what the
tiling pass intended to be a loop group — or if a group boundary does
not align perfectly with a fused-node boundary — the reconstruction in
codegen_node() silently produces wrong output: incorrect trip counts or
mismatched iteration spaces. With coarse tiling these are correctness
bugs, not performance bugs. CountedLoopSchedulerNode enforces the
grouping structurally: the scheduler cannot split or reorder within it,
and a mismatch is caught at post-fusion pass time rather than silently at
codegen time.
Out of scope
Loops whose trip count is data-dependent (use
ir.WhileLoopfor that).Fusing a non-tiled op into the body of a
CountedLoopSchedulerNode.Passing the loop induction variable into an
OpSpecbody (ops inside a loop do not currently use the induction variable; each iteration executes identically on a different slice of the data determined by the reduced iteration space).Symbolic loop counts in
bundle.mlir(currently raisesNotImplementedError; requires runtime shape plumbing into the MLIR function signature).