Skip to content

Commit f26313e

Browse files
authored
Merge pull request #48 from coord-e/coord-e/fix-closure-param-match
Enable to call Fn closures via FnMut and FnOnce
2 parents ff27c71 + 2ea6cb2 commit f26313e

11 files changed

Lines changed: 386 additions & 156 deletions

File tree

src/analyze.rs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -402,17 +402,13 @@ impl<'tcx> Analyzer<'tcx> {
402402
}
403403
}
404404

405-
/// Computes the signature of the local function.
405+
/// Computes the signature of the function using the given `body`.
406406
///
407-
/// This works like `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
407+
/// This works like `self.tcx.fn_sig(def_id).instantiate_identity().skip_binder()`,
408408
/// but extracts parameter and return types directly from the given `body` to obtain a signature that
409409
/// reflects potential type instantiations happened after `optimized_mir`.
410-
pub fn local_fn_sig_with_body(
411-
&self,
412-
local_def_id: LocalDefId,
413-
body: &mir::Body<'tcx>,
414-
) -> mir_ty::FnSig<'tcx> {
415-
let ty = self.tcx.type_of(local_def_id).instantiate_identity();
410+
pub fn fn_sig_with_body(&self, def_id: DefId, body: &mir::Body<'tcx>) -> mir_ty::FnSig<'tcx> {
411+
let ty = self.tcx.type_of(def_id).instantiate_identity();
416412
let sig = if let mir_ty::TyKind::Closure(_, substs) = ty.kind() {
417413
substs.as_closure().sig().skip_binder()
418414
} else {
@@ -428,14 +424,14 @@ impl<'tcx> Analyzer<'tcx> {
428424
)
429425
}
430426

431-
/// Computes the signature of the local function.
427+
/// Computes the signature of the function.
432428
///
433-
/// This works like `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
429+
/// This works like `self.tcx.fn_sig(def_id).instantiate_identity().skip_binder()`,
434430
/// but extracts parameter and return types directly from [`mir::Body`] to obtain a signature that
435431
/// reflects the actual type of lifted closure functions.
436-
pub fn local_fn_sig(&self, local_def_id: LocalDefId) -> mir_ty::FnSig<'tcx> {
437-
let body = self.tcx.optimized_mir(local_def_id);
438-
self.local_fn_sig_with_body(local_def_id, body)
432+
pub fn fn_sig(&self, def_id: DefId) -> mir_ty::FnSig<'tcx> {
433+
let body = self.tcx.optimized_mir(def_id);
434+
self.fn_sig_with_body(def_id, body)
439435
}
440436

441437
fn extract_require_annot<T>(
@@ -487,4 +483,12 @@ impl<'tcx> Analyzer<'tcx> {
487483
}
488484
ensure_annot
489485
}
486+
487+
/// Whether the given `def_id` corresponds to a method of one of the `Fn` traits.
488+
fn is_fn_trait_method(&self, def_id: DefId) -> bool {
489+
self.tcx
490+
.trait_of_item(def_id)
491+
.and_then(|trait_did| self.tcx.fn_trait_kind_from_def_id(trait_did))
492+
.is_some()
493+
}
490494
}

src/analyze/basic_block.rs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ pub struct Analyzer<'tcx, 'ctx> {
4141
}
4242

4343
impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
44+
fn ctx(&self) -> &analyze::Analyzer<'tcx> {
45+
&*self.ctx
46+
}
47+
4448
fn is_defined(&self, local: Local) -> bool {
4549
self.env.contains_local(local)
4650
}
@@ -53,6 +57,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
5357
visitor::ReborrowVisitor::new(self)
5458
}
5559

