diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 2dbb9ea..30e70d3 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -37,6 +37,10 @@ pub fn extern_spec_fn_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("extern_spec_fn")] } +pub fn raw_define_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("raw_define")] +} + /// A [`annot::Resolver`] implementation for resolving function parameters. /// /// The parameter names and their sorts needs to be configured via diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 9dd85f9..4172f94 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -2,12 +2,14 @@ use std::collections::HashSet; +use rustc_hir::def_id::CRATE_DEF_ID; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; use crate::analyze; use crate::chc; use crate::rty::{self, ClauseBuilderExt as _}; +use crate::annot; /// An implementation of local crate analysis. /// @@ -26,6 +28,21 @@ pub struct Analyzer<'tcx, 'ctx> { } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { + fn analyze_raw_define_annot(&mut self) { + for attrs in self.tcx.get_attrs_by_path( + CRATE_DEF_ID.to_def_id(), + &analyze::annot::raw_define_path(), + ) { + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let parser = annot::AnnotParser::new( + // TODO: this resolver is not actually used. + analyze::annot::ParamResolver::default() + ); + let raw_definition = parser.parse_raw_definition(ts).unwrap(); + self.ctx.system.borrow_mut().push_raw_definition(raw_definition); + } + } + fn refine_local_defs(&mut self) { for local_def_id in self.tcx.mir_keys(()) { if self.tcx.def_kind(*local_def_id).is_fn_like() { @@ -187,6 +204,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let span = tracing::debug_span!("crate", krate = %self.tcx.crate_name(rustc_span::def_id::LOCAL_CRATE)); let _guard = span.enter(); + self.analyze_raw_define_annot(); self.refine_local_defs(); self.analyze_local_defs(); self.assert_callable_entry(); diff --git a/src/annot.rs b/src/annot.rs index 128c289..4d5e758 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -8,7 +8,7 @@ //! The main entry point is [`AnnotParser`], which parses a [`TokenStream`] into a //! [`rty::RefinedType`] or a [`chc::Formula`]. -use rustc_ast::token::{BinOpToken, Delimiter, LitKind, Token, TokenKind}; +use rustc_ast::token::{BinOpToken, Delimiter, LitKind, Lit, Token, TokenKind}; use rustc_ast::tokenstream::{RefTokenTreeCursor, Spacing, TokenStream, TokenTree}; use rustc_index::IndexVec; use rustc_span::symbol::Ident; @@ -420,7 +420,7 @@ where Ok(AnnotPath { segments }) } - fn parse_datatype_ctor_args(&mut self) -> Result>> { + fn parse_arg_terms(&mut self) -> Result>> { if self.look_ahead_token(0).is_none() { return Ok(Vec::new()); } @@ -478,6 +478,28 @@ where FormulaOrTerm::Term(var, sort.clone()) } _ => { + // If the single-segment identifier is followed by parethesized arguments, + // parse them as user-defined predicate calls. + let next_tt = self.look_ahead_token_tree(0); + + if let Some(TokenTree::Delimited(_, _, Delimiter::Parenthesis, args)) = next_tt { + let args = args.clone(); + self.consume(); + + let pred_symbol = chc::UserDefinedPredSymbol::new(ident.name.to_string()); + let pred = chc::Pred::UserDefined(pred_symbol); + + let mut parser = Parser { + resolver: self.boxed_resolver(), + cursor: args.trees(), + formula_existentials: self.formula_existentials.clone(), + }; + let args = parser.parse_arg_terms()?; + + let atom = chc::Atom::new(pred, args); + let formula = chc::Formula::Atom(atom); + return Ok(FormulaOrTerm::Formula(formula)); + } let (v, sort) = self.resolve(*ident)?; FormulaOrTerm::Term(chc::Term::var(v), sort) } @@ -497,7 +519,7 @@ where cursor: s.trees(), formula_existentials: self.formula_existentials.clone(), }; - let args = parser.parse_datatype_ctor_args()?; + let args = parser.parse_arg_terms()?; parser.end_of_input()?; let (term, sort) = path.to_datatype_ctor(args); FormulaOrTerm::Term(term, sort) @@ -1076,6 +1098,32 @@ where .ok_or_else(|| ParseAttrError::unexpected_term("in annotation"))?; Ok(AnnotFormula::Formula(formula)) } + + pub fn parse_annot_raw_definition(&mut self) -> Result { + let t = self.next_token("raw CHC definition")?; + + match t { + Token { + kind: TokenKind::Literal( + Lit { kind, symbol, .. } + ), + .. + } => { + match kind { + LitKind::Str => { + let definition = symbol.to_string(); + Ok(chc::RawDefinition{ definition }) + }, + _ => Err(ParseAttrError::unexpected_token( + "string literal", t.clone() + )) + } + }, + _ => Err(ParseAttrError::unexpected_token( + "string literal", t.clone() + )) + } + } } /// A [`Resolver`] implementation for resolving specific variable as [`rty::RefinedTypeVar::Value`]. @@ -1208,4 +1256,15 @@ where parser.end_of_input()?; Ok(formula) } + + pub fn parse_raw_definition(&self, ts: TokenStream) -> Result { + let mut parser = Parser { + resolver: &self.resolver, + cursor: ts.trees(), + formula_existentials: Default::default(), + }; + let raw_definition = parser.parse_annot_raw_definition()?; + parser.end_of_input()?; + Ok(raw_definition) + } } diff --git a/src/chc.rs b/src/chc.rs index 5543de4..9a91aaa 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -902,12 +902,86 @@ impl MatcherPred { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct UserDefinedPredSymbol { + inner: String, +} + +impl std::fmt::Display for UserDefinedPredSymbol { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b UserDefinedPredSymbol +where + D: pretty::DocAllocator<'a, termcolor::ColorSpec>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { + allocator.text(self.inner.clone()) + } +} + +impl UserDefinedPredSymbol { + pub fn new(inner: String) -> Self { + Self { inner } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct UserDefinedPred { + symbol: UserDefinedPredSymbol, + args: Vec, +} + +impl<'a> std::fmt::Display for UserDefinedPred { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.symbol.inner) + } +} + +impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b UserDefinedPred +where + D: pretty::DocAllocator<'a, termcolor::ColorSpec>, + D::Doc: Clone, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { + let args = allocator.intersperse( + self.args.iter().map(|a| a.pretty(allocator)), + allocator.text(", "), + ); + allocator + .text("user_defined_pred") + .append( + allocator + .text(self.symbol.inner.clone()) + .append(args.angles()) + .angles(), + ) + .group() + } +} + +impl UserDefinedPred { + pub fn new(symbol: UserDefinedPredSymbol, args: Vec) -> Self { + Self { + symbol, + args, + } + } + + pub fn name(&self) -> &str { + &self.symbol.inner + } +} + /// A predicate. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Pred { Known(KnownPred), Var(PredVarId), Matcher(MatcherPred), + UserDefined(UserDefinedPredSymbol), } impl std::fmt::Display for Pred { @@ -916,6 +990,7 @@ impl std::fmt::Display for Pred { Pred::Known(p) => p.fmt(f), Pred::Var(p) => p.fmt(f), Pred::Matcher(p) => p.fmt(f), + Pred::UserDefined(p) => p.fmt(f), } } } @@ -930,6 +1005,7 @@ where Pred::Known(p) => p.pretty(allocator), Pred::Var(p) => p.pretty(allocator), Pred::Matcher(p) => p.pretty(allocator), + Pred::UserDefined(p) => p.pretty(allocator), } } } @@ -958,6 +1034,7 @@ impl Pred { Pred::Known(p) => p.name().into(), Pred::Var(p) => p.to_string().into(), Pred::Matcher(p) => p.name().into(), + Pred::UserDefined(p) => p.to_string().into(), } } @@ -966,6 +1043,7 @@ impl Pred { Pred::Known(p) => p.is_negative(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -974,6 +1052,7 @@ impl Pred { Pred::Known(p) => p.is_infix(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -982,6 +1061,7 @@ impl Pred { Pred::Known(p) => p.is_top(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -990,6 +1070,7 @@ impl Pred { Pred::Known(p) => p.is_bottom(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } } @@ -1606,6 +1687,14 @@ impl Clause { } } +/// A definition specified using #![thrust::define_raw()] +/// +/// Those will be directly inserted into the generated SMT-LIB2 file. +#[derive(Debug, Clone)] +pub struct RawDefinition { + pub definition: String, +} + /// A selector for a datatype constructor. /// /// A selector is a function that extracts a field from a datatype value. @@ -1655,6 +1744,7 @@ pub struct PredVarDef { /// A CHC system. #[derive(Debug, Clone, Default)] pub struct System { + pub raw_definitions: Vec, pub datatypes: Vec, pub clauses: IndexVec, pub pred_vars: IndexVec, @@ -1665,6 +1755,10 @@ impl System { self.pred_vars.push(PredVarDef { sig, debug_info }) } + pub fn push_raw_definition(&mut self, raw_definition: RawDefinition) { + self.raw_definitions.push(raw_definition) + } + pub fn push_clause(&mut self, clause: Clause) -> Option { if clause.is_nop() { return None; diff --git a/src/chc/format_context.rs b/src/chc/format_context.rs index 2123315..1c75215 100644 --- a/src/chc/format_context.rs +++ b/src/chc/format_context.rs @@ -21,6 +21,7 @@ use crate::chc::{self, hoice::HoiceDatatypeRenamer}; pub struct FormatContext { renamer: HoiceDatatypeRenamer, datatypes: Vec, + raw_definitions: Vec, } // FIXME: this is obviously ineffective and should be replaced @@ -273,13 +274,18 @@ impl FormatContext { .filter(|d| d.params == 0) .collect(); let renamer = HoiceDatatypeRenamer::new(&datatypes); - FormatContext { renamer, datatypes } + let raw_definitions = system.raw_definitions.clone(); + FormatContext { renamer, datatypes, raw_definitions } } pub fn datatypes(&self) -> &[chc::Datatype] { &self.datatypes } + pub fn raw_definitions(&self) -> &[chc::RawDefinition] { + &self.raw_definitions + } + pub fn box_ctor(&self, sort: &chc::Sort) -> impl std::fmt::Display { let ss = Sort::new(sort).sorts(); format!("box{ss}") diff --git a/src/chc/smtlib2.rs b/src/chc/smtlib2.rs index 167d108..0e0f2e9 100644 --- a/src/chc/smtlib2.rs +++ b/src/chc/smtlib2.rs @@ -370,6 +370,30 @@ impl<'ctx, 'a> Clause<'ctx, 'a> { } } +/// A wrapper around a [`chc::RawDefinition`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. +#[derive(Debug, Clone)] +pub struct RawDefinition<'a> { + inner: &'a chc::RawDefinition, +} + +impl<'a> std::fmt::Display for RawDefinition<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + self.inner.definition, + ) + } +} + +impl<'a> RawDefinition<'a> { + pub fn new(inner: &'a chc::RawDefinition) -> Self { + Self { + inner + } + } +} + /// A wrapper around a [`chc::DatatypeSelector`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct DatatypeSelector<'ctx, 'a> { @@ -555,6 +579,11 @@ impl<'a> std::fmt::Display for System<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "(set-logic HORN)\n")?; + // insert definition from #![thrust::define_chc()] here + for raw_def in self.ctx.raw_definitions() { + writeln!(f, "{}\n", RawDefinition::new(raw_def))?; + } + writeln!(f, "{}\n", Datatypes::new(&self.ctx, self.ctx.datatypes()))?; for datatype in self.ctx.datatypes() { writeln!(f, "{}", DatatypeDiscrFun::new(&self.ctx, datatype))?; diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 5be1240..da6d8c4 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -42,6 +42,7 @@ fn unbox_pred(pred: Pred) -> Pred { Pred::Known(pred) => Pred::Known(pred), Pred::Var(pred) => Pred::Var(pred), Pred::Matcher(pred) => unbox_matcher_pred(pred), + Pred::UserDefined(pred) => Pred::UserDefined(pred), } } @@ -161,6 +162,7 @@ pub fn unbox(system: System) -> System { clauses, pred_vars, datatypes, + raw_definitions, } = system; let datatypes = datatypes.into_iter().map(unbox_datatype).collect(); let clauses = clauses.into_iter().map(unbox_clause).collect(); @@ -169,5 +171,6 @@ pub fn unbox(system: System) -> System { clauses, pred_vars, datatypes, + raw_definitions, } } diff --git a/tests/ui/fail/annot_raw_define.rs b/tests/ui/fail/annot_raw_define.rs new file mode 100644 index 0000000..346a158 --- /dev/null +++ b/tests/ui/fail/annot_raw_define.rs @@ -0,0 +1,20 @@ +//@error-in-other-file: UnexpectedToken +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert definitions written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_define(true)] // argument must be single string literal + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +// #[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + // assert!(is_double(a, double(a))); +} diff --git a/tests/ui/fail/annot_raw_define_without_params.rs b/tests/ui/fail/annot_raw_define_without_params.rs new file mode 100644 index 0000000..d6683a8 --- /dev/null +++ b/tests/ui/fail/annot_raw_define_without_params.rs @@ -0,0 +1,20 @@ +//@error-in-other-file: invalid attribute +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert definitions written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_define] // argument must be single string literal + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +// #[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + // assert!(is_double(a, double(a))); +} diff --git a/tests/ui/pass/annot_preds_raw_define.rs b/tests/ui/pass/annot_preds_raw_define.rs new file mode 100644 index 0000000..bfcde39 --- /dev/null +++ b/tests/ui/pass/annot_preds_raw_define.rs @@ -0,0 +1,25 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert definitions written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_define("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +#[thrust::requires(true)] +// #[thrust::ensures(result == 2 * x)] +#[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + // assert!(is_double(a, double(a))); +} diff --git a/tests/ui/pass/annot_preds_raw_define_multi.rs b/tests/ui/pass/annot_preds_raw_define_multi.rs new file mode 100644 index 0000000..c0419f9 --- /dev/null +++ b/tests/ui/pass/annot_preds_raw_define_multi.rs @@ -0,0 +1,36 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +#![feature(custom_inner_attributes)] +#![thrust::raw_define("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +#![thrust::raw_define("(define-fun is_triple ((x Int) (tripled_x Int)) Bool + (= + (* x 3) + tripled_x + ) +)")] + +#[thrust::requires(true)] +#[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +#[thrust::require(is_double(x, doubled_x))] +#[thrust::ensures(is_triple(x, result))] +fn triple(x: i64, doubled_x: i64) -> i64 { + x + doubled_x +} + +fn main() { + let a = 3; + let double_a = double(a); + assert!(double_a == 6); + assert!(triple(a, double_a) == 9); +} diff --git a/tests/ui/pass/annot_raw_define.rs b/tests/ui/pass/annot_raw_define.rs new file mode 100644 index 0000000..c1c47a7 --- /dev/null +++ b/tests/ui/pass/annot_raw_define.rs @@ -0,0 +1,25 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert definitions written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_define("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +// #[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + // assert!(is_double(a, double(a))); +} diff --git a/tests/ui/pass/annot_raw_define_multi.rs b/tests/ui/pass/annot_raw_define_multi.rs new file mode 100644 index 0000000..ae9e0eb --- /dev/null +++ b/tests/ui/pass/annot_raw_define_multi.rs @@ -0,0 +1,33 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert definitions written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_define("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +// multiple raw definitions can be inserted. +#![thrust::raw_define("(define-fun is_triple ((x Int) (tripled_x Int)) Bool + (= + (* x 3) + tripled_x + ) +)")] + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +// #[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + // assert!(is_double(a, double(a))); +}