Indirect Access (Gather)

This page describes how Torch-Spyre compiles indirect tensor accesses — operations like aten.index where the row of a data tensor is selected by a value loaded at runtime from a separate index tensor.

Quick navigation:


What is indirect access?

A gather reads a value tensor at a row determined by a runtime-loaded index:

# x: float16 [M, N], i: int32 [P, Q]
# result: float16 [P, Q, N] — row i[p,q] of x, for each (p,q)
result = x[i]         # aten.index
result = x[i].exp()   # aten.index fused with aten.exp

How Inductor represents indirect access

Inductor lowers aten.index to a Pointwise node whose inner_fn calls ops.indirect_indexing() to convert a loaded value into a loop index:

load(index_tensor, ...)        → tmp0   (int32 value)
indirect_indexing(tmp0, ...)   → i_sym  (used as row address)
load(value_tensor, i_sym * N + c2)

In the resulting MemoryDep, the index expression for the value tensor contains a symbol (e.g. tmp0) that is not present in dep.ranges — it has no static loop bound because it is resolved only at runtime. MemoryDep.is_indirect() returns True for such deps.


The challenge: coordinates that depend on runtime values

Every stage of the Spyre compilation pipeline — stick compatibility checking, normalize_coordinates, align_tensors, and op spec generation — works with device-coordinate expressions: symbolic formulas that map iteration variables (c0, c1, ...) to positions in the on-device tiled layout.

For a direct access these formulas contain only loop variables. For an indirect access, one coordinate of the value tensor contains tmp0, which represents a value that only the hardware will know at runtime. Without special handling every stage that computes or checks coordinates would either crash, silently skip the symbol, or produce wrong layout decisions.


Solution: IndirectAccess and device coordinates

The PR introduces IndirectAccess, a sympy Function subclass defined in op_spec.py:

class IndirectAccess(Function):
    """IndirectAccess(tensor_name) — value loaded from that tensor at the current point."""

IndirectAccess('arg1_1') means “the value loaded from the buffer named arg1_1 at the current iteration point”. It is an opaque sympy atom that survives sympify round-trips and can be carried through arithmetic expressions without being evaluated.

Building the substitution dict

Two helpers in pass_utils.py produce a {indirect_sym IndirectAccess(name)} substitution dict, depending on the pipeline stage:

Helper

When to use

How it works

indirect_access_subs_from_op(op)

Pre-scheduler (propagate_layouts)

Re-executes inner_fn via _IndirectIndexFinder to discover which buffer’s load fed each indirect symbol

indirect_access_subs_from_kernel(indirect_vars)

Post-scheduler (SpyreKernel)

Reads SpyreKernel.indirect_vars, which SpyreKernelOpsHandler.indirect_indexing() populates live during codegen

Both return the same dict shape, so every downstream caller sees a uniform interface regardless of when it runs.

Applying the substitution

compute_coordinates (in views.py) accepts an optional indirect_load_subs parameter. After computing the coordinate expressions in the normal way it applies the substitution:

if indirect_load_subs:
    coordinates = [c.xreplace(indirect_load_subs) for c in coordinates]

This replaces tmp0 with IndirectAccess('arg1_1') in the affected coordinate, giving a gather-aware expression that all later stages can handle.

The device_coordinates wrapper in pass_utils.py forwards the parameter to compute_coordinates, so callers in propagate_layouts.py and spyre_kernel.py only need to pass the dict through one function.

Range inference for indirect symbols

Before the substitution is applied, compute_coordinates needs a range for tmp0 to know which layout dimension it addresses. Indirect symbols have no entry in dep.ranges, so the range is inferred from the tensor layout: the code finds the dimension whose stride equals the coefficient of tmp0 in the index expression:

term = index.xreplace({v: 0 for v in vars - {var}})
coeff = int(term.xreplace({var: 1}))
inferred = next(
    (sz for st, sz in zip(stride, size) if int(st) == coeff and int(sz) > 1),
    None,
)

Pipeline walkthrough

The gather x[i] with x: float16 [128, 256] and i: int32 [3, 192] passes through the following stages.

1. propagate_layouts (pre-scheduler)

compute_layouts is called for the Pointwise node. indirect_access_subs_from_op re-executes inner_fn via _IndirectIndexFinder and returns {tmp0: IndirectAccess(Symbol('arg1_1'))}.

The resulting device coordinates (logged when TORCH_LOGS="+inductor" is set) show the substitution explicitly:

input[1] name=arg0_1  (value tensor x)
  device_coordinates=[floor(d2/64), tmp0, Mod(d2, 64)]
    ->  [floor(d2/64), IndirectAccess(arg1_1), Mod(d2, 64)]