60+
fn rust_call_visitor<'a>(&'a mut self) -> visitor::RustCallVisitor<'a, 'tcx, 'ctx> {
61+
visitor::RustCallVisitor::new(self)
62+
}
63+
5664
fn basic_block_ty(&self, bb: BasicBlock) -> &BasicBlockType {
5765
self.ctx.basic_block_ty(self.local_def_id, bb)
5866
}
@@ -568,12 +576,28 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
568576
{
569577
// TODO: handle const_fn_def on Env side
570578
let func_ty = if let Some((def_id, args)) = func.const_fn_def() {
571-
let param_env = self.tcx.param_env(self.local_def_id);
572-
let instance = mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
573-
let resolved_def_id = if let Some(instance) = instance {
574-
instance.def_id()
579+
let resolved_def_id = if self.ctx.is_fn_trait_method(def_id) {
580+
// When calling a closure via `Fn`/`FnMut`/`FnOnce` trait,
581+
// we simply replace the def_id with the closure's function def_id.
582+
// This skips shims, and makes self arguments mismatch. visitor::RustCallVisitor
583+
// adjusts the arguments accordingly.
584+
let mir_ty::TyKind::Closure(closure_def_id, _) = args.type_at(0).kind() else {
585+
panic!("expected closure arg for fn trait");
586+
};
587+
tracing::debug!(?closure_def_id, "closure instance");
588+
*closure_def_id
575589
} else {
576-
def_id
590+
let param_env = self
591+
.tcx
592+
.param_env(self.local_def_id)
593+
.with_reveal_all_normalized(self.tcx);
594+
let instance =
595+
mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
596+
if let Some(instance) = instance {
597+
instance.def_id()
598+
} else {
599+
def_id
600+
}
577601
};
578602
if def_id != resolved_def_id {
579603
tracing::info!(?def_id, ?resolved_def_id, "resolve",);
@@ -671,6 +695,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
671695
self.env.borrow_place(place, prophecy).into()
672696
}
673697

698+
fn immut_borrow_place(&self, referent: mir::Place<'tcx>) -> rty::RefinedType<Var> {
699+
let place = self.elaborate_place(&referent);
700+
self.env.place_type(place).immut().into()
701+
}
702+
674703
#[tracing::instrument(skip(self, lhs, rvalue))]
675704
fn analyze_assignment(
676705
&mut self,
@@ -754,6 +783,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
754783
source_info: term.source_info,
755784
};
756785
}
786+
self.rust_call_visitor().visit_terminator(&mut term);
757787
self.reborrow_visitor().visit_terminator(&mut term);
758788
tracing::debug!(term = ?term.kind);
759789
term

src/analyze/basic_block/visitor.rs

Lines changed: 4 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1,136 +1,5 @@
1-
use rustc_middle::mir::{self, Local};
2-
use rustc_middle::ty::{self as mir_ty, TyCtxt};
1+
mod reborrow;
2+
mod rust_call;
33

4-
use crate::analyze::ReplacePlacesVisitor;
5-
6-
pub struct ReborrowVisitor<'a, 'tcx, 'ctx> {
7-
tcx: TyCtxt<'tcx>,
8-
analyzer: &'a mut super::Analyzer<'tcx, 'ctx>,
9-
}
10-
11-
impl<'tcx> ReborrowVisitor<'_, 'tcx, '_> {
12-
fn insert_borrow(&mut self, place: mir::Place<'tcx>, inner_ty: mir_ty::Ty<'tcx>) -> Local {
13-
let r = mir_ty::Region::new_from_kind(self.tcx, mir_ty::RegionKind::ReErased);
14-
let ty = mir_ty::Ty::new_mut_ref(self.tcx, r, inner_ty);
15-
let decl = mir::LocalDecl::new(ty, Default::default()).immutable();
16-
let new_local = self.analyzer.local_decls.push(decl);
17-
let new_local_ty = self.analyzer.borrow_place_(place, inner_ty);
18-
self.analyzer.bind_local(new_local, new_local_ty);
19-
tracing::info!(old_place = ?place, ?new_local, "implicitly borrowed");
20-
new_local
21-
}
22-
23-
fn insert_reborrow(&mut self, place: mir::Place<'tcx>, inner_ty: mir_ty::Ty<'tcx>) -> Local {
24-
let r = mir_ty::Region::new_from_kind(self.tcx, mir_ty::RegionKind::ReErased);
25-
let ty = mir_ty::Ty::new_mut_ref(self.tcx, r, inner_ty);
26-
let decl = mir::LocalDecl::new(ty, Default::default()).immutable();
27-
let new_local = self.analyzer.local_decls.push(decl);
28-
let new_local_ty = self.analyzer.borrow_place_(place, inner_ty);
29-
self.analyzer.bind_local(new_local, new_local_ty);
30-
tracing::info!(old_place = ?place, ?new_local, "implicitly reborrowed");
31-
new_local
32-
}
33-
}
34-
35-
impl<'a, 'tcx, 'ctx> mir::visit::MutVisitor<'tcx> for ReborrowVisitor<'a, 'tcx, 'ctx> {
36-
fn tcx(&self) -> TyCtxt<'tcx> {
37-
self.tcx
38-
}
39-
40-
fn visit_assign(
41-
&mut self,
42-
place: &mut mir::Place<'tcx>,
43-
rvalue: &mut mir::Rvalue<'tcx>,
44-
location: mir::Location,
45-
) {
46-
if !self.analyzer.is_defined(place.local) {
47-
self.super_assign(place, rvalue, location);
48-
return;
49-
}
50-
51-
if place.projection.is_empty() && self.analyzer.is_mut_local(place.local) {
52-
let ty = self.analyzer.local_decls[place.local].ty;
53-
let new_local = self.insert_borrow(place.local.into(), ty);
54-
let new_place = self.tcx.mk_place_deref(new_local.into());
55-
ReplacePlacesVisitor::with_replacement(self.tcx, place.local.into(), new_place)
56-
.visit_rvalue(rvalue, location);
57-
*place = new_place;
58-
self.super_assign(place, rvalue, location);
59-
return;
60-
}
61-
62-
let inner_place = if place.projection.last() == Some(&mir::PlaceElem::Deref) {
63-
// *m = *m + 1 => m1 = &mut m; *m1 = *m + 1
64-
let mut projection = place.projection.as_ref().to_vec();
65-
projection.pop();
66-
mir::Place {
67-
local: place.local,
68-
projection: self.tcx.mk_place_elems(&projection),
69-
}
70-
} else {
71-
// s.0 = s.0 + 1 => m1 = &mut s.0; *m1 = *m1 + 1
72-
*place
73-
};
74-
75-
let ty = inner_place.ty(&self.analyzer.local_decls, self.tcx).ty;
76-
let (new_local, new_place) = match ty.kind() {
77-
mir_ty::TyKind::Ref(_, inner_ty, m) if m.is_mut() => {
78-
let new_local = self.insert_reborrow(*place, *inner_ty);
79-
(new_local, new_local.into())
80-
}
81-
mir_ty::TyKind::Adt(adt, args) if adt.is_box() => {
82-
let inner_ty = args.type_at(0);
83-
let new_local = self.insert_borrow(*place, inner_ty);
84-
(new_local, new_local.into())
85-
}
86-
_ => {
87-
let new_local = self.insert_borrow(*place, ty);
88-
(new_local, self.tcx.mk_place_deref(new_local.into()))
89-
}
90-
};
91-
92-
ReplacePlacesVisitor::with_replacement(self.tcx, inner_place, new_place)
93-
.visit_rvalue(rvalue, location);
94-
*place = self.tcx.mk_place_deref(new_local.into());
95-
self.super_assign(place, rvalue, location);
96-
}
97-
98-
// TODO: is it always true that the operand is not referred again in rvalue
99-
fn visit_operand(&mut self, operand: &mut mir::Operand<'tcx>, location: mir::Location) {
100-
let Some(p) = operand.place() else {
101-
self.super_operand(operand, location);
102-
return;
103-
};
104-
105-
let mir_ty::TyKind::Ref(_, inner_ty, m) =
106-
p.ty(&self.analyzer.local_decls, self.tcx).ty.kind()
107-
else {
108-
self.super_operand(operand, location);
109-
return;
110-
};
111-
112-
if m.is_mut() {
113-
let new_local = self.insert_reborrow(self.tcx.mk_place_deref(p), *inner_ty);
114-
*operand = mir::Operand::Move(new_local.into());
115-
}
116-
117-
self.super_operand(operand, location);
118-
}
119-
}
120-
121-
impl<'a, 'tcx, 'ctx> ReborrowVisitor<'a, 'tcx, 'ctx> {
122-
pub fn new(analyzer: &'a mut super::Analyzer<'tcx, 'ctx>) -> Self {
123-
let tcx = analyzer.tcx;
124-
Self { analyzer, tcx }
125-
}
126-
127-
pub fn visit_statement(&mut self, stmt: &mut mir::Statement<'tcx>) {
128-
// dummy location
129-
mir::visit::MutVisitor::visit_statement(self, stmt, mir::Location::START);
130-
}
131-
132-
pub fn visit_terminator(&mut self, term: &mut mir::Terminator<'tcx>) {
133-
// dummy location
134-
mir::visit::MutVisitor::visit_terminator(self, term, mir::Location::START);
135-
}
136-
}
4+
pub use reborrow::ReborrowVisitor;
5+
pub use rust_call::RustCallVisitor;

0 commit comments

Comments
 (0)