Skip to content

Commit 51084c6

Browse files
authored
Support defining predicates using #[thrust::predicate] attribute with fn statements (#23)
* add: test for annotations of predicates * add: definitions to check if functions are marked as predicates * add: gather functions marked as predicates into Analyzer::predicates * add: logging for found predicates * add: UserDefinedPredDefs in chc::System * add: parse `#[thrust::predicate]` and register user-defined predicate definitions * add: format `UserDefinedPredDef` in SMT-LIB2 * fix: test for `#[thrust::predicate]` * fix: use String for predicate body * fix: insert raw commands and user-defined predicates before datatype defintions in .smt2 file
1 parent 7a5e42b commit 51084c6

7 files changed

Lines changed: 195 additions & 9 deletions

File tree

src/analyze/annot.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ pub fn raw_command_path() -> [Symbol; 2] {
4141
[Symbol::intern("thrust"), Symbol::intern("raw_command")]
4242
}
4343

44+
pub fn predicate_path() -> [Symbol; 2] {
45+
[Symbol::intern("thrust"), Symbol::intern("predicate")]
46+
}
47+
4448
/// A [`annot::Resolver`] implementation for resolving function parameters.
4549
///
4650
/// The parameter names and their sorts needs to be configured via

src/analyze/crate_.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub struct Analyzer<'tcx, 'ctx> {
2424
tcx: TyCtxt<'tcx>,
2525
ctx: &'ctx mut analyze::Analyzer<'tcx>,
2626
trusted: HashSet<DefId>,
27+
predicates: HashSet<DefId>,
2728
}
2829

2930
impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
@@ -82,6 +83,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
8283
self.trusted.insert(local_def_id.to_def_id());
8384
}
8485

