11# Julia compiler integration
22
3-
43# # world age lookups
54
65# `tls_world_age` should be used to look up the current world age. in most cases, this is
1211 tls_world_age () = ccall (:jl_get_tls_world_age , UInt, ())
1312end
1413
14+
1515# # looking up method instances
1616
1717export methodinstance, generic_methodinstance
159159
160160
161161# # code instance cache
162+
162163const HAS_INTEGRATED_CACHE = VERSION >= v " 1.11.0-DEV.1552"
163164
164165if ! HAS_INTEGRATED_CACHE
318319 get_method_table_view (world:: UInt , mt:: MTType ) = OverlayMethodTable (world, mt)
319320end
320321
321- struct GPUInterpreter <: CC.AbstractInterpreter
322+ abstract type AbstractGPUInterpreter <: CC.AbstractInterpreter end
323+ struct GPUInterpreter <: AbstractGPUInterpreter
322324 world:: UInt
323325 method_table:: GPUMethodTableView
324326
@@ -436,6 +438,112 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
436438end
437439
438440
441+ # # deferred compilation
442+
443+ struct DeferredCallInfo <: CC.CallInfo
444+ rt:: DataType
445+ info:: CC.CallInfo
446+ end
447+
448+ # recognize calls to gpuc.deferred and save DeferredCallInfo metadata
449+ function CC. abstract_call_known (interp:: AbstractGPUInterpreter , @nospecialize (f),
450+ arginfo:: CC.ArgInfo , si:: CC.StmtInfo , sv:: CC.AbsIntState ,
451+ max_methods:: Int = CC. get_max_methods (interp, f, sv))
452+ (; fargs, argtypes) = arginfo
453+ if f === var"gpuc.deferred"
454+ argvec = argtypes[2 : end ]
455+ call = CC. abstract_call (interp, CC. ArgInfo (nothing , argvec), si, sv, max_methods)
456+ callinfo = DeferredCallInfo (call. rt, call. info)
457+ @static if VERSION < v " 1.11.0-"
458+ return CC. CallMeta (Ptr{Cvoid}, CC. Effects (), callinfo)
459+ else
460+ return CC. CallMeta (Ptr{Cvoid}, Union{}, CC. Effects (), callinfo)
461+ end
462+ end
463+ return @invoke CC. abstract_call_known (interp:: CC.AbstractInterpreter , f,
464+ arginfo:: CC.ArgInfo , si:: CC.StmtInfo , sv:: CC.AbsIntState ,
465+ max_methods:: Int )
466+ end
467+
468+ # during inlining, refine deferred calls to gpuc.lookup foreigncalls
469+ const FlagType = VERSION >= v " 1.11.0-" ? UInt32 : UInt8
470+ function CC. handle_call! (todo:: Vector{Pair{Int,Any}} , ir:: CC.IRCode , idx:: CC.Int ,
471+ stmt:: Expr , info:: DeferredCallInfo , flag:: FlagType ,
472+ sig:: CC.Signature , state:: CC.InliningState )
473+ minfo = info. info
474+ results = minfo. results
475+ if length (results. matches) != 1
476+ return nothing
477+ end
478+ match = only (results. matches)
479+
480+ # lookup the target mi with correct edge tracking
481+ case = CC. compileable_specialization (match, CC. Effects (), CC. InliningEdgeTracker (state),
482+ info)
483+ @assert case isa CC. InvokeCase
484+ @assert stmt. head === :call
485+
486+ args = Any[
487+ " extern gpuc.lookup" ,
488+ Ptr{Cvoid},
489+ Core. svec (Any, Any, match. spec_types. parameters[2 : end ]. .. ), # Must use Any for MethodInstance or ftype
490+ 0 ,
491+ QuoteNode (:llvmcall ),
492+ case. invoke,
493+ stmt. args[2 : end ]. ..
494+ ]
495+ stmt. head = :foreigncall
496+ stmt. args = args
497+ return nothing
498+ end
499+
500+ struct DeferredEdges
501+ edges:: Vector{MethodInstance}
502+ end
503+
504+ function find_deferred_edges (ir:: CC.IRCode )
505+ edges = MethodInstance[]
506+ # XXX : can we add this instead in handle_call?
507+ for stmt in ir. stmts
508+ inst = stmt[:inst ]
509+ inst isa Expr || continue
510+ expr = inst:: Expr
511+ if expr. head === :foreigncall &&
512+ expr. args[1 ] == " extern gpuc.lookup"
513+ deferred_mi = expr. args[6 ]
514+ push! (edges, deferred_mi)
515+ end
516+ end
517+ unique! (edges)
518+ return edges
519+ end
520+
521+ if VERSION >= v " 1.11.0-"
522+ function CC. ipo_dataflow_analysis! (interp:: AbstractGPUInterpreter , ir:: CC.IRCode ,
523+ caller:: CC.InferenceResult )
524+ edges = find_deferred_edges (ir)
525+ if ! isempty (edges)
526+ CC. stack_analysis_result! (caller, DeferredEdges (edges))
527+ end
528+ @invoke CC. ipo_dataflow_analysis! (interp:: CC.AbstractInterpreter , ir:: CC.IRCode ,
529+ caller:: CC.InferenceResult )
530+ end
531+ else # v1.10
532+ # 1.10 doesn't have stack_analysis_result or ipo_dataflow_analysis
533+ function CC. finish (interp:: AbstractGPUInterpreter , opt:: CC.OptimizationState , ir:: CC.IRCode ,
534+ caller:: CC.InferenceResult )
535+ edges = find_deferred_edges (ir)
536+ if ! isempty (edges)
537+ # HACK: we store the deferred edges in the argescapes field, which is invalid,
538+ # but nobody should be running EA on our results.
539+ caller. argescapes = DeferredEdges (edges)
540+ end
541+ @invoke CC. finish (interp:: CC.AbstractInterpreter , opt:: CC.OptimizationState ,
542+ ir:: CC.IRCode , caller:: CC.InferenceResult )
543+ end
544+ end
545+
546+
439547# # world view of the cache
440548using Core. Compiler: WorldView
441549
@@ -584,6 +692,24 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
584692 error (" Cannot compile $(job. source) for world $(job. world) ; method is only valid in worlds $(job. source. def. primary_world) to $(job. source. def. deleted_world) " )
585693 end
586694
695+ compiled = IdDict ()
696+ llvm_mod, outstanding = compile_method_instance (job, compiled)
697+ worklist = outstanding
698+ while ! isempty (worklist)
699+ source = pop! (worklist)
700+ haskey (compiled, source) && continue
701+ job2 = CompilerJob (source, job. config)
702+ @debug " Processing..." job2
703+ llvm_mod2, outstanding = compile_method_instance (job2, compiled)
704+ append! (worklist, outstanding)
705+ @assert context (llvm_mod) == context (llvm_mod2)
706+ link! (llvm_mod, llvm_mod2)
707+ end
708+
709+ return llvm_mod, compiled
710+ end
711+
712+ function compile_method_instance (@nospecialize (job:: CompilerJob ), compiled:: IdDict{Any, Any} )
587713 # populate the cache
588714 interp = get_interpreter (job)
589715 cache = CC. code_cache (interp)
@@ -594,7 +720,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
594720
595721 # create a callback to look-up function in our cache,
596722 # and keep track of the method instances we needed.
597- method_instances = []
723+ method_instances = Any []
598724 if Sys. ARCH == :x86 || Sys. ARCH == :x86_64
599725 function lookup_fun (mi, min_world, max_world)
600726 push! (method_instances, mi)
@@ -659,7 +785,6 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
659785 end
660786
661787 # process all compiled method instances
662- compiled = Dict ()
663788 for mi in method_instances
664789 ci = ci_cache_lookup (cache, mi, job. world, job. world)
665790 ci === nothing && continue
@@ -696,10 +821,34 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
696821 compiled[mi] = (; ci, func= llvm_func, specfunc= llvm_specfunc)
697822 end
698823
824+ # Collect the deferred edges
825+ outstanding = Any[]
826+ for mi in method_instances
827+ ! haskey (compiled, mi) && continue # Equivalent to ci_cache_lookup == nothing
828+ ci = compiled[mi]. ci
829+ @static if VERSION >= v " 1.11.0-"
830+ edges = CC. traverse_analysis_results (ci) do @nospecialize result
831+ return result isa DeferredEdges ? result : return
832+ end
833+ else
834+ edges = ci. argescapes
835+ if ! (edges isa Union{Nothing, DeferredEdges})
836+ edges = nothing
837+ end
838+ end
839+ if edges != = nothing
840+ for deferred_mi in (edges:: DeferredEdges ). edges
841+ if ! haskey (compiled, deferred_mi)
842+ push! (outstanding, deferred_mi)
843+ end
844+ end
845+ end
846+ end
847+
699848 # ensure that the requested method instance was compiled
700849 @assert haskey (compiled, job. source)
701850
702- return llvm_mod, compiled
851+ return llvm_mod, outstanding
703852end
704853
705854# partially revert JuliaLangjulia#49391
0 commit comments