Indirect Access (Gather)
This page describes how Torch-Spyre compiles indirect tensor accesses —
operations like aten.index where the row of a data tensor is selected by a
value loaded at runtime from a separate index tensor.
Quick navigation:
What is indirect access?
A gather reads a value tensor at a row determined by a runtime-loaded index:
# x: float16 [M, N], i: int32 [P, Q]
# result: float16 [P, Q, N] — row i[p,q] of x, for each (p,q)
result = x[i] # aten.index
result = x[i].exp() # aten.index fused with aten.exp
How Inductor represents indirect access
Inductor lowers aten.index to a Pointwise node whose inner_fn calls
ops.indirect_indexing() to convert a loaded value into a loop index:
load(index_tensor, ...) → tmp0 (int32 value)
indirect_indexing(tmp0, ...) → i_sym (used as row address)
load(value_tensor, i_sym * N + c2)
In the resulting MemoryDep, the index expression for the value tensor
contains a symbol (e.g. tmp0) that is not present in dep.ranges —
it has no static loop bound because it is resolved only at runtime.
MemoryDep.is_indirect() returns True for such deps.
The challenge: coordinates that depend on runtime values
Every stage of the Spyre compilation pipeline — stick compatibility checking,
normalize_coordinates, align_tensors, and op spec generation — works with
device-coordinate expressions: symbolic formulas that map iteration variables
(c0, c1, ...) to positions in the on-device tiled layout.
For a direct access these formulas contain only loop variables. For an
indirect access, one coordinate of the value tensor contains tmp0, which
represents a value that only the hardware will know at runtime. Without
special handling every stage that computes or checks coordinates would either
crash, silently skip the symbol, or produce wrong layout decisions.
Solution: IndirectAccess and device coordinates
The PR introduces IndirectAccess, a sympy Function subclass defined in
op_spec.py:
class IndirectAccess(Function):
"""IndirectAccess(tensor_name) — value loaded from that tensor at the current point."""
IndirectAccess('arg1_1') means “the value loaded from the buffer named arg1_1
at the current iteration point”. It is an opaque sympy atom that survives
sympify round-trips and can be carried through arithmetic expressions without
being evaluated.
Building the substitution dict
Two helpers in
pass_utils.py
produce a {indirect_sym → IndirectAccess(name)} substitution dict, depending on
the pipeline stage:
Helper |
When to use |
How it works |
|---|---|---|
|
Pre-scheduler ( |
Re-executes |
|
Post-scheduler ( |
Reads |
Both return the same dict shape, so every downstream caller sees a uniform interface regardless of when it runs.
Applying the substitution
compute_coordinates (in
views.py)
accepts an optional indirect_load_subs parameter. After computing the
coordinate expressions in the normal way it applies the substitution:
if indirect_load_subs:
coordinates = [c.xreplace(indirect_load_subs) for c in coordinates]
This replaces tmp0 with IndirectAccess('arg1_1') in the affected coordinate,
giving a gather-aware expression that all later stages can handle.
The device_coordinates wrapper in pass_utils.py forwards the parameter
to compute_coordinates, so callers in propagate_layouts.py and
spyre_kernel.py only need to pass the dict through one function.
Range inference for indirect symbols
Before the substitution is applied, compute_coordinates needs a range for
tmp0 to know which layout dimension it addresses. Indirect symbols have no
entry in dep.ranges, so the range is inferred from the tensor layout: the
code finds the dimension whose stride equals the coefficient of tmp0 in the
index expression:
term = index.xreplace({v: 0 for v in vars - {var}})
coeff = int(term.xreplace({var: 1}))
inferred = next(
(sz for st, sz in zip(stride, size) if int(st) == coeff and int(sz) > 1),
None,
)
Pipeline walkthrough
The gather x[i] with x: float16 [128, 256] and i: int32 [3, 192] passes
through the following stages.
1. propagate_layouts (pre-scheduler)
compute_layouts is called for the Pointwise node.
indirect_access_subs_from_op re-executes inner_fn via _IndirectIndexFinder
and returns {tmp0: IndirectAccess(Symbol('arg1_1'))}.
The resulting device coordinates (logged when TORCH_LOGS="+inductor" is set)
show the substitution explicitly:
input[1] name=arg0_1 (value tensor x)
device_coordinates=[floor(d2/64), tmp0, Mod(d2, 64)]
-> [floor(d2/64), IndirectAccess(arg1_1), Mod(d2, 64)]
The IndirectAccess coordinate is passed through normalize_coordinates as an
opaque Term(var=None, offset=IndirectAccess(...)) and through align_tensors
unchanged.
2. SpyreKernel codegen (post-scheduler)
SpyreKernelOpsHandler.indirect_indexing() intercepts the indirect_indexing
call during inner_fn execution and stores the source TensorAccess in
SpyreKernel.indirect_vars:
sym = sympy_index_symbol(f"indirect{self.kernel._indirect_var_count}")
self.kernel._indirect_var_count += 1
self.kernel.indirect_vars[sym] = index_var # TensorAccess for arg1_1
indirect_access_subs_from_kernel(self.indirect_vars) then converts this to
{sym: IndirectAccess(Symbol('arg1_1'))} — the same shape as the pre-scheduler
dict, so create_tensor_arg can use it directly.
3. Op spec generation
SpyreKernel emits the index tensor as the first TensorArg in the op
spec with name='arg1_1' set. This name is what IndirectAccess('arg1_1') refers
to in the value tensor’s coordinates:
TensorArg(
is_input=True, arg_index=0, device_dtype=DataFormats.IEEE_INT32,
device_size=[1, 6, 3, 32],
device_coordinates=[0, floor(c1/32), c0, Mod(c1, 32)],
name='arg1_1',
),
TensorArg(
is_input=True, arg_index=1, device_dtype=DataFormats.SEN169_FP16,
device_size=[1, 4, 128, 64],
device_coordinates=[0, floor(c2/64), IndirectAccess('arg1_1'), Mod(c2, 64)],
),
The backend compiler reads IndirectAccess('arg1_1') as: “load this tensor’s row
index from the tensor named arg1_1 at the current iteration point”.
4. Wrapper serialization
SpyreKernel._codegen_op_spec_list serializes IndirectAccess expressions
using a dedicated branch of sympy_str:
if isinstance(x, IndirectAccess):
name_sym = x.args[0]
return f"IndirectAccess('{name_sym}')"
The generated wrapper imports IndirectAccess from op_spec so the op spec
survives eval/sympify round-trips at kernel load time.
Stick compatibility for index tensors
The standard stick-compatibility check (compute_restickify_needed) compares
stick-dimension expressions across input and output tensors to decide whether
a restickify pass is needed. Index tensors must be excluded from this check:
their loaded values determine an address, not a position in the output, so
constraining their stick layout to match the output would be incorrect.
compute_restickify_needed accepts an optional op parameter. When
provided, it calls indirect_index_dep_names(op) and returns (False, None)
immediately for any dep whose name is in that set:
if op is not None and in_dep.name in indirect_index_dep_names(op):
return False, None
The same exclusion is applied in _multi_arg_pointwise_layouts when
collecting stick expressions:
indirect_index_names = indirect_index_dep_names(op)
stick_exprs = {
device_coordinates(stl, arg.dep)[-1]
for arg in args
for stl in arg.layouts
if arg.dep.name not in indirect_index_names
and device_coordinates(stl, arg.dep)[-1] != 0
}
Op spec layout
For an unfused gather the scheduler produces two op specs:
identity — copies gathered rows from the value tensor into a temporary buffer using
IndirectAccesscoordinates.exp (or whichever unary follows) — applies the unary to the temporary buffer using direct coordinates.
When fusion is enabled (currently disabled pending backend support — see the
flag in patches.py), the two ops collapse into a single fused op spec with
the index tensor as a named input.
The argument ordering rule for op specs that contain an index tensor is:
Index tensor(s), in buffer-name order, each with
nameset.Value tensor(s), in the order they appear in the computation.
Output tensor.
The helper _is_indirect_index_arg identifies index-role args post-hoc by
scanning other args’ coordinates for IndirectAccess atoms whose name matches.
Fusion
Pointwise fusion with the gather is currently disabled in patches.py
because IndirectAccess coordinate expressions are not yet handled in SuperDSC
generation. When enabled, the identity op and the downstream pointwise op
merge into one op spec with a single IndirectAccess coordinate expression in
the input arg.
Current limitations
Only 1-D index tensors (a single indirect symbol per data dep) are supported.
Scatter index tensors (
aten.scatter_,aten.index_put_) are correctly detected via_find_scatter_index_buf_namesand excluded from stick compatibility checks. However,IndirectAccesscoordinates on output args (the codegen side of scatter) are not yet wired up in SuperDSC generation.The fused (single op spec) path is disabled because
IndirectAccesscoordinates are not yet handled in SuperDSC generation.