Skip to content

error when operating on python int #778

@v0i0

Description

@v0i0

Note: Please write your bug report in English to ensure it can be understood and addressed by the development team.

Describe the bug
Error when doing math on a python int

To Reproduce

Helion compiler triton codegen error for @helion.kernel(config=helion.Config(block_sizes=[1, 8192], indexing='pointer', loop_orders=[[0, 1]], num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None, None, None, None], range_multi_buffers=[None, None, None, None], range_num_stages=[0, 0, 0, 0], range_unroll_factors=[0, 0, 0, 0], range_warp_specializes=[]), static_shapes=False)
Traceback (most recent call last):
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1337, in run_node
    result = lowering.codegen(self, n)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 784, in codegen
    return expr_from_string(ctx.cg.device_function.user_sympy_expr(self.expr))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 383, in user_sympy_expr
    return self.sympy_expr(expr)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 352, in sympy_expr
    self._lift_sympy_arg(sym), integer=True
    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 371, in _lift_sympy_arg
    return self.expr_arg(expr, origin.origin).name
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 479, in expr_arg
    name=self.new_var(origin.suggest_var_name()),
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/variable_origin.py", line 166, in suggest_var_name
    return f"{self.value.suggest_var_name()}_attr_{self.key}"
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/variable_origin.py", line 98, in suggest_var_name
    raise NotImplementedError(type(self).__name__)
NotImplementedError: SourceOrigin

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 436, in compile_config
    triton_code = self.to_triton_code(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 411, in to_triton_code
    root = generate_ast(self.host_function, config, emit_repro_caller)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/generate_ast.py", line 407, in generate_ast
    codegen.add_statement(codegen.visit(stmt))
                          ^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/ast_extension.py", line 272, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/generate_ast.py", line 234, in visit_For
    codegen_call_with_graph(
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
    return GraphInterpreter(graph, cg).run(*new_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1337, in run_node
    result = lowering.codegen(self, n)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 754, in codegen
    return self.api_func._codegen(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/language/_tracing_ops.py", line 86, in _
    return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state)  # pyright: ignore[reportArgumentType,reportCallIssue]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_ir.py", line 237, in codegen
    return codegen_call_with_graph(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
    return GraphInterpreter(graph, cg).run(*new_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1380, in run_node
    raise InductorLoweringError(
helion.exc.InductorLoweringError: Error in codegen for node mul (<built-in function mul>): SourceOrigin
While processing:
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 30, in fn
    work_n = (worker_id + num_workers * work_item) % n_count
                          ^^^^^^^^^^^^^^^^^^^^^^^


While executing %mul : [num_users=1] = call_function[target=operator.mul](args = (%symnode, %u7), kwargs = {})
Original traceback:
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 30, in fn
    work_n = (worker_id + num_workers * work_item) % n_count

Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

While executing %_for_loop : [num_users=0] = call_function[target=helion.language._tracing_ops._for_loop](args = (1, [0], [%symnode], []), kwargs = {})
Original traceback:
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 29, in fn
    for work_item in hl.grid(work_per_worker):

Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
from typing import Any
from typing import Callable

import torch

import helion
import helion.language as hl

@helion.kernel(config = helion.Config(block_sizes=[1, 8192]))
def fn(
    x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5
) -> tuple[torch.Tensor, torch.Tensor]:
    m, n = x.size()
    n = hl.specialize(n)
    assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"

    out = torch.empty_like(x)
    m_block = hl.register_block_size(m)
    n_block = hl.register_block_size(n)
    m_count = (m + m_block - 1) // m_block
    n_count = (n + n_block - 1) // n_block

    num_workers = torch.cuda.get_device_properties(x.device).multi_processor_count  # type: ignore[arg-type]
    num_workers = num_workers + num_workers + num_workers + num_workers
    total_work = m_count * n_count
    work_per_worker = (total_work + num_workers - 1) // num_workers

    for worker_id in hl.grid(num_workers):
        for work_item in hl.grid(work_per_worker):
            work_n = (worker_id + num_workers * work_item) % n_count
            work_m = (worker_id + num_workers * work_item) // n_count
            work_n_start = work_n * n_block
            work_n_end = min(work_n_start + n_block, n)
            work_m_start = work_m * m_block
            work_m_end = min(work_m_start + m_block, m)
            for tile_n, tile_m in hl.tile([work_n_start, work_m_start], [work_n_end, work_m_end], block_size=[n_block, m_block]):
                x_tile = x[tile_m, tile_n].to(torch.float32)
                out[tile_m, tile_n] = x_tile.to(out.dtype)

    return out

print(open(__file__).read(), '\n\n\n', '=' * 50, '\n\n')

fn(torch.randn(1000, 1024, device='cuda'), torch.randn(1024, device='cuda'))
 


 ================================================== 


Traceback (most recent call last):
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1337, in run_node
    result = lowering.codegen(self, n)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 784, in codegen
    return expr_from_string(ctx.cg.device_function.user_sympy_expr(self.expr))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 383, in user_sympy_expr
    return self.sympy_expr(expr)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 352, in sympy_expr
    self._lift_sympy_arg(sym), integer=True
    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 371, in _lift_sympy_arg
    return self.expr_arg(expr, origin.origin).name
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 479, in expr_arg
    name=self.new_var(origin.suggest_var_name()),
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/variable_origin.py", line 166, in suggest_var_name
    return f"{self.value.suggest_var_name()}_attr_{self.key}"
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/variable_origin.py", line 98, in suggest_var_name
    raise NotImplementedError(type(self).__name__)
NotImplementedError: SourceOrigin

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 44, in <module>
    fn(torch.randn(1000, 1024, device='cuda'), torch.randn(1024, device='cuda'))
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 286, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 620, in __call__
    self.set_config(config)
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 525, in set_config
    self._run = self.compile_config(config)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 436, in compile_config
    triton_code = self.to_triton_code(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 411, in to_triton_code
    root = generate_ast(self.host_function, config, emit_repro_caller)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/generate_ast.py", line 407, in generate_ast
    codegen.add_statement(codegen.visit(stmt))
                          ^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/ast_extension.py", line 272, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/generate_ast.py", line 234, in visit_For
    codegen_call_with_graph(
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
    return GraphInterpreter(graph, cg).run(*new_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1337, in run_node
    result = lowering.codegen(self, n)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 754, in codegen
    return self.api_func._codegen(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/language/_tracing_ops.py", line 86, in _
    return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state)  # pyright: ignore[reportArgumentType,reportCallIssue]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_ir.py", line 237, in codegen
    return codegen_call_with_graph(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
    return GraphInterpreter(graph, cg).run(*new_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1380, in run_node
    raise InductorLoweringError(
helion.exc.InductorLoweringError: Error in codegen for node mul (<built-in function mul>): SourceOrigin
While processing:
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 30, in fn
    work_n = (worker_id + num_workers * work_item) % n_count
                          ^^^^^^^^^^^^^^^^^^^^^^^


While executing %mul : [num_users=1] = call_function[target=operator.mul](args = (%symnode, %u7), kwargs = {})
Original traceback:
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 30, in fn
    work_n = (worker_id + num_workers * work_item) % n_count

Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

While executing %_for_loop : [num_users=0] = call_function[target=helion.language._tracing_ops._for_loop](args = (1, [0], [%symnode], []), kwargs = {})
Original traceback:
  File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 29, in fn
    for work_item in hl.grid(work_per_worker):

Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

Expected behavior
it should work

Versions
main

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions