Skip to content

Commit 0efcf06

Browse files
authored
[Docs] Add list of indexing autotuning docs (#1027)
1 parent fdc98b7 commit 0efcf06

File tree

3 files changed

+69
-7
lines changed

3 files changed

+69
-7
lines changed

README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ portable between different hardware. Helion automates and autotunes over:
3535

3636
* Automatically calculates strides and indices.
3737
* Autotunes choices among various indexing methods (pointers, block pointers, TensorDescriptors).
38+
* Supports per-load indexing strategies for fine-grained memory access control.
3839

3940
2. **Masking:**
4041

@@ -257,10 +258,14 @@ Reorders the program IDs (PIDs) of the generated kernel for improved L2
257258
cache behavior. A value of `1` disables this optimization, while higher
258259
values specify the grouping size.
259260

260-
* **indexing** (`"pointer"`, `"tensor_descriptor"` or `"block_ptr"`):
261-
Specifies the type of indexing code to generate. The `"tensor_descriptor"`
262-
option uses Tensor Memory Accelerators (TMAs) but requires a Hopper or
263-
newer GPU and the latest development version of Triton.
261+
* **indexing** (`"pointer"`, `"tensor_descriptor"`, `"block_ptr"`, or a list of these):
262+
Specifies the memory indexing strategy for load operations. Can be:
263+
- A single strategy (applies to all loads): `indexing="block_ptr"`
264+
- A list of strategies (one per load operation): `indexing=["pointer", "block_ptr", "tensor_descriptor"]`
265+
- Empty/omitted (defaults to `"pointer"` for all loads)
266+
267+
The `"tensor_descriptor"` option uses Tensor Memory Accelerators (TMAs) but
268+
requires a Hopper or newer GPU and the latest development version of Triton.
264269

265270
* **pid\_type** (`"flat"`, `"xyz"`, `"persistent_blocked"`, or `"persistent_interleaved"`):
266271
Specifies the program ID mapping strategy. `"flat"` uses only the x-dimension,

docs/api/config.md

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,30 @@ Configs are typically discovered automatically through autotuning, but can also
109109
110110
.. autoattribute:: Config.indexing
111111
112-
Memory indexing strategy:
112+
Memory indexing strategy for load operations. Can be specified as:
113113
114-
- ``"pointer"``: Pointer-based indexing
115-
- ``"tensor_descriptor"``: Tensor descriptor indexing
114+
**Single strategy (applies to all loads - backward compatible):**
115+
116+
.. code-block:: python
117+
118+
indexing="block_ptr" # All loads use block pointers
119+
120+
**Per-load strategies (list, one per load operation):**
121+
122+
.. code-block:: python
123+
124+
indexing=["pointer", "block_ptr", "tensor_descriptor"]
125+
126+
**Empty/omitted (defaults to** ``"pointer"`` **for all loads):**
127+
128+
.. code-block:: python
129+
130+
# indexing not specified - all loads use pointer indexing
131+
132+
**Valid strategies:**
133+
134+
- ``"pointer"``: Pointer-based indexing (default)
135+
- ``"tensor_descriptor"``: Tensor descriptor indexing (requires Hopper+ GPU)
116136
- ``"block_ptr"``: Block pointer indexing
117137
```
118138

@@ -185,6 +205,42 @@ def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
185205
# hl.load(x, [tile], eviction_policy="evict_first")
186206
```
187207

208+
### Per-Load Indexing Example
209+
210+
```python
211+
import torch
212+
import helion
213+
import helion.language as hl
214+
215+
# Single indexing strategy for all loads (backward compatible)
216+
@helion.kernel(config={"indexing": "block_ptr"})
217+
def kernel_uniform_indexing(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
218+
out = torch.empty_like(x)
219+
for tile in hl.tile(x.size(0)):
220+
a = hl.load(x, [tile]) # Uses block_ptr
221+
b = hl.load(y, [tile]) # Uses block_ptr
222+
out[tile] = a + b
223+
return out
224+
225+
# Per-load indexing strategies for fine-grained control
226+
@helion.kernel(
227+
config={
228+
"block_size": 16,
229+
"indexing": ["pointer", "block_ptr", "tensor_descriptor"],
230+
}
231+
)
232+
def kernel_mixed_indexing(
233+
x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
234+
) -> torch.Tensor:
235+
out = torch.empty_like(x)
236+
for tile in hl.tile(x.size(0)):
237+
a = hl.load(x, [tile]) # First load: pointer indexing
238+
b = hl.load(y, [tile]) # Second load: block_ptr indexing
239+
c = hl.load(z, [tile]) # Third load: tensor_descriptor indexing
240+
out[tile] = a + b + c
241+
return out
242+
```
243+
188244
### Config Serialization
189245

190246
```python

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ portable between different hardware. Helion automates and autotunes over:
3535

3636
* Automatically calculates strides and indices.
3737
* Autotunes choices among various indexing methods (pointers, block pointers, TensorDescriptors).
38+
* Supports per-load indexing strategies for fine-grained memory access control.
3839

3940
2. **Masking:**
4041

0 commit comments

Comments
 (0)