Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ impl<'tcx> Analyzer<'tcx> {
.variants()
.iter()
.map(|variant| {
let name = refine::datatype_symbol(self.tcx, variant.def_id);
// TODO: consider using TyCtxt::tag_for_variant
let discr = resolve_discr(self.tcx, variant.discr);
let field_tys = variant
Expand All @@ -222,7 +221,7 @@ impl<'tcx> Analyzer<'tcx> {
})
.collect();
rty::EnumVariantDef {
name,
name: chc::DatatypeSymbol::new(format!("{}.{}", name, variant.name)),
discr,
field_tys,
}
Expand Down
82 changes: 17 additions & 65 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,78 +562,30 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
});
}

fn is_box_new(&self, def_id: DefId) -> bool {
// TODO: stop using diagnositc item for semantic purpose
self.tcx.all_diagnostic_items(()).id_to_name.get(&def_id)
== Some(&rustc_span::symbol::sym::box_new)
}

fn is_mem_swap(&self, def_id: DefId) -> bool {
// TODO: stop using diagnositc item for semantic purpose
self.tcx.all_diagnostic_items(()).id_to_name.get(&def_id)
== Some(&rustc_span::symbol::sym::mem_swap)
}

fn type_call<I>(&mut self, func: Operand<'tcx>, args: I, expected_ret: &rty::RefinedType<Var>)
where
I: IntoIterator<Item = Operand<'tcx>>,
{
// TODO: handle const_fn_def on Env side
let func_ty = match func.const_fn_def() {
// TODO: move this to well-known defs?
Some((def_id, args)) if self.is_box_new(def_id) => {
let inner_ty = self
.type_builder
.for_template(&mut self.ctx)
.build(args.type_at(0))
.vacuous();
let param = rty::RefinedType::unrefined(inner_ty.clone());
let ret_term =
chc::Term::box_(chc::Term::var(rty::FunctionParamIdx::from(0_usize)));
let ret = rty::RefinedType::refined_with_term(
rty::PointerType::own(inner_ty).into(),
ret_term,
);
rty::FunctionType::new([param].into_iter().collect(), ret).into()
}
Some((def_id, args)) if self.is_mem_swap(def_id) => {
let inner_ty = self.type_builder.build(args.type_at(0)).vacuous();
let param1 =
rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into());
let param2 =
rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into());
let param1_var = rty::RefinedTypeVar::Free(rty::FunctionParamIdx::from(0_usize));
let param2_var = rty::RefinedTypeVar::Free(rty::FunctionParamIdx::from(1_usize));
let ret1 = chc::Term::var(param1_var)
.mut_current()
.equal_to(chc::Term::var(param2_var).mut_final());
let ret2 = chc::Term::var(param2_var)
.mut_current()
.equal_to(chc::Term::var(param1_var).mut_final());
let ret_formula = chc::Formula::Atom(ret1).and(chc::Formula::Atom(ret2));
let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into());
rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into()
let func_ty = if let Some((def_id, args)) = func.const_fn_def() {
let param_env = self.tcx.param_env(self.local_def_id);
let instance = mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
let resolved_def_id = if let Some(instance) = instance {
instance.def_id()
} else {
def_id
};
if def_id != resolved_def_id {
tracing::info!(?def_id, ?resolved_def_id, "resolve",);
}
Some((def_id, args)) => {
let param_env = self.tcx.param_env(self.local_def_id);
let instance =
mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
let resolved_def_id = if let Some(instance) = instance {
instance.def_id()
} else {
def_id
};
if def_id != resolved_def_id {
tracing::info!(?def_id, ?resolved_def_id, "resolve",);
}

self.ctx
.def_ty_with_args(resolved_def_id, args)
.expect("unknown def")
.ty
.vacuous()
}
_ => self.operand_type(func.clone()).ty,
self.ctx
.def_ty_with_args(resolved_def_id, args)
.expect("unknown def")
.ty
.vacuous()
} else {
self.operand_type(func.clone()).ty
};
let expected_args: IndexVec<_, _> = args
.into_iter()
Expand Down
9 changes: 7 additions & 2 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
let typeck_result = self.tcx.typeck(self.outer_def_id);
if let rustc_hir::def::Res::Def(_, def_id) = typeck_result.qpath_res(qpath, hir_id)
{
assert!(self.inner_def_id.is_none(), "invalid extern_spec_fn");
self.inner_def_id = Some(def_id);
if matches!(
self.tcx.def_kind(def_id),
rustc_hir::def::DefKind::Fn | rustc_hir::def::DefKind::AssocFn
) {
assert!(self.inner_def_id.is_none(), "invalid extern_spec_fn");
self.inner_def_id = Some(def_id);
}
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/chc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ impl DatatypeSort {
pub fn new(symbol: DatatypeSymbol, args: Vec<Sort>) -> Self {
DatatypeSort { symbol, args }
}

pub fn args_mut(&mut self) -> &mut Vec<Sort> {
&mut self.args
}
}

/// A sort is the type of a logical term.
Expand Down
26 changes: 26 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#![feature(rustc_private)]

extern crate rustc_ast;
extern crate rustc_driver;
extern crate rustc_interface;
extern crate rustc_parse;
extern crate rustc_session;
extern crate rustc_span;

use rustc_driver::{Callbacks, Compilation, RunCompiler};
use rustc_interface::interface::{Compiler, Config};
Expand All @@ -17,6 +20,29 @@ impl Callbacks for CompilerCalls {
attrs.push("register_tool(thrust)".to_owned());
}

fn after_crate_root_parsing<'tcx>(
&mut self,
compiler: &Compiler,
queries: &'tcx Queries<'tcx>,
) -> Compilation {
let mut result = queries.parse().unwrap();
let krate = result.get_mut();

let injected = include_str!("../std.rs");
let mut parser = rustc_parse::new_parser_from_source_str(
&compiler.sess.psess,
rustc_span::FileName::Custom("thrust std injected".to_string()),
injected.to_owned(),
);
while let Some(item) = parser
.parse_item(rustc_parse::parser::ForceCollect::No)
.unwrap()
{
krate.items.push(item);
}
Compilation::Continue
}

fn after_analysis<'tcx>(
&mut self,
_compiler: &Compiler,
Expand Down
2 changes: 2 additions & 0 deletions src/refine/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ where
.with_scope(&builder)
.build_refined(param_ty.ty)
}
} else if self.param_refinement.is_some() {
rty::RefinedType::unrefined(self.inner.build(param_ty.ty).vacuous())
} else {
rty::RefinedType::unrefined(
self.inner
Expand Down
110 changes: 110 additions & 0 deletions src/rty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,13 @@ impl<FV> Refinement<FV> {
refinement: self,
}
}

pub fn subst_ty_params_in_sorts<T>(&mut self, subst: &TypeParamSubst<T>) {
for sort in &mut self.existentials {
subst_ty_params_in_sort(sort, subst);
}
subst_ty_params_in_body(&mut self.body, subst);
}
}