86+
if analyzer.is_annotated_as_predicate() {
87+
self.predicates.insert(local_def_id.to_def_id());
88+
analyzer.analyze_predicate_definition(local_def_id);
89+
}
90+
8591
use mir_ty::TypeVisitableExt as _;
8692
if sig.has_param() && !analyzer.is_fully_annotated() {
8793
self.ctx.register_deferred_def(local_def_id.to_def_id());
@@ -105,6 +111,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
105111
tracing::info!(?local_def_id, "trusted");
106112
continue;
107113
}
114+
if self.predicates.contains(&local_def_id.to_def_id()) {
115+
tracing::info!(?local_def_id, "predicate");
116+
continue;
117+
}
108118
let Some(expected) = self.ctx.concrete_def_ty(local_def_id.to_def_id()) else {
109119
// when the local_def_id is deferred it would be skipped
110120
continue;
@@ -212,7 +222,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
212222
pub fn new(ctx: &'ctx mut analyze::Analyzer<'tcx>) -> Self {
213223
let tcx = ctx.tcx;
214224
let trusted = HashSet::default();
215-
Self { ctx, tcx, trusted }
225+
let predicates = HashSet::default();
226+
Self {
227+
ctx,
228+
tcx,
229+
trusted,
230+
predicates,
231+
}
216232
}
217233

218234
pub fn run(&mut self) {

src/analyze/local_def.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,27 @@ use crate::pretty::PrettyDisplayExt as _;
1616
use crate::refine::{BasicBlockType, TypeBuilder};
1717
use crate::rty;
1818

19+
fn stmt_str_literal(stmt: &rustc_hir::Stmt) -> Option<String> {
20+
use rustc_ast::LitKind;
21+
use rustc_hir::{Expr, ExprKind, Stmt, StmtKind};
22+
23+
match stmt {
24+
Stmt {
25+
kind:
26+
StmtKind::Semi(Expr {
27+
kind:
28+
ExprKind::Lit(rustc_hir::Lit {
29+
node: LitKind::Str(symbol, _),
30+
..
31+
}),
32+
..
33+
}),
34+
..
35+
} => Some(symbol.to_string()),
36+
_ => None,
37+
}
38+
}
39+
1940
/// An implementation of the typing of local definitions.
2041
///
2142
/// The current implementation only applies to function definitions. The entry point is
@@ -106,6 +127,49 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
106127
ret_annot
107128
}
108129

130+
pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) {
131+
let pred_name = self.tcx.item_name(local_def_id.to_def_id()).to_string();
132+
133+
// function's body
134+
use rustc_hir::{Block, Expr, ExprKind};
135+
136+
let hir_map = self.tcx.hir();
137+
let body_id = hir_map.maybe_body_owned_by(local_def_id).unwrap();
138+
let hir_body = hir_map.body(body_id);
139+
140+
let predicate_body = match hir_body.value {
141+
Expr {
142+
kind: ExprKind::Block(Block { stmts, .. }, _),
143+
..
144+
} => stmts
145+
.iter()
146+
.find_map(stmt_str_literal)
147+
.expect("invalid predicate definition: no string literal was found."),
148+
_ => panic!("expected function body, got: {hir_body:?}"),
149+
};
150+
151+
// names and sorts of arguments
152+
let arg_names = self
153+
.tcx
154+
.fn_arg_names(local_def_id.to_def_id())
155+
.iter()
156+
.map(|ident| ident.to_string());
157+
158+
let sig = self.ctx.local_fn_sig(local_def_id);
159+
let arg_sorts = sig
160+
.inputs()
161+
.iter()
162+
.map(|input_ty| self.type_builder.build(*input_ty).to_sort());
163+
164+
let arg_name_and_sorts = arg_names.into_iter().zip(arg_sorts).collect::<Vec<_>>();
165+
166+
self.ctx.system.borrow_mut().push_pred_define(
167+
chc::UserDefinedPred::new(pred_name),
168+
chc::UserDefinedPredSig::from(arg_name_and_sorts),
169+
predicate_body,
170+
);
171+
}
172+
109173
pub fn is_annotated_as_trusted(&self) -> bool {
110174
self.tcx
111175
.get_attrs_by_path(
@@ -136,6 +200,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
136200
.is_some()
137201
}
138202

203+
pub fn is_annotated_as_predicate(&self) -> bool {
204+
self.tcx
205+
.get_attrs_by_path(
206+
self.local_def_id.to_def_id(),
207+
&analyze::annot::predicate_path(),
208+
)
209+
.next()
210+
.is_some()
211+
}
212+
139213
// TODO: unify this logic with extraction functions above
140214
pub fn is_fully_annotated(&self) -> bool {
141215
let has_require = self

src/chc.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,11 +1704,21 @@ pub struct PredVarDef {
17041704
pub debug_info: DebugInfo,
17051705
}
17061706

1707+
pub type UserDefinedPredSig = Vec<(String, Sort)>;
1708+
1709+
#[derive(Debug, Clone)]
1710+
pub struct UserDefinedPredDef {
1711+
symbol: UserDefinedPred,
1712+
sig: UserDefinedPredSig,
1713+
body: String,
1714+
}
1715+
17071716
/// A CHC system.
17081717
#[derive(Debug, Clone, Default)]
17091718
pub struct System {
17101719
pub raw_commands: Vec<RawCommand>,
17111720
pub datatypes: Vec<Datatype>,
1721+
pub user_defined_pred_defs: Vec<UserDefinedPredDef>,
17121722
pub clauses: IndexVec<ClauseId, Clause>,
17131723
pub pred_vars: IndexVec<PredVarId, PredVarDef>,
17141724
}
@@ -1722,6 +1732,16 @@ impl System {
17221732
self.raw_commands.push(raw_command)
17231733
}
17241734

1735+
pub fn push_pred_define(
1736+
&mut self,
1737+
symbol: UserDefinedPred,
1738+
sig: UserDefinedPredSig,
1739+
body: String,
1740+
) {
1741+
self.user_defined_pred_defs
1742+
.push(UserDefinedPredDef { symbol, sig, body })
1743+
}
1744+
17251745
pub fn push_clause(&mut self, clause: Clause) -> Option<ClauseId> {
17261746
if clause.is_nop() {
17271747
return None;

src/chc/smtlib2.rs

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,33 @@ impl<'ctx, 'a> MatcherPredFun<'ctx, 'a> {
562562
}
563563
}
564564

565+
pub struct UserDefinedPredDef<'ctx, 'a> {
566+
ctx: &'ctx FormatContext,
567+
inner: &'a chc::UserDefinedPredDef,
568+
}
569+
570+
impl<'ctx, 'a> std::fmt::Display for UserDefinedPredDef<'ctx, 'a> {
571+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
572+
let params = List::closed(
573+
self.inner
574+
.sig
575+
.iter()
576+
.map(|(name, sort)| format!("({} {})", name, self.ctx.fmt_sort(sort))),
577+
);
578+
write!(
579+
f,
580+
"(define-fun {name} {params} Bool {body})",
581+
name = self.inner.symbol,
582+
body = &self.inner.body,
583+
)
584+
}
585+
}
586+
587+
impl<'ctx, 'a> UserDefinedPredDef<'ctx, 'a> {
588+
pub fn new(ctx: &'ctx FormatContext, inner: &'a chc::UserDefinedPredDef) -> Self {
589+
Self { ctx, inner }
590+
}
591+
}
565592
/// A wrapper around a [`chc::System`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format.
566593
#[derive(Debug, Clone)]
567594
pub struct System<'a> {
@@ -573,16 +600,25 @@ impl<'a> std::fmt::Display for System<'a> {
573600
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
574601
writeln!(f, "(set-logic HORN)\n")?;
575602

603+
writeln!(f, "{}\n", Datatypes::new(&self.ctx, self.ctx.datatypes()))?;
604+
for datatype in self.ctx.datatypes() {
605+
writeln!(f, "{}", DatatypeDiscrFun::new(&self.ctx, datatype))?;
606+
writeln!(f, "{}", MatcherPredFun::new(&self.ctx, datatype))?;
607+
}
608+
576609
// insert command from #![thrust::raw_command()] here
577610
for raw_command in &self.inner.raw_commands {
578611
writeln!(f, "{}\n", RawCommand::new(raw_command))?;
579612
}
580613

581-
writeln!(f, "{}\n", Datatypes::new(&self.ctx, self.ctx.datatypes()))?;
582-
for datatype in self.ctx.datatypes() {
583-
writeln!(f, "{}", DatatypeDiscrFun::new(&self.ctx, datatype))?;
584-
writeln!(f, "{}", MatcherPredFun::new(&self.ctx, datatype))?;
614+
for user_defined_pred_def in &self.inner.user_defined_pred_defs {
615+
writeln!(
616+
f,
617+
"{}\n",
618+
UserDefinedPredDef::new(&self.ctx, user_defined_pred_def)
619+
)?;
585620
}
621+
586622
writeln!(f)?;
587623
for (p, def) in self.inner.pred_vars.iter_enumerated() {
588624
if !def.debug_info.is_empty() {

src/chc/unbox.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,25 +152,40 @@ fn unbox_datatype(datatype: Datatype) -> Datatype {
152152
}
153153
}
154154

155+
fn unbox_user_defined_pred_def(user_defined_pred_def: UserDefinedPredDef) -> UserDefinedPredDef {
156+
let UserDefinedPredDef { symbol, sig, body } = user_defined_pred_def;
157+
let sig = sig
158+
.into_iter()
159+
.map(|(name, sort)| (name, unbox_sort(sort)))
160+
.collect();
161+
UserDefinedPredDef { symbol, sig, body }
162+
}
163+
155164
/// Remove all `Box` sorts and `Box`/`BoxCurrent` terms from the system.
156165
///
157166
/// The box values in Thrust represent an owned pointer, but are logically equivalent to the inner type.
158167
/// This pass removes them to reduce the complexity of the CHCs sent to the solver.
159168
/// This function traverses a [`System`] and removes all `Box` related constructs.
160169
pub fn unbox(system: System) -> System {
161170
let System {
171+
raw_commands,
172+
datatypes,
173+
user_defined_pred_defs,
162174
clauses,
163175
pred_vars,
164-
datatypes,
165-
raw_commands,
166176
} = system;
167177
let datatypes = datatypes.into_iter().map(unbox_datatype).collect();
168178
let clauses = clauses.into_iter().map(unbox_clause).collect();
169179
let pred_vars = pred_vars.into_iter().map(unbox_pred_var_def).collect();
180+
let user_defined_pred_defs = user_defined_pred_defs
181+
.into_iter()
182+
.map(unbox_user_defined_pred_def)
183+
.collect();
170184
System {
185+
raw_commands,
186+
datatypes,
187+
user_defined_pred_defs,
171188
clauses,
172189
pred_vars,
173-
datatypes,
174-
raw_commands,
175190
}
176191
}

tests/ui/pass/annot_preds.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//@check-pass
2+
//@compile-flags: -Adead_code -C debug-assertions=off
3+
4+
#[thrust::predicate]
5+
fn is_double(x: i64, doubled_x: i64) -> bool {
6+
"(=
7+
(* x 2)
8+
doubled_x
9+
)"; true
10+
}
11+
12+
#[thrust::requires(true)]
13+
#[thrust::ensures(is_double(x, result))]
14+
fn double(x: i64) -> i64 {
15+
x + x
16+
}
17+
18+
fn main() {
19+
let a = 3;
20+
assert!(double(a) == 6);
21+
}

0 commit comments

Comments
 (0)