diff --git a/gcd.rs b/gcd.rs new file mode 100644 index 0000000..bd3e07f --- /dev/null +++ b/gcd.rs @@ -0,0 +1,18 @@ +fn gcd(mut a: i32, mut b: i32) -> i32 { + while a != b { + let (l, r) = if a < b { + (&mut a, &b) + } else { + (&mut b, &a) + }; + *l -= *r; + } + a +} + +#[thrust::callable] +fn check_gcd(a: i32, b: i32) { + assert!(gcd(a, b) <= a); +} + +fn main() {} \ No newline at end of file diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 2dbb9ea..68fcbaf 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -33,6 +33,10 @@ pub fn callable_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("callable")] } +pub fn predicate_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("predicate")] +} + pub fn extern_spec_fn_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("extern_spec_fn")] } diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 9dd85f9..ef7e10a 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -1,8 +1,8 @@ //! Analyze a local crate. -use std::collections::HashSet; +use std::collections::{HashSet, HashMap}; -use rustc_middle::ty::{self as mir_ty, TyCtxt}; +use rustc_middle::ty::{self as mir_ty, TyCtxt, FnSig}; use rustc_span::def_id::{DefId, LocalDefId}; use crate::analyze; @@ -23,6 +23,7 @@ pub struct Analyzer<'tcx, 'ctx> { tcx: TyCtxt<'tcx>, ctx: &'ctx mut analyze::Analyzer<'tcx>, trusted: HashSet, + predicates: HashMap>, } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { @@ -45,6 +46,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(local_def_id.to_def_id()); } + if analyzer.is_annotated_as_predicate() { + self.predicates.insert(local_def_id.to_def_id(), sig); + } + if analyzer.is_annotated_as_extern_spec_fn() { assert!(analyzer.is_fully_annotated()); self.trusted.insert(local_def_id.to_def_id()); @@ -73,6 +78,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { tracing::info!(?local_def_id, "trusted"); continue; } + if self.predicates.contains_key(&local_def_id.to_def_id()) { + let sig = self.predicates.get(&local_def_id.to_def_id()).unwrap(); + tracing::info!(?local_def_id, ?sig, "predicate"); + continue; + } let Some(expected) = self.ctx.concrete_def_ty(local_def_id.to_def_id()) else { // when the local_def_id is deferred it would be skipped continue; @@ -180,7 +190,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { pub fn new(ctx: &'ctx mut analyze::Analyzer<'tcx>) -> Self { let tcx = ctx.tcx; let trusted = HashSet::default(); - Self { ctx, tcx, trusted } + let predicates = HashMap::default(); + Self { ctx, tcx, trusted, predicates } } pub fn run(&mut self) { diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index d556ef0..5a80165 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -126,6 +126,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .is_some() } + pub fn is_annotated_as_predicate(&self) -> bool { + self.tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::predicate_path(), + ) + .next() + .is_some() + } + pub fn is_annotated_as_extern_spec_fn(&self) -> bool { self.tcx .get_attrs_by_path( diff --git a/src/chc.rs b/src/chc.rs index 5543de4..972cdb3 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -902,12 +902,87 @@ impl MatcherPred { } } +// TODO: DatatypeSymbolをほぼそのままコピーする形になっているので、エイリアスなどで共通化すべき? +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct UserDefinedPredSymbol { + inner: String, +} + +impl std::fmt::Display for UserDefinedPredSymbol { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b UserDefinedPredSymbol +where + D: pretty::DocAllocator<'a, termcolor::ColorSpec>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { + allocator.text(self.inner.clone()) + } +} + +impl UserDefinedPredSymbol { + pub fn new(inner: String) -> Self { + Self { inner } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct UserDefinedPred { + symbol: UserDefinedPredSymbol, + args: Vec, +} + +impl<'a> std::fmt::Display for UserDefinedPred { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.symbol.inner) + } +} + +impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b UserDefinedPred +where + D: pretty::DocAllocator<'a, termcolor::ColorSpec>, + D::Doc: Clone, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { + let args = allocator.intersperse( + self.args.iter().map(|a| a.pretty(allocator)), + allocator.text(", "), + ); + allocator + .text("user_defined_pred") + .append( + allocator + .text(self.symbol.inner.clone()) + .append(args.angles()) + .angles(), + ) + .group() + } +} + +impl UserDefinedPred { + pub fn new(symbol: UserDefinedPredSymbol, args: Vec) -> Self { + Self { + symbol, + args, + } + } + + pub fn name(&self) -> &str { + &self.symbol.inner + } +} + /// A predicate. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Pred { Known(KnownPred), Var(PredVarId), Matcher(MatcherPred), + UserDefined(UserDefinedPredSymbol), } impl std::fmt::Display for Pred { @@ -916,6 +991,7 @@ impl std::fmt::Display for Pred { Pred::Known(p) => p.fmt(f), Pred::Var(p) => p.fmt(f), Pred::Matcher(p) => p.fmt(f), + Pred::UserDefined(p) => p.fmt(f), } } } @@ -930,6 +1006,7 @@ where Pred::Known(p) => p.pretty(allocator), Pred::Var(p) => p.pretty(allocator), Pred::Matcher(p) => p.pretty(allocator), + Pred::UserDefined(p) => p.pretty(allocator), } } } @@ -958,6 +1035,7 @@ impl Pred { Pred::Known(p) => p.name().into(), Pred::Var(p) => p.to_string().into(), Pred::Matcher(p) => p.name().into(), + Pred::UserDefined(p) => p.inner.clone().into(), } } @@ -966,6 +1044,7 @@ impl Pred { Pred::Known(p) => p.is_negative(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -974,6 +1053,7 @@ impl Pred { Pred::Known(p) => p.is_infix(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -982,6 +1062,7 @@ impl Pred { Pred::Known(p) => p.is_top(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -990,6 +1071,7 @@ impl Pred { Pred::Known(p) => p.is_bottom(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } } diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 5be1240..0c3308d 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -42,6 +42,7 @@ fn unbox_pred(pred: Pred) -> Pred { Pred::Known(pred) => Pred::Known(pred), Pred::Var(pred) => Pred::Var(pred), Pred::Matcher(pred) => unbox_matcher_pred(pred), + Pred::UserDefined(pred) => Pred::UserDefined(pred), } } diff --git a/tests/ui/pass/annot_preds.rs b/tests/ui/pass/annot_preds.rs new file mode 100644 index 0000000..79bf978 --- /dev/null +++ b/tests/ui/pass/annot_preds.rs @@ -0,0 +1,24 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +#[thrust::predicate] +fn is_double(x: i64, doubled_x: i64) -> bool { + x * 2 == doubled_x + // "(=( + // (* (x 2)) + // doubled_x + // ))" +} + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +// #[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + assert!(is_double(a, double(a))); +}