/// A helper type to map logical variables in a refinement at once.
Expand Down Expand Up @@ -1445,6 +1452,7 @@ impl<FV> RefinedType<FV> {
where
FV: chc::Var,
{
self.refinement.subst_ty_params_in_sorts(subst);
match &mut self.ty {
Type::Int | Type::Bool | Type::String | Type::Never => {}
Type::Param(ty) => {
Expand Down Expand Up @@ -1513,6 +1521,108 @@ impl RefinedType<Closed> {
}
}

/// Substitutes type parameters in a sort.
fn subst_ty_params_in_sort<T>(sort: &mut chc::Sort, subst: &TypeParamSubst<T>) {
match sort {
chc::Sort::Null | chc::Sort::Int | chc::Sort::Bool | chc::Sort::String => {}
chc::Sort::Param(idx) => {
let type_param_idx = TypeParamIdx::from_usize(*idx);
if let Some(rty) = subst.get(type_param_idx) {
*sort = rty.ty.to_sort();
}
}
chc::Sort::Box(s) | chc::Sort::Mut(s) => {
subst_ty_params_in_sort(s, subst);
}
chc::Sort::Tuple(sorts) => {
for s in sorts {
subst_ty_params_in_sort(s, subst);
}
}
chc::Sort::Datatype(dt_sort) => {
for s in dt_sort.args_mut() {
subst_ty_params_in_sort(s, subst);
}
}
}
}

/// Substitutes type parameters in all sorts appearing in a body.
fn subst_ty_params_in_body<T, V>(body: &mut chc::Body<V>, subst: &TypeParamSubst<T>) {
for atom in &mut body.atoms {
subst_ty_params_in_atom(atom, subst);
}
subst_ty_params_in_formula(&mut body.formula, subst);
}

/// Substitutes type parameters in all sorts appearing in an atom.
fn subst_ty_params_in_atom<T, V>(atom: &mut chc::Atom<V>, subst: &TypeParamSubst<T>) {
if let Some(guard) = &mut atom.guard {
subst_ty_params_in_formula(guard, subst);
}
for term in &mut atom.args {
subst_ty_params_in_term(term, subst);
}
}

/// Substitutes type parameters in all sorts appearing in a formula.
fn subst_ty_params_in_formula<T, V>(formula: &mut chc::Formula<V>, subst: &TypeParamSubst<T>) {
match formula {
chc::Formula::Atom(atom) => subst_ty_params_in_atom(atom, subst),
chc::Formula::Not(f) => subst_ty_params_in_formula(f, subst),
chc::Formula::And(fs) | chc::Formula::Or(fs) => {
for f in fs {
subst_ty_params_in_formula(f, subst);
}
}
chc::Formula::Exists(vars, f) => {
for (_, sort) in vars {
subst_ty_params_in_sort(sort, subst);
}
subst_ty_params_in_formula(f, subst);
}
}
}

/// Substitutes type parameters in all sorts appearing in a term.
fn subst_ty_params_in_term<T, V>(term: &mut chc::Term<V>, subst: &TypeParamSubst<T>) {
match term {
chc::Term::Null
| chc::Term::Var(_)
| chc::Term::Bool(_)
| chc::Term::Int(_)
| chc::Term::String(_) => {}
chc::Term::Box(t)
| chc::Term::BoxCurrent(t)
| chc::Term::MutCurrent(t)
| chc::Term::MutFinal(t)
| chc::Term::TupleProj(t, _)
| chc::Term::DatatypeDiscr(_, t) => {
subst_ty_params_in_term(t, subst);
}
chc::Term::Mut(t1, t2) => {
subst_ty_params_in_term(t1, subst);
subst_ty_params_in_term(t2, subst);
}
chc::Term::App(_, args) | chc::Term::Tuple(args) => {
for arg in args {
subst_ty_params_in_term(arg, subst);
}
}
chc::Term::DatatypeCtor(s, _, args) => {
for arg in s.args_mut() {
subst_ty_params_in_sort(arg, subst);
}
for arg in args {
subst_ty_params_in_term(arg, subst);
}
}
chc::Term::FormulaExistentialVar(sort, _) => {
subst_ty_params_in_sort(sort, subst);
}
}
}

pub fn unify_tys_params<I1, I2, T>(tys1: I1, tys2: I2) -> TypeParamSubst<T>
where
T: chc::Var,
Expand Down
Loading