@@ -36,7 +36,7 @@ struct NeverInlineMeta <: InlineStateMeta end
3636import GPUCompiler: abstract_call_known, GPUInterpreter
3737import Core. Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
3838 StmtInfo, AbsIntState, EFFECTS_TOTAL,
39- MethodResultPure
39+ MethodResultPure, CallInfo, IRCode
4040
4141function abstract_call_known (meta:: InlineStateMeta , interp:: GPUInterpreter , @nospecialize (f),
4242 arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
@@ -69,5 +69,179 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
6969 return nothing
7070end
7171
72+ struct MockEnzymeMeta end
7273
73- end
74+ # Having to define this function is annoying
75+ # introduce `abstract type InferenceMeta`
76+ function inlining_handler (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (atype), callinfo)
77+ return nothing
78+ end
79+
80+ function autodiff end
81+
82+ import GPUCompiler: DeferredCallInfo
83+ struct AutodiffCallInfo <: CallInfo
84+ rt
85+ info:: DeferredCallInfo
86+ end
87+
88+ function abstract_call_known (meta:: Nothing , interp:: GPUInterpreter , f:: typeof (autodiff),
89+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
90+ (; fargs, argtypes) = arginfo
91+
92+ @assert f === autodiff
93+ if length (argtypes) <= 1
94+ @static if VERSION < v " 1.11.0-"
95+ return CallMeta (Union{}, Effects (), NoCallInfo ())
96+ else
97+ return CallMeta (Union{}, Union{}, Effects (), NoCallInfo ())
98+ end
99+ end
100+
101+ other_fargs = fargs === nothing ? nothing : fargs[2 : end ]
102+ other_arginfo = ArgInfo (other_fargs, argtypes[2 : end ])
103+ # TODO : Ought we not change absint to use MockEnzymeMeta(), otherwise we fill the cache for nothing.
104+ call = Core. Compiler. abstract_call (interp, other_arginfo, si, sv, max_methods)
105+ callinfo = DeferredCallInfo (MockEnzymeMeta (), call. rt, call. info)
106+
107+ # Real Enzyme must compute `rt` and `exct` according to enzyme semantics
108+ # and likely perform a unwrapping of fargs...
109+ rt = call. rt
110+
111+ # TODO : Edges? Effects?
112+ @static if VERSION < v " 1.11.0-"
113+ # Can't use call.effects since otherwise this call might be just replaced with rt
114+ return CallMeta (rt, Effects (), AutodiffCallInfo (rt, callinfo))
115+ else
116+ return CallMeta (rt, call. exct, Effects (), AutodiffCallInfo (rt, callinfo))
117+ end
118+ end
119+
120+ function abstract_call_known (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (f),
121+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
122+ return nothing
123+ end
124+
125+ import Core. Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature
126+
127+ # We really need a Compiler stdlib
128+ Base. getindex (ir:: IRCode , i) = Core. Compiler. getindex (ir, i)
129+ Base. setindex! (inst:: Instruction , val, i) = Core. Compiler. setindex! (inst, val, i)
130+
131+ const FlagType = VERSION >= v " 1.11.0-" ? UInt32 : UInt8
132+ function Core. Compiler. handle_call! (todo:: Vector{Pair{Int,Any}} , ir:: IRCode , stmt_idx:: Int ,
133+ stmt:: Expr , info:: AutodiffCallInfo , flag:: FlagType ,
134+ sig:: Signature , state:: InliningState )
135+
136+ # Goal:
137+ # The IR we want to inline here is:
138+ # unpack the args ..
139+ # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
140+ # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
141+
142+ # 0. Obtain primal mi from DeferredCallInfo
143+ # TODO : remove this code duplication
144+ deferred_info = info. info
145+ minfo = deferred_info. info
146+ results = minfo. results
147+ if length (results. matches) != 1
148+ return nothing
149+ end
150+ match = only (results. matches)
151+
152+ # lookup the target mi with correct edge tracking
153+ # TODO : Effects?
154+ case = Core. Compiler. compileable_specialization (
155+ match, Core. Compiler. Effects (), Core. Compiler. InliningEdgeTracker (state), info)
156+ @assert case isa Core. Compiler. InvokeCase
157+ @assert stmt. head === :call
158+
159+ # Now create the IR we want to inline
160+ ir = Core. Compiler. IRCode () # contains a placeholder
161+ args = [Core. Compiler. Argument (i) for i in 2 : length (stmt. args)] # f, args...
162+ idx = 0
163+
164+ # 0. Enzyme proper: Desugar args
165+ primal_args = args
166+ primal_argtypes = match. spec_types. parameters[2 : end ]
167+
168+ adjoint_rt = info. rt
169+ adjoint_args = args # TODO
170+ adjoint_argtypes = primal_argtypes
171+
172+ # 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
173+ expr = Expr (:foreigncall ,
174+ " extern gpuc.lookup" ,
175+ Ptr{Cvoid},
176+ Core. svec (#= meta=# Any, #= mi=# Any, #= f=# Any, primal_argtypes... ), # Must use Any for MethodInstance or ftype
177+ 0 ,
178+ QuoteNode (:llvmcall ),
179+ deferred_info. meta,
180+ case. invoke,
181+ primal_args...
182+ )
183+ ptr = insert_node! (ir, (idx += 1 ), NewInstruction (expr, Ptr{Cvoid}))
184+
185+ # 2. Call to magic `__autodiff`
186+ expr = Expr (:foreigncall ,
187+ " extern __autodiff" ,
188+ adjoint_rt,
189+ Core. svec (Ptr{Cvoid}, Any, adjoint_argtypes... ),
190+ 0 ,
191+ QuoteNode (:llvmcall ),
192+ ptr,
193+ adjoint_args...
194+ )
195+ ret = insert_node! (ir, idx, NewInstruction (expr, adjoint_rt))
196+
197+ # Finally replace placeholder return
198+ ir[Core. SSAValue (1 )][:inst ] = Core. ReturnNode (ret)
199+ ir[Core. SSAValue (1 )][:type ] = Ptr{Cvoid}
200+
201+ ir = Core. Compiler. compact! (ir)
202+
203+ # which mi to use here?
204+ # push inlining todos
205+ # TODO : Effects
206+ # aviatesk mentioned using inlining_policy instead...
207+ itodo = Core. Compiler. InliningTodo (case. invoke, ir, Core. Compiler. Effects ())
208+ @assert itodo. linear_inline_eligible
209+ push! (todo, (stmt_idx=> itodo))
210+
211+ return nothing
212+ end
213+
214+ function mock_enzyme! (@nospecialize (job), intrinsic, mod:: LLVM.Module )
215+ changed = false
216+
217+ for use in LLVM. uses (intrinsic)
218+ call = LLVM. user (use)
219+ LLVM. @dispose builder= LLVM. IRBuilder () begin
220+ LLVM. position! (builder, call)
221+ ops = LLVM. operands (call)
222+ target = ops[1 ]
223+ if target isa LLVM. ConstantExpr && (LLVM. opcode (target) == LLVM. API. LLVMPtrToInt ||
224+ LLVM. opcode (target) == LLVM. API. LLVMBitCast)
225+ target = first (LLVM. operands (target))
226+ end
227+ funcT = LLVM. called_type (call)
228+ funcT = LLVM. FunctionType (LLVM. return_type (funcT), LLVM. parameters (funcT)[3 : end ])
229+ direct_call = LLVM. call! (builder, funcT, target,
230+ [ops[i] for i in 3 : length (ops)])
231+
232+ LLVM. replace_uses! (call, direct_call)
233+ end
234+ if isempty (LLVM. uses (call))
235+ LLVM. erase! (call)
236+ changed = true
237+ else
238+ # the validator will detect this
239+ end
240+ end
241+
242+ return changed
243+ end
244+
245+ GPUCompiler. register_plugin! (" __autodiff" , mock_enzyme!)
246+
247+ end # module
0 commit comments