-
Notifications
You must be signed in to change notification settings - Fork 41
Open
Description
Note: Please write your bug report in English to ensure it can be understood and addressed by the development team.
Describe the bug
Error when doing math on a python int
To Reproduce
Helion compiler triton codegen error for @helion.kernel(config=helion.Config(block_sizes=[1, 8192], indexing='pointer', loop_orders=[[0, 1]], num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None, None, None, None], range_multi_buffers=[None, None, None, None], range_num_stages=[0, 0, 0, 0], range_unroll_factors=[0, 0, 0, 0], range_warp_specializes=[]), static_shapes=False)
Traceback (most recent call last):
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1337, in run_node
result = lowering.codegen(self, n)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 784, in codegen
return expr_from_string(ctx.cg.device_function.user_sympy_expr(self.expr))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 383, in user_sympy_expr
return self.sympy_expr(expr)
^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 352, in sympy_expr
self._lift_sympy_arg(sym), integer=True
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 371, in _lift_sympy_arg
return self.expr_arg(expr, origin.origin).name
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 479, in expr_arg
name=self.new_var(origin.suggest_var_name()),
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/variable_origin.py", line 166, in suggest_var_name
return f"{self.value.suggest_var_name()}_attr_{self.key}"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/variable_origin.py", line 98, in suggest_var_name
raise NotImplementedError(type(self).__name__)
NotImplementedError: SourceOrigin
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 436, in compile_config
triton_code = self.to_triton_code(
^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 411, in to_triton_code
root = generate_ast(self.host_function, config, emit_repro_caller)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/generate_ast.py", line 407, in generate_ast
codegen.add_statement(codegen.visit(stmt))
^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/ast_extension.py", line 272, in visit
return visitor(node)
^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/generate_ast.py", line 234, in visit_For
codegen_call_with_graph(
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
return GraphInterpreter(graph, cg).run(*new_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 174, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1337, in run_node
result = lowering.codegen(self, n)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 754, in codegen
return self.api_func._codegen(
^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/language/_tracing_ops.py", line 86, in _
return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_ir.py", line 237, in codegen
return codegen_call_with_graph(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
return GraphInterpreter(graph, cg).run(*new_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 174, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1380, in run_node
raise InductorLoweringError(
helion.exc.InductorLoweringError: Error in codegen for node mul (<built-in function mul>): SourceOrigin
While processing:
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 30, in fn
work_n = (worker_id + num_workers * work_item) % n_count
^^^^^^^^^^^^^^^^^^^^^^^
While executing %mul : [num_users=1] = call_function[target=operator.mul](args = (%symnode, %u7), kwargs = {})
Original traceback:
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 30, in fn
work_n = (worker_id + num_workers * work_item) % n_count
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
While executing %_for_loop : [num_users=0] = call_function[target=helion.language._tracing_ops._for_loop](args = (1, [0], [%symnode], []), kwargs = {})
Original traceback:
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 29, in fn
for work_item in hl.grid(work_per_worker):
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
from typing import Any
from typing import Callable
import torch
import helion
import helion.language as hl
@helion.kernel(config = helion.Config(block_sizes=[1, 8192]))
def fn(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5
) -> tuple[torch.Tensor, torch.Tensor]:
m, n = x.size()
n = hl.specialize(n)
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"
out = torch.empty_like(x)
m_block = hl.register_block_size(m)
n_block = hl.register_block_size(n)
m_count = (m + m_block - 1) // m_block
n_count = (n + n_block - 1) // n_block
num_workers = torch.cuda.get_device_properties(x.device).multi_processor_count # type: ignore[arg-type]
num_workers = num_workers + num_workers + num_workers + num_workers
total_work = m_count * n_count
work_per_worker = (total_work + num_workers - 1) // num_workers
for worker_id in hl.grid(num_workers):
for work_item in hl.grid(work_per_worker):
work_n = (worker_id + num_workers * work_item) % n_count
work_m = (worker_id + num_workers * work_item) // n_count
work_n_start = work_n * n_block
work_n_end = min(work_n_start + n_block, n)
work_m_start = work_m * m_block
work_m_end = min(work_m_start + m_block, m)
for tile_n, tile_m in hl.tile([work_n_start, work_m_start], [work_n_end, work_m_end], block_size=[n_block, m_block]):
x_tile = x[tile_m, tile_n].to(torch.float32)
out[tile_m, tile_n] = x_tile.to(out.dtype)
return out
print(open(__file__).read(), '\n\n\n', '=' * 50, '\n\n')
fn(torch.randn(1000, 1024, device='cuda'), torch.randn(1024, device='cuda'))
==================================================
Traceback (most recent call last):
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1337, in run_node
result = lowering.codegen(self, n)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 784, in codegen
return expr_from_string(ctx.cg.device_function.user_sympy_expr(self.expr))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 383, in user_sympy_expr
return self.sympy_expr(expr)
^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 352, in sympy_expr
self._lift_sympy_arg(sym), integer=True
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 371, in _lift_sympy_arg
return self.expr_arg(expr, origin.origin).name
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_function.py", line 479, in expr_arg
name=self.new_var(origin.suggest_var_name()),
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/variable_origin.py", line 166, in suggest_var_name
return f"{self.value.suggest_var_name()}_attr_{self.key}"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/variable_origin.py", line 98, in suggest_var_name
raise NotImplementedError(type(self).__name__)
NotImplementedError: SourceOrigin
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 44, in <module>
fn(torch.randn(1000, 1024, device='cuda'), torch.randn(1024, device='cuda'))
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 286, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 620, in __call__
self.set_config(config)
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 525, in set_config
self._run = self.compile_config(config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 436, in compile_config
triton_code = self.to_triton_code(
^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/runtime/kernel.py", line 411, in to_triton_code
root = generate_ast(self.host_function, config, emit_repro_caller)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/generate_ast.py", line 407, in generate_ast
codegen.add_statement(codegen.visit(stmt))
^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/ast_extension.py", line 272, in visit
return visitor(node)
^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/generate_ast.py", line 234, in visit_For
codegen_call_with_graph(
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
return GraphInterpreter(graph, cg).run(*new_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 174, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1337, in run_node
result = lowering.codegen(self, n)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 754, in codegen
return self.api_func._codegen(
^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/language/_tracing_ops.py", line 86, in _
return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/device_ir.py", line 237, in codegen
return codegen_call_with_graph(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
return GraphInterpreter(graph, cg).run(*new_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 174, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/_compiler/inductor_lowering.py", line 1380, in run_node
raise InductorLoweringError(
helion.exc.InductorLoweringError: Error in codegen for node mul (<built-in function mul>): SourceOrigin
While processing:
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 30, in fn
work_n = (worker_id + num_workers * work_item) % n_count
^^^^^^^^^^^^^^^^^^^^^^^
While executing %mul : [num_users=1] = call_function[target=operator.mul](args = (%symnode, %u7), kwargs = {})
Original traceback:
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 30, in fn
work_n = (worker_id + num_workers * work_item) % n_count
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
While executing %_for_loop : [num_users=0] = call_function[target=helion.language._tracing_ops._for_loop](args = (1, [0], [%symnode], []), kwargs = {})
Original traceback:
File "/data/users/mhoehnerbach/projects/helion-day/helion/helion/repro2.py", line 29, in fn
for work_item in hl.grid(work_per_worker):
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
Expected behavior
it should work
Versions
main
Metadata
Metadata
Assignees
Labels
No labels