Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 165 additions & 14 deletions src/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,57 @@ impl<T> AnnotFormula<T> {
}
}

/// A path in an annotation.
#[derive(Debug, Clone)]
pub struct AnnotPath {
pub segments: Vec<AnnotPathSegment>,
}

impl AnnotPath {
pub fn to_datatype_ctor<V>(&self, ctor_args: Vec<chc::Term<V>>) -> (chc::Term<V>, chc::Sort) {
let mut segments = self.segments.clone();

let ctor = segments.pop().unwrap();
if !ctor.generic_args.is_empty() {
unimplemented!("generic arguments in datatype constructor");
}
let Some(ty_last_segment) = segments.last_mut() else {
unimplemented!("single segment path");
};
let generic_args: Vec<_> = ty_last_segment.generic_args.drain(..).collect();
let ty_path_idents: Vec<_> = segments
.into_iter()
.map(|segment| {
if !segment.generic_args.is_empty() {
unimplemented!("generic arguments in datatype constructor");
}
segment.ident.to_string()
})
.collect();
// see refine::datatype_symbol
let d_sym = chc::DatatypeSymbol::new(ty_path_idents.join("."));
let v_sym = chc::DatatypeSymbol::new(format!("{}.{}", d_sym, ctor.ident.as_str()));
let term = chc::Term::datatype_ctor(d_sym.clone(), generic_args.clone(), v_sym, ctor_args);
let sort = chc::Sort::datatype(d_sym, generic_args);
(term, sort)
}

pub fn single_segment_ident(&self) -> Option<&Ident> {
if self.segments.len() == 1 && self.segments[0].generic_args.is_empty() {
Some(&self.segments[0].ident)
} else {
None
}
}
}

/// A segment of a path in an annotation.
#[derive(Debug, Clone)]
pub struct AnnotPathSegment {
pub ident: Ident,
pub generic_args: Vec<chc::Sort>,
}

/// A trait for resolving variables in annotations to their logical representation and their sorts.
pub trait Resolver {
type Output;
Expand Down Expand Up @@ -298,6 +349,84 @@ where
}
}

fn parse_path_tail(&mut self, head: Ident) -> Result<AnnotPath> {
let mut segments: Vec<AnnotPathSegment> = Vec::new();
segments.push(AnnotPathSegment {
ident: head,
generic_args: Vec::new(),
});
while let Some(Token {
kind: TokenKind::ModSep,
..
}) = self.look_ahead_token(0)
{
self.consume();
match self.next_token("ident or <")? {
t @ Token {
kind: TokenKind::Lt,
..
} => {
if segments.is_empty() {
return Err(ParseAttrError::unexpected_token(
"path segment before <",
t.clone(),
));
}
let mut generic_args = Vec::new();
loop {
let sort = self.parse_sort()?;
generic_args.push(sort);
match self.next_token(", or >")? {
Token {
kind: TokenKind::Comma,
..
} => {}
Token {
kind: TokenKind::Gt,
..
} => break,
t => return Err(ParseAttrError::unexpected_token(", or >", t.clone())),
}
}
segments.last_mut().unwrap().generic_args = generic_args;
}
t @ Token {
kind: TokenKind::Ident(_, _),
..
} => {
let (ident, _) = t.ident().unwrap();
segments.push(AnnotPathSegment {
ident,
generic_args: Vec::new(),
});
}
t => return Err(ParseAttrError::unexpected_token("ident or <", t.clone())),
}
}
Ok(AnnotPath { segments })
}

fn parse_datatype_ctor_args(&mut self) -> Result<Vec<chc::Term<T::Output>>> {
let mut terms = Vec::new();
loop {
let formula_or_term = self.parse_formula_or_term()?;
let (t, _) = formula_or_term.into_term().ok_or_else(|| {
ParseAttrError::unexpected_formula("in datatype constructor arguments")
})?;
terms.push(t);
if let Some(Token {
kind: TokenKind::Comma,
..
}) = self.look_ahead_token(0)
{
self.consume();
} else {
break;
}
}
Ok(terms)
}

fn parse_atom(&mut self) -> Result<FormulaOrTerm<T::Output>> {
let tt = self.next_token_tree("term or formula")?.clone();

Expand All @@ -317,21 +446,43 @@ where
};

let formula_or_term = if let Some((ident, _)) = t.ident() {
match (
ident.as_str(),
self.formula_existentials.get(ident.name.as_str()),
) {
("true", _) => FormulaOrTerm::Formula(chc::Formula::top()),
("false", _) => FormulaOrTerm::Formula(chc::Formula::bottom()),
(_, Some(sort)) => {
let var =
chc::Term::FormulaExistentialVar(sort.clone(), ident.name.to_string());
FormulaOrTerm::Term(var, sort.clone())
}
_ => {
let (v, sort) = self.resolve(ident)?;
FormulaOrTerm::Term(chc::Term::var(v), sort)
let path = self.parse_path_tail(ident)?;
if let Some(ident) = path.single_segment_ident() {
match (
ident.as_str(),
self.formula_existentials.get(ident.name.as_str()),
) {
("true", _) => FormulaOrTerm::Formula(chc::Formula::top()),
("false", _) => FormulaOrTerm::Formula(chc::Formula::bottom()),
(_, Some(sort)) => {
let var =
chc::Term::FormulaExistentialVar(sort.clone(), ident.name.to_string());
FormulaOrTerm::Term(var, sort.clone())
}
_ => {
let (v, sort) = self.resolve(*ident)?;
FormulaOrTerm::Term(chc::Term::var(v), sort)
}
}
} else {
let next_tt = self
.next_token_tree("arguments for datatype constructor")?
.clone();
let TokenTree::Delimited(_, _, Delimiter::Parenthesis, s) = next_tt else {
return Err(ParseAttrError::unexpected_token_tree(
"arguments for datatype constructor",
next_tt.clone(),
));
};
let mut parser = Parser {
resolver: self.boxed_resolver(),
cursor: s.trees(),
formula_existentials: self.formula_existentials.clone(),
};
let args = parser.parse_datatype_ctor_args()?;
parser.end_of_input()?;
let (term, sort) = path.to_datatype_ctor(args);
FormulaOrTerm::Term(term, sort)
}
} else {
match t.kind {
Expand Down
20 changes: 20 additions & 0 deletions tests/ui/fail/annot_enum_simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//@error-in-other-file: Unsat

pub enum X {
A(i64),
B(bool),
}

#[thrust::requires(x == X::A(1))]
#[thrust::ensures(true)]
fn test(x: X) {
if let X::A(i) = x {
assert!(i == 2);
} else {
loop {}
}
}

fn main() {
test(X::A(1));
}
20 changes: 20 additions & 0 deletions tests/ui/pass/annot_enum_simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//@check-pass

pub enum X {
A(i64),
B(bool),
}

#[thrust::requires(x == X::A(1))]
#[thrust::ensures(true)]
fn test(x: X) {
if let X::A(i) = x {
assert!(i == 1);
} else {
loop {}
}
}

fn main() {
test(X::A(1));
}