Skip to content
233 changes: 139 additions & 94 deletions Compiler/src/abstractinterpretation.jl

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions Compiler/src/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,10 @@ mutable struct InferenceState
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
if argtyp === Bool && has_conditional(typeinf_lattice(interp))
argtyp = Conditional(i, Const(true), Const(false))
argtyp = Conditional(i, #= ssadef =# 0, Const(true), Const(false))
end
slottypes[i] = argtyp
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
bb_vartable1[i] = VarState(argtyp, #= ssadef =# 0, i > nargtypes)
end
src.ssavaluetypes = ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
ssaflags = copy(src.ssaflags)
Expand Down Expand Up @@ -764,7 +764,7 @@ function sptypes_from_meth_instance(mi::MethodInstance)
ty = Const(v)
undef = false
end
sptypes[i] = VarState(ty, undef)
sptypes[i] = VarState(ty, typemin(Int), undef)
end
return sptypes
end
Expand Down
2 changes: 1 addition & 1 deletion Compiler/src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractIn
bb_vartables = Union{VarTable,Nothing}[]
for block = 1:length(cfg.blocks)
push!(bb_vartables, VarState[
VarState(slottypes[slot], src.slotflags[slot] & SLOT_USEDUNDEF != 0)
VarState(slottypes[slot], typemin(Int), src.slotflags[slot] & SLOT_USEDUNDEF != 0)
for slot = 1:nslots
])
end
Expand Down
2 changes: 1 addition & 1 deletion Compiler/src/reflection_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end

function statement_costs!(interp::AbstractInterpreter, cost::Vector{Int}, body::Vector{Any}, src::Union{CodeInfo, IRCode}, match::Core.MethodMatch)
params = OptimizationParams(interp)
sptypes = VarState[VarState(sp, false) for sp in match.sparams]
sptypes = VarState[VarState(sp, #= ssadef =# typemin(Int), false) for sp in match.sparams]
return statement_costs!(cost, body, src, sptypes, params)
end

Expand Down
2 changes: 1 addition & 1 deletion Compiler/src/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end

function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sstate::StatementState, irsv::IRInterpretationState)
si = StmtInfo(true, sstate.saw_latestworld) # TODO better job here?
call = abstract_call(interp, arginfo, si, irsv)::Future
call = abstract_call(interp, arginfo, si, sstate.vtypes, irsv)::Future
Future{Any}(call, interp, irsv) do call, interp, irsv
irsv.ir.stmts[irsv.curridx][:info] = call.info
nothing
Expand Down
13 changes: 7 additions & 6 deletions Compiler/src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ end

function not_tfunc(𝕃::AbstractLattice, @nospecialize(b))
if isa(b, Conditional)
return Conditional(b.slot, b.elsetype, b.thentype)
return Conditional(b.slot, b.ssadef, b.elsetype, b.thentype)
elseif isa(b, Const)
return Const(not_int(b.val))
end
Expand Down Expand Up @@ -354,14 +354,14 @@ end
if isa(x, Conditional)
y = widenconditional(y)
if isa(y, Const)
y.val === false && return Conditional(x.slot, x.elsetype, x.thentype)
y.val === false && return Conditional(x.slot, x.ssadef, x.elsetype, x.thentype)
y.val === true && return x
return Const(false)
end
elseif isa(y, Conditional)
x = widenconditional(x)
if isa(x, Const)
x.val === false && return Conditional(y.slot, y.elsetype, y.thentype)
x.val === false && return Conditional(y.slot, y.ssadef, y.elsetype, y.thentype)
x.val === true && return y
return Const(false)
end
Expand Down Expand Up @@ -1363,7 +1363,7 @@ end
return Bool
end

@nospecs function abstract_modifyop!(interp::AbstractInterpreter, ff, argtypes::Vector{Any}, si::StmtInfo, sv::AbsIntState)
@nospecs function abstract_modifyop!(interp::AbstractInterpreter, ff, argtypes::Vector{Any}, si::StmtInfo, vtypes::Union{VarTable,Nothing}, sv::AbsIntState)
if ff === modifyfield!
minargs = 5
maxargs = 6
Expand Down Expand Up @@ -1424,7 +1424,7 @@ end
# as well as compute the info for the method matches
op = unwrapva(argtypes[op_argi])
v = unwrapva(argtypes[v_argi])
callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), StmtInfo(true, si.saw_latestworld), sv, #=max_methods=#1)
callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), StmtInfo(true, si.saw_latestworld), vtypes, sv, #=max_methods=#1)
TF = Core.Box(TF)
RT = Core.Box(RT)
return Future{CallMeta}(callinfo, interp, sv) do callinfo, interp, sv
Expand Down Expand Up @@ -3113,7 +3113,8 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
old_restrict = sv.restrict_abstract_call_sites
sv.restrict_abstract_call_sites = false
end
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, sv, #=max_methods=#-1)
# TODO: vtypes?
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, nothing, sv, #=max_methods=#-1)
tt = Core.Box(tt)
return Future{CallMeta}(call, interp, sv) do call, interp, sv
if isa(sv, InferenceState)
Expand Down
2 changes: 1 addition & 1 deletion Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ function type_annotate!(interp::AbstractInterpreter, sv::InferenceState)
for slot in 1:nslots
vt = varstate[slot]
widened_type = widenslotwrapper(ignorelimited(vt.typ))
varstate[slot] = VarState(widened_type, vt.undef)
varstate[slot] = VarState(widened_type, vt.ssadef, vt.undef)
end
end
end
Expand Down
90 changes: 60 additions & 30 deletions Compiler/src/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ the type of `SlotNumber(cnd.slot)` will be limited by `cnd.thentype`
and in the false branch, it will be limited by `cnd.elsetype`.
Example:
```julia
let cond = isa(x::Union{Int, Float}, Int)::Conditional(x, Int, Float)
let cond = isa(x::Union{Int, Float}, Int)::Conditional(x, _, Int, Float)
if cond
# May assume x is `Int` now
else
Expand All @@ -43,27 +43,30 @@ end
"""
struct Conditional
slot::Int
ssadef::Int
thentype
elsetype
# `isdefined` indicates this `Conditional` is from `@isdefined slot`, implying that
# the `undef` information of `slot` can be improved in the then branch.
# Since this is only beneficial for local inference, it is not translated into `InterConditional`.
isdefined::Bool
function Conditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype);
function Conditional(slot::Int, ssadef::Int, @nospecialize(thentype), @nospecialize(elsetype);
isdefined::Bool=false)
assert_nested_slotwrapper(thentype)
assert_nested_slotwrapper(elsetype)
limited = may_form_limited_typ(thentype, elsetype, Bool)
limited !== nothing && return limited
return new(slot, thentype, elsetype, isdefined)
return new(slot, ssadef, thentype, elsetype, isdefined)
end
end
Conditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype); isdefined::Bool=false) =
Conditional(slot_id(var), thentype, elsetype; isdefined)
Conditional(var::SlotNumber, ssadef::Int, @nospecialize(thentype), @nospecialize(elsetype); isdefined::Bool=false) =
Conditional(slot_id(var), ssadef, thentype, elsetype; isdefined)

const AnyConditional = Union{Conditional,InterConditional}
Conditional(cnd::InterConditional) = Conditional(cnd.slot, cnd.thentype, cnd.elsetype)
InterConditional(cnd::Conditional) = InterConditional(cnd.slot, cnd.thentype, cnd.elsetype)
function InterConditional(cnd::Conditional)
@assert cnd.ssadef == 0
InterConditional(cnd.slot, cnd.thentype, cnd.elsetype)
end

"""
alias::MustAlias
Expand All @@ -90,21 +93,22 @@ N.B. currently this lattice element is only used in abstractinterpret, not in op
"""
struct MustAlias
slot::Int
ssadef::Int
vartyp::Any
fldidx::Int
fldtyp::Any
function MustAlias(slot::Int, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp))
function MustAlias(slot::Int, ssadef::Int, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp))
assert_nested_slotwrapper(vartyp)
assert_nested_slotwrapper(fldtyp)
# @assert !isalreadyconst(vartyp) "vartyp is already const"
# @assert !isalreadyconst(fldtyp) "fldtyp is already const"
limited = may_form_limited_typ(vartyp, fldtyp, fldtyp)
limited !== nothing && return limited
return new(slot, vartyp, fldidx, fldtyp)
return new(slot, ssadef, vartyp, fldidx, fldtyp)
end
end
MustAlias(var::SlotNumber, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp)) =
MustAlias(slot_id(var), vartyp, fldidx, fldtyp)
MustAlias(var::SlotNumber, ssadef::Int, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp)) =
MustAlias(slot_id(var), ssadef, vartyp, fldidx, fldtyp)

"""
alias::InterMustAlias
Expand All @@ -130,8 +134,10 @@ InterMustAlias(var::SlotNumber, @nospecialize(vartyp), fldidx::Int, @nospecializ
InterMustAlias(slot_id(var), vartyp, fldidx, fldtyp)

const AnyMustAlias = Union{MustAlias,InterMustAlias}
MustAlias(alias::InterMustAlias) = MustAlias(alias.slot, alias.vartyp, alias.fldidx, alias.fldtyp)
InterMustAlias(alias::MustAlias) = InterMustAlias(alias.slot, alias.vartyp, alias.fldidx, alias.fldtyp)
function InterMustAlias(alias::MustAlias)
@assert alias.ssadef == 0
InterMustAlias(alias.slot, alias.vartyp, alias.fldidx, alias.fldtyp)
end

struct PartialTypeVar
tv::TypeVar
Expand All @@ -145,8 +151,20 @@ end
struct StateUpdate
var::SlotNumber
vtype::VarState
conditional::Bool
StateUpdate(var::SlotNumber, vtype::VarState, conditional::Bool=false) = new(var, vtype, conditional)
end

"""
Similar to `StateUpdate`, except with the additional guarantee that object identity
is preserved by the update (i.e. `x (before) === x (after)`).
"""
struct StateRefinement
slot::Int
# XXX: This should be an intersection of the old type with the new
# (i.e. newtyp ⊑ oldtyp)
newtyp
undef::Bool

StateRefinement(slot::Int, @nospecialize(newtyp), undef::Bool) = new(slot, newtyp, undef)
end

"""
Expand Down Expand Up @@ -284,6 +302,7 @@ end
return false
end

is_same_conditionals(a::Conditional, b::Conditional) = a.slot == b.slot && a.ssadef == b.ssadef
is_same_conditionals(a::C, b::C) where C<:AnyConditional = a.slot == b.slot

@nospecializeinfer is_lattice_bool(lattice::AbstractLattice, @nospecialize(typ)) = typ !== Bottom && ⊑(lattice, typ, Bool)
Expand Down Expand Up @@ -332,7 +351,7 @@ end
end

@nospecializeinfer function form_mustalias_conditional(alias::MustAlias, @nospecialize(thentype), @nospecialize(elsetype))
(; slot, vartyp, fldidx) = alias
(; slot, ssadef, vartyp, fldidx) = alias
if isa(vartyp, PartialStruct)
fields = vartyp.fields
thenfields = thentype === Bottom ? nothing : copy(fields)
Expand All @@ -343,7 +362,7 @@ end
elsefields === nothing || (elsefields[fldidx] = elsetype)
undefs[fldidx] = false
end
return Conditional(slot,
return Conditional(slot, ssadef,
thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undefs, thenfields),
elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undefs, elsefields))
else
Expand All @@ -360,7 +379,7 @@ end
elsefields === nothing || push!(elsefields, t)
end
end
return Conditional(slot,
return Conditional(slot, ssadef,
thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp_widened, thenfields),
elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp_widened, elsefields))
end
Expand Down Expand Up @@ -713,34 +732,39 @@ widenconst(::LimitedAccuracy) = error("unhandled LimitedAccuracy")
# state management #
####################

function smerge(lattice::AbstractLattice, sa::Union{NotFound,VarState}, sb::Union{NotFound,VarState})
function smerge(lattice::AbstractLattice, sa::Union{NotFound,VarState}, sb::Union{NotFound,VarState}, join_pc::Int)
sa === sb && return sa
sa === NOT_FOUND && return sb
sb === NOT_FOUND && return sa
return VarState(tmerge(lattice, sa.typ, sb.typ), sa.undef | sb.undef)
return VarState(tmerge(lattice, sa.typ, sb.typ), sa.ssadef == sb.ssadef ? sa.ssadef : join_pc, sa.undef | sb.undef)
end

@nospecializeinfer @inline schanged(lattice::AbstractLattice, @nospecialize(n), @nospecialize(o)) =
(n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !(n.undef <= o.undef && ⊑(lattice, n.typ, o.typ))))
@nospecializeinfer @inline schanged(lattice::AbstractLattice, @nospecialize(n), @nospecialize(o), join_pc::Int) =
(n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !(n.undef <= o.undef && (n.ssadef === o.ssadef || o.ssadef === join_pc) && ⊑(lattice, n.typ, o.typ))))

# remove any lattice elements that wrap the reassigned slot object from the vartable
function invalidate_slotwrapper(vt::VarState, changeid::Int, ignore_conditional::Bool)
function invalidate_slotwrapper(vt::VarState, changeid::Int)
newtyp = ignorelimited(vt.typ)
if (!ignore_conditional && isa(newtyp, Conditional) && newtyp.slot == changeid) ||
(isa(newtyp, MustAlias) && newtyp.slot == changeid)
if ((isa(newtyp, Conditional) && newtyp.slot == changeid) ||
(isa(newtyp, MustAlias) && newtyp.slot == changeid))
newtyp = @noinline widenwrappedslotwrapper(vt.typ)
return VarState(newtyp, vt.undef)
return VarState(newtyp, vt.ssadef, vt.undef)
end
return nothing
end

function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable)
function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable, join_pc::Int)
changed = false
for i = 1:length(state)
newtype = changes[i]
oldtype = state[i]
if schanged(lattice, newtype, oldtype)
state[i] = smerge(lattice, oldtype, newtype)
# In addition to computing the type, the merge here computes the "reaching definition"
# for a slot. The provided `join_pc` is a "virtual" PC, which corresponds to the ϕ-block
# that would exist at the beginning of the BasicBlock.
#
# This effectively applies the "path-convergence criterion" for SSA construction.
if schanged(lattice, newtype, oldtype, join_pc)
state[i] = smerge(lattice, oldtype, newtype, join_pc)
changed = true
end
end
Expand All @@ -757,7 +781,7 @@ end
function stoverwrite1!(state::VarTable, change::StateUpdate)
changeid = slot_id(change.var)
for i = 1:length(state)
invalidated = invalidate_slotwrapper(state[i], changeid, change.conditional)
invalidated = invalidate_slotwrapper(state[i], changeid)
if invalidated !== nothing
state[i] = invalidated
end
Expand All @@ -768,6 +792,12 @@ function stoverwrite1!(state::VarTable, change::StateUpdate)
return state
end

function strefine1!(state::VarTable, refinement::StateRefinement)
(; newtyp, undef, slot) = refinement
state[slot] = VarState(newtyp, state[slot].ssadef, undef)
return state
end

# The ::AbstractLattice argument is unused and simply serves to disambiguate
# different instances of the compiler that may share the `Core.PartialStruct`
# type.
Expand Down
10 changes: 5 additions & 5 deletions Compiler/src/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -500,24 +500,24 @@ end
# type-lattice for Conditional wrapper (NOTE never be merged with InterConditional)
if isa(typea, Conditional) && isa(typeb, Const)
if typeb.val === true
typeb = Conditional(typea.slot, Any, Union{})
typeb = Conditional(typea.slot, typea.ssadef, Any, Union{})
elseif typeb.val === false
typeb = Conditional(typea.slot, Union{}, Any)
typeb = Conditional(typea.slot, typea.ssadef, Union{}, Any)
end
end
if isa(typeb, Conditional) && isa(typea, Const)
if typea.val === true
typea = Conditional(typeb.slot, Any, Union{})
typea = Conditional(typeb.slot, typeb.ssadef, Any, Union{})
elseif typea.val === false
typea = Conditional(typeb.slot, Union{}, Any)
typea = Conditional(typeb.slot, typeb.ssadef, Union{}, Any)
end
end
if isa(typea, Conditional) && isa(typeb, Conditional)
if is_same_conditionals(typea, typeb)
thentype = tmerge(widenlattice(lattice), typea.thentype, typeb.thentype)
elsetype = tmerge(widenlattice(lattice), typea.elsetype, typeb.elsetype)
if thentype !== elsetype
return Conditional(typea.slot, thentype, elsetype)
return Conditional(typea.slot, typea.ssadef, thentype, elsetype)
end
end
val = maybe_extract_const_bool(typea)
Expand Down
11 changes: 10 additions & 1 deletion Compiler/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,23 @@ SpecInfo(src::CodeInfo) = SpecInfo(
A special wrapper that represents a local variable of a method being analyzed.
This does not participate in the native type system nor the inference lattice, and it thus
should be always unwrapped to `v.typ` when performing any type or lattice operations on it.

`v.undef` represents undefined-ness of this static parameter. If `true`, it means that the
variable _may_ be undefined at runtime, otherwise it is guaranteed to be defined.
If `v.typ === Bottom` it means that the variable is strictly undefined.

`v.ssadef` represents the "reaching definition" for the variable.
If zero, then the value comes from an argument.
If negative, this refers to a "virtual ϕ-block" preceding the given index,
that would have been inserted as the value of this slot in a truly SSA-form IR.
If a slot has the same `ssadef` at two different points of execution,
the slot contents are guaranteed to share identity (`x₀ === x₁`).
"""
struct VarState
typ
ssadef::Int
undef::Bool
VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
VarState(@nospecialize(typ), ssadef::Int, undef::Bool) = new(typ, ssadef, undef)
end

struct AnalysisResults
Expand Down
6 changes: 3 additions & 3 deletions Compiler/test/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,10 @@ Compiler.nsplit_impl(info::NoinlineCallInfo) = Compiler.nsplit(info.info)
Compiler.getsplit_impl(info::NoinlineCallInfo, idx::Int) = Compiler.getsplit(info.info, idx)
Compiler.getresult_impl(info::NoinlineCallInfo, idx::Int) = Compiler.getresult(info.info, idx)

function Compiler.abstract_call(interp::NoinlineInterpreter,
arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, sv::Compiler.InferenceState, max_methods::Int)
function Compiler.abstract_call(interp::NoinlineInterpreter, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo,
vtypes::Union{Compiler.VarTable,Nothing}, sv::Compiler.InferenceState, max_methods::Int)
ret = @invoke Compiler.abstract_call(interp::Compiler.AbstractInterpreter,
arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, sv::Compiler.InferenceState, max_methods::Int)
arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, vtypes::Union{Compiler.VarTable,Nothing}, sv::Compiler.InferenceState, max_methods::Int)
return Compiler.Future{Compiler.CallMeta}(ret, interp, sv) do ret, interp, sv
if sv.mod in noinline_modules(interp)
(;rt, exct, effects, info) = ret
Expand Down
Loading