@@ -404,21 +404,14 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
404404 # XXX : byval is not round-trippable on LLVM < 12 (see maleadt/LLVM.jl#186)
405405 # so we need to re-classify the Julia arguments.
406406 # remove this once we only support 1.7.
407- has_kernel_state = kernel_state_type (job) != = Nothing
408- orig_ft = if has_kernel_state
409- # the kernel state has been added here already, so strip the first parameter
410- LLVM. FunctionType (LLVM. return_type (ft), parameters (ft)[2 : end ]; vararg= isvararg (ft))
411- else
412- ft
413- end
414- args = classify_arguments (job, orig_ft)
407+ args = classify_arguments (job, ft)
415408 filter! (args) do arg
416409 arg. cc != GHOST
417410 end
418411 for arg in args
419412 if arg. cc == BITS_REF
420413 # NOTE: +1 since this pass runs after introducing the kernel state
421- byval[arg. codegen. i+ has_kernel_state ] = true
414+ byval[arg. codegen. i] = true
422415 end
423416 end
424417 end
@@ -510,6 +503,7 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
510503 # NOTE: if we ever have legitimate uses of the old function, create a shim instead
511504 fn = LLVM. name (f)
512505 @assert isempty (uses (f))
506+ replace_metadata_uses! (f, new_f)
513507 unsafe_delete! (mod, f)
514508 LLVM. name! (new_f, fn)
515509
535529# so that the julia.gpu.state_getter` can be simplified to return an opaque pointer.
536530
537531# add a state argument to every function in the module, starting from the kernel entry point
538- function add_kernel_state! (@nospecialize (job :: CompilerJob ), mod:: LLVM.Module ,
539- entry :: LLVM.Function )
532+ function add_kernel_state! (mod:: LLVM.Module )
533+ job = current_job :: CompilerJob
540534 ctx = context (mod)
541- entry_fn = LLVM. name (entry)
542535
543536 # check if we even need a kernel state argument
544537 state = kernel_state_type (job)
@@ -552,12 +545,18 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
552545 # this is both for extern uses, and to make this transformation a two-step process.
553546 state_intr = kernel_state_intr (mod, T_state)
554547
548+ kernels = []
549+ kernels_md = metadata (mod)[" julia.kernel" ]
550+ for kernel_md in operands (kernels_md)
551+ push! (kernels, Value (operands (kernel_md)[1 ]; ctx))
552+ end
553+
555554 # determine which functions need a kernel state argument
556555 #
557556 # previously, we add the argument to every function and relied on unused arg elim to
558557 # clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
559558 # function pointers. such IR is hard to rewrite, so instead be more conservative.
560- worklist = Set {LLVM.Function} ([entry, state_intr ])
559+ worklist = Set {LLVM.Function} ([state_intr, kernels ... ])
561560 worklist_length = 0
562561 while worklist_length != length (worklist)
563562 # iteratively discover functions that use the intrinsic or any function calling it
@@ -669,6 +668,7 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
669668 error (" old function still has uses" )
670669 end
671670 end
671+ replace_metadata_uses! (f, workmap[f])
672672 unsafe_delete! (mod, f)
673673 end
674674
@@ -707,10 +707,12 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
707707 elseif val isa LLVM. CallBase
708708 # the function is being passed as an argument, which we'll just permit,
709709 # because we expect to have rewritten the call down the line separately.
710+ elseif val isa LLVM. StoreInst
711+ # the function is being stored, which again we'll permit like before.
710712 elseif val isa ConstantExpr
711713 rewrite_uses! (val)
712714 else
713- error (" Cannot rewrite unknown use of function: $val " )
715+ error (" Cannot rewrite $( typeof (val)) use of function: $val " )
714716 end
715717 end
716718 end
0 commit comments