The IndirectAccess coordinate is passed through normalize_coordinates as an opaque Term(var=None, offset=IndirectAccess(...)) and through align_tensors unchanged.

2. SpyreKernel codegen (post-scheduler)

SpyreKernelOpsHandler.indirect_indexing() intercepts the indirect_indexing call during inner_fn execution and stores the source TensorAccess in SpyreKernel.indirect_vars:

sym = sympy_index_symbol(f"indirect{self.kernel._indirect_var_count}")
self.kernel._indirect_var_count += 1
self.kernel.indirect_vars[sym] = index_var   # TensorAccess for arg1_1

indirect_access_subs_from_kernel(self.indirect_vars) then converts this to {sym: IndirectAccess(Symbol('arg1_1'))} — the same shape as the pre-scheduler dict, so create_tensor_arg can use it directly.

3. Op spec generation

SpyreKernel emits the index tensor as the first TensorArg in the op spec with name='arg1_1' set. This name is what IndirectAccess('arg1_1') refers to in the value tensor’s coordinates:

TensorArg(
    is_input=True, arg_index=0, device_dtype=DataFormats.IEEE_INT32,
    device_size=[1, 6, 3, 32],
    device_coordinates=[0, floor(c1/32), c0, Mod(c1, 32)],
    name='arg1_1',
),
TensorArg(
    is_input=True, arg_index=1, device_dtype=DataFormats.SEN169_FP16,
    device_size=[1, 4, 128, 64],
    device_coordinates=[0, floor(c2/64), IndirectAccess('arg1_1'), Mod(c2, 64)],
),

The backend compiler reads IndirectAccess('arg1_1') as: “load this tensor’s row index from the tensor named arg1_1 at the current iteration point”.

4. Wrapper serialization

SpyreKernel._codegen_op_spec_list serializes IndirectAccess expressions using a dedicated branch of sympy_str:

if isinstance(x, IndirectAccess):
    name_sym = x.args[0]
    return f"IndirectAccess('{name_sym}')"

The generated wrapper imports IndirectAccess from op_spec so the op spec survives eval/sympify round-trips at kernel load time.


Stick compatibility for index tensors

The standard stick-compatibility check (compute_restickify_needed) compares stick-dimension expressions across input and output tensors to decide whether a restickify pass is needed. Index tensors must be excluded from this check: their loaded values determine an address, not a position in the output, so constraining their stick layout to match the output would be incorrect.

compute_restickify_needed accepts an optional op parameter. When provided, it calls indirect_index_dep_names(op) and returns (False, None) immediately for any dep whose name is in that set:

if op is not None and in_dep.name in indirect_index_dep_names(op):
    return False, None

The same exclusion is applied in _multi_arg_pointwise_layouts when collecting stick expressions:

indirect_index_names = indirect_index_dep_names(op)
stick_exprs = {
    device_coordinates(stl, arg.dep)[-1]
    for arg in args
    for stl in arg.layouts
    if arg.dep.name not in indirect_index_names
    and device_coordinates(stl, arg.dep)[-1] != 0
}

Op spec layout

For an unfused gather the scheduler produces two op specs:

  1. identity — copies gathered rows from the value tensor into a temporary buffer using IndirectAccess coordinates.

  2. exp (or whichever unary follows) — applies the unary to the temporary buffer using direct coordinates.

When fusion is enabled (currently disabled pending backend support — see the flag in patches.py), the two ops collapse into a single fused op spec with the index tensor as a named input.

The argument ordering rule for op specs that contain an index tensor is:

  1. Index tensor(s), in buffer-name order, each with name set.

  2. Value tensor(s), in the order they appear in the computation.

  3. Output tensor.

The helper _is_indirect_index_arg identifies index-role args post-hoc by scanning other args’ coordinates for IndirectAccess atoms whose name matches.


Fusion

Pointwise fusion with the gather is currently disabled in patches.py because IndirectAccess coordinate expressions are not yet handled in SuperDSC generation. When enabled, the identity op and the downstream pointwise op merge into one op spec with a single IndirectAccess coordinate expression in the input arg.


Current limitations

  • Only 1-D index tensors (a single indirect symbol per data dep) are supported.

  • Scatter index tensors (aten.scatter_, aten.index_put_) are correctly detected via _find_scatter_index_buf_names and excluded from stick compatibility checks. However, IndirectAccess coordinates on output args (the codegen side of scatter) are not yet wired up in SuperDSC generation.

  • The fused (single op spec) path is disabled because IndirectAccess coordinates are not yet handled in SuperDSC generation.