@@ -116,31 +116,39 @@ function get_trampoline(job)
116116 return addr
117117end
118118
119- # import GPUCompiler: deferred_codegen_jobs
120- # @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
121- # # manual version of native_job because we have a function type
122- # source = methodinstance(F, Base.to_tuple_type(tt), world)
123- # target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
124- # # XXX : do we actually require the Julia runtime?
125- # # with jlruntime=false, we reach an unreachable.
126- # params = TestCompilerParams()
127- # config = CompilerConfig(target, params; kernel=false)
128- # job = CompilerJob(source, config, world)
129- # # XXX : invoking GPUCompiler from a generated function is not allowed!
130- # # for things to work, we need to forward the correct world, at least.
131-
132- # addr = get_trampoline(job)
133- # trampoline = pointer(addr)
134- # id = Base.reinterpret(Int, trampoline)
135-
136- # deferred_codegen_jobs[id] = job
137-
138- # quote
139- # ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
140- # assume(ptr != C_NULL)
141- # return ptr
142- # end
143- # end
119+ const runtime_cache = Dict {Any, Ptr{Cvoid}} ()
120+
121+ function compiler (job)
122+ JuliaContext () do _
123+ ir, meta = GPUCompiler. compile (:llvm , job; validate= false )
124+ # So 1. serialize the module
125+ buf = convert (MemoryBuffer, ir)
126+ buf, LLVM. name (meta. entry)
127+ end
128+ end
129+
130+ function linker (_, (buf, entry_fn))
131+ compiler = jit[]
132+ lljit = compiler. jit
133+ jd = JITDylib (lljit)
134+
135+ # 2. deserialize and wrap by a ThreadSafeModule
136+ ThreadSafeContext () do ts_ctx
137+ tsm = context! (context (ts_ctx)) do
138+ mod = parse (LLVM. Module, buf)
139+ ThreadSafeModule (mod)
140+ end
141+
142+ LLVM. add! (lljit, jd, tsm)
143+ end
144+ addr = LLVM. lookup (lljit, entry_fn)
145+ pointer (addr)
146+ end
147+
148+ function GPUCompiler. var"gpuc.deferred.with" (config:: GPUCompiler.CompilerConfig{<:NativeCompilerTarget} , f:: F , args... ) where F
149+ source = methodinstance (F, Base. to_tuple_type (typeof (args)))
150+ GPUCompiler. cached_compilation (runtime_cache, source, config, compiler, linker):: Ptr{Cvoid}
151+ end
144152
145153@generated function abi_call (f:: Ptr{Cvoid} , rt:: Type{RT} , tt:: Type{T} , func:: F , args:: Vararg{Any, N} ) where {T, RT, F, N}
146154 argtt = tt. parameters[1 ]
226234 rt = Core. Compiler. return_type (f, tt)
227235 # FIXME : Horrible idea, have `var"gpuc.deferred"` actually do the work
228236 # But that will only be needed here, and in Enzyme...
229- ptr = GPUCompiler. var"gpuc.deferred" (f, args... )
237+ target = NativeCompilerTarget (; jlruntime= true , llvm_always_inline= true )
238+ # XXX : do we actually require the Julia runtime?
239+ # with jlruntime=false, we reach an unreachable.
240+ params = TestCompilerParams ()
241+ config = CompilerConfig (target, params; kernel= false )
242+ ptr = GPUCompiler. var"gpuc.deferred.with" (config, f, args... )
230243 abi_call (ptr, rt, tt, f, args... )
231244end
232245
0 commit comments