Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,4 +487,12 @@ impl<'tcx> Analyzer<'tcx> {
}
ensure_annot
}

/// Whether the given `def_id` corresponds to a method of one of the `Fn` traits.
fn is_fn_trait_method(&self, def_id: DefId) -> bool {
self.tcx
.trait_of_item(def_id)
.and_then(|trait_did| self.tcx.fn_trait_kind_from_def_id(trait_did))
.is_some()
}
}
40 changes: 35 additions & 5 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub struct Analyzer<'tcx, 'ctx> {
}

impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
fn ctx(&self) -> &analyze::Analyzer<'tcx> {
&*self.ctx
}

fn is_defined(&self, local: Local) -> bool {
self.env.contains_local(local)
}
Expand All @@ -53,6 +57,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
visitor::ReborrowVisitor::new(self)
}

fn rust_call_visitor<'a>(&'a mut self) -> visitor::RustCallVisitor<'a, 'tcx, 'ctx> {
visitor::RustCallVisitor::new(self)
}

fn basic_block_ty(&self, bb: BasicBlock) -> &BasicBlockType {
self.ctx.basic_block_ty(self.local_def_id, bb)
}
Expand Down Expand Up @@ -568,12 +576,28 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
{
// TODO: handle const_fn_def on Env side
let func_ty = if let Some((def_id, args)) = func.const_fn_def() {
let param_env = self.tcx.param_env(self.local_def_id);
let instance = mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
let resolved_def_id = if let Some(instance) = instance {
instance.def_id()
let resolved_def_id = if self.ctx.is_fn_trait_method(def_id) {
// When calling a closure via `Fn`/`FnMut`/`FnOnce` trait,
// we simply replace the def_id with the closure's function def_id.
// This skips shims, and makes self arguments mismatch. visitor::RustCallVisitor
// adjusts the arguments accordingly.
let mir_ty::TyKind::Closure(closure_def_id, _) = args.type_at(0).kind() else {
panic!("expected closure arg for fn trait");
};
tracing::debug!(?closure_def_id, "closure instance");
*closure_def_id
Comment thread
coord-e marked this conversation as resolved.
} else {
def_id
let param_env = self
.tcx
.param_env(self.local_def_id)
.with_reveal_all_normalized(self.tcx);
let instance =
mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
if let Some(instance) = instance {
instance.def_id()
} else {
def_id
}
};
if def_id != resolved_def_id {
tracing::info!(?def_id, ?resolved_def_id, "resolve",);
Expand Down Expand Up @@ -671,6 +695,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self.env.borrow_place(place, prophecy).into()
}

fn immut_borrow_place(&self, referent: mir::Place<'tcx>) -> rty::RefinedType<Var> {
let place = self.elaborate_place(&referent);
self.env.place_type(place).immut().into()
}

#[tracing::instrument(skip(self, lhs, rvalue))]
fn analyze_assignment(
&mut self,
Expand Down Expand Up @@ -754,6 +783,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
source_info: term.source_info,
};
}
self.rust_call_visitor().visit_terminator(&mut term);
self.reborrow_visitor().visit_terminator(&mut term);
tracing::debug!(term = ?term.kind);
term
Expand Down
139 changes: 4 additions & 135 deletions src/analyze/basic_block/visitor.rs
Original file line number Diff line number Diff line change
@@ -1,136 +1,5 @@
use rustc_middle::mir::{self, Local};
use rustc_middle::ty::{self as mir_ty, TyCtxt};
mod reborrow;
mod rust_call;

use crate::analyze::ReplacePlacesVisitor;

pub struct ReborrowVisitor<'a, 'tcx, 'ctx> {
tcx: TyCtxt<'tcx>,
analyzer: &'a mut super::Analyzer<'tcx, 'ctx>,
}

impl<'tcx> ReborrowVisitor<'_, 'tcx, '_> {
fn insert_borrow(&mut self, place: mir::Place<'tcx>, inner_ty: mir_ty::Ty<'tcx>) -> Local {
let r = mir_ty::Region::new_from_kind(self.tcx, mir_ty::RegionKind::ReErased);
let ty = mir_ty::Ty::new_mut_ref(self.tcx, r, inner_ty);
let decl = mir::LocalDecl::new(ty, Default::default()).immutable();
let new_local = self.analyzer.local_decls.push(decl);
let new_local_ty = self.analyzer.borrow_place_(place, inner_ty);
self.analyzer.bind_local(new_local, new_local_ty);
tracing::info!(old_place = ?place, ?new_local, "implicitly borrowed");
new_local
}

fn insert_reborrow(&mut self, place: mir::Place<'tcx>, inner_ty: mir_ty::Ty<'tcx>) -> Local {
let r = mir_ty::Region::new_from_kind(self.tcx, mir_ty::RegionKind::ReErased);
let ty = mir_ty::Ty::new_mut_ref(self.tcx, r, inner_ty);
let decl = mir::LocalDecl::new(ty, Default::default()).immutable();
let new_local = self.analyzer.local_decls.push(decl);
let new_local_ty = self.analyzer.borrow_place_(place, inner_ty);
self.analyzer.bind_local(new_local, new_local_ty);
tracing::info!(old_place = ?place, ?new_local, "implicitly reborrowed");
new_local
}
}

impl<'a, 'tcx, 'ctx> mir::visit::MutVisitor<'tcx> for ReborrowVisitor<'a, 'tcx, 'ctx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn visit_assign(
&mut self,
place: &mut mir::Place<'tcx>,
rvalue: &mut mir::Rvalue<'tcx>,
location: mir::Location,
) {
if !self.analyzer.is_defined(place.local) {
self.super_assign(place, rvalue, location);
return;
}

if place.projection.is_empty() && self.analyzer.is_mut_local(place.local) {
let ty = self.analyzer.local_decls[place.local].ty;
let new_local = self.insert_borrow(place.local.into(), ty);
let new_place = self.tcx.mk_place_deref(new_local.into());
ReplacePlacesVisitor::with_replacement(self.tcx, place.local.into(), new_place)
.visit_rvalue(rvalue, location);
*place = new_place;
self.super_assign(place, rvalue, location);
return;
}

let inner_place = if place.projection.last() == Some(&mir::PlaceElem::Deref) {
// *m = *m + 1 => m1 = &mut m; *m1 = *m + 1
let mut projection = place.projection.as_ref().to_vec();
projection.pop();
mir::Place {
local: place.local,
projection: self.tcx.mk_place_elems(&projection),
}
} else {
// s.0 = s.0 + 1 => m1 = &mut s.0; *m1 = *m1 + 1
*place
};

let ty = inner_place.ty(&self.analyzer.local_decls, self.tcx).ty;
let (new_local, new_place) = match ty.kind() {
mir_ty::TyKind::Ref(_, inner_ty, m) if m.is_mut() => {
let new_local = self.insert_reborrow(*place, *inner_ty);
(new_local, new_local.into())
}
mir_ty::TyKind::Adt(adt, args) if adt.is_box() => {
let inner_ty = args.type_at(0);
let new_local = self.insert_borrow(*place, inner_ty);
(new_local, new_local.into())
}
_ => {
let new_local = self.insert_borrow(*place, ty);
(new_local, self.tcx.mk_place_deref(new_local.into()))
}
};

ReplacePlacesVisitor::with_replacement(self.tcx, inner_place, new_place)
.visit_rvalue(rvalue, location);
*place = self.tcx.mk_place_deref(new_local.into());
self.super_assign(place, rvalue, location);
}

// TODO: is it always true that the operand is not referred again in rvalue
fn visit_operand(&mut self, operand: &mut mir::Operand<'tcx>, location: mir::Location) {
let Some(p) = operand.place() else {
self.super_operand(operand, location);
return;
};

let mir_ty::TyKind::Ref(_, inner_ty, m) =
p.ty(&self.analyzer.local_decls, self.tcx).ty.kind()
else {
self.super_operand(operand, location);
return;
};

if m.is_mut() {
let new_local = self.insert_reborrow(self.tcx.mk_place_deref(p), *inner_ty);
*operand = mir::Operand::Move(new_local.into());
}

self.super_operand(operand, location);
}
}

impl<'a, 'tcx, 'ctx> ReborrowVisitor<'a, 'tcx, 'ctx> {
pub fn new(analyzer: &'a mut super::Analyzer<'tcx, 'ctx>) -> Self {
let tcx = analyzer.tcx;
Self { analyzer, tcx }
}

pub fn visit_statement(&mut self, stmt: &mut mir::Statement<'tcx>) {
// dummy location
mir::visit::MutVisitor::visit_statement(self, stmt, mir::Location::START);
}

pub fn visit_terminator(&mut self, term: &mut mir::Terminator<'tcx>) {
// dummy location
mir::visit::MutVisitor::visit_terminator(self, term, mir::Location::START);
}
}
pub use reborrow::ReborrowVisitor;
pub use rust_call::RustCallVisitor;
136 changes: 136 additions & 0 deletions src/analyze/basic_block/visitor/reborrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use rustc_middle::mir::{self, Local};
use rustc_middle::ty::{self as mir_ty, TyCtxt};

use crate::analyze::ReplacePlacesVisitor;

pub struct ReborrowVisitor<'a, 'tcx, 'ctx> {
tcx: TyCtxt<'tcx>,
analyzer: &'a mut crate::analyze::basic_block::Analyzer<'tcx, 'ctx>,
}

impl<'tcx> ReborrowVisitor<'_, 'tcx, '_> {
fn insert_borrow(&mut self, place: mir::Place<'tcx>, inner_ty: mir_ty::Ty<'tcx>) -> Local {
let r = mir_ty::Region::new_from_kind(self.tcx, mir_ty::RegionKind::ReErased);
let ty = mir_ty::Ty::new_mut_ref(self.tcx, r, inner_ty);
let decl = mir::LocalDecl::new(ty, Default::default()).immutable();
let new_local = self.analyzer.local_decls.push(decl);
let new_local_ty = self.analyzer.borrow_place_(place, inner_ty);
self.analyzer.bind_local(new_local, new_local_ty);
tracing::info!(old_place = ?place, ?new_local, "implicitly borrowed");
new_local
}

fn insert_reborrow(&mut self, place: mir::Place<'tcx>, inner_ty: mir_ty::Ty<'tcx>) -> Local {
let r = mir_ty::Region::new_from_kind(self.tcx, mir_ty::RegionKind::ReErased);
let ty = mir_ty::Ty::new_mut_ref(self.tcx, r, inner_ty);
let decl = mir::LocalDecl::new(ty, Default::default()).immutable();
let new_local = self.analyzer.local_decls.push(decl);
let new_local_ty = self.analyzer.borrow_place_(place, inner_ty);
self.analyzer.bind_local(new_local, new_local_ty);
tracing::info!(old_place = ?place, ?new_local, "implicitly reborrowed");
new_local
}
}

impl<'a, 'tcx, 'ctx> mir::visit::MutVisitor<'tcx> for ReborrowVisitor<'a, 'tcx, 'ctx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn visit_assign(
&mut self,
place: &mut mir::Place<'tcx>,
rvalue: &mut mir::Rvalue<'tcx>,
location: mir::Location,
) {
if !self.analyzer.is_defined(place.local) {
self.super_assign(place, rvalue, location);
return;
}

if place.projection.is_empty() && self.analyzer.is_mut_local(place.local) {
let ty = self.analyzer.local_decls[place.local].ty;
let new_local = self.insert_borrow(place.local.into(), ty);
let new_place = self.tcx.mk_place_deref(new_local.into());
ReplacePlacesVisitor::with_replacement(self.tcx, place.local.into(), new_place)
.visit_rvalue(rvalue, location);
*place = new_place;
self.super_assign(place, rvalue, location);
return;
}

let inner_place = if place.projection.last() == Some(&mir::PlaceElem::Deref) {
// *m = *m + 1 => m1 = &mut m; *m1 = *m + 1
let mut projection = place.projection.as_ref().to_vec();
projection.pop();
mir::Place {
local: place.local,
projection: self.tcx.mk_place_elems(&projection),
}
} else {
// s.0 = s.0 + 1 => m1 = &mut s.0; *m1 = *m1 + 1
*place
};

let ty = inner_place.ty(&self.analyzer.local_decls, self.tcx).ty;
let (new_local, new_place) = match ty.kind() {
mir_ty::TyKind::Ref(_, inner_ty, m) if m.is_mut() => {
let new_local = self.insert_reborrow(*place, *inner_ty);
(new_local, new_local.into())
}
mir_ty::TyKind::Adt(adt, args) if adt.is_box() => {
let inner_ty = args.type_at(0);
let new_local = self.insert_borrow(*place, inner_ty);
(new_local, new_local.into())
}
_ => {
let new_local = self.insert_borrow(*place, ty);
(new_local, self.tcx.mk_place_deref(new_local.into()))
}
};

ReplacePlacesVisitor::with_replacement(self.tcx, inner_place, new_place)
.visit_rvalue(rvalue, location);
*place = self.tcx.mk_place_deref(new_local.into());
self.super_assign(place, rvalue, location);
}

// TODO: is it always true that the operand is not referred again in rvalue
fn visit_operand(&mut self, operand: &mut mir::Operand<'tcx>, location: mir::Location) {
let Some(p) = operand.place() else {
self.super_operand(operand, location);
return;
};

let mir_ty::TyKind::Ref(_, inner_ty, m) =
p.ty(&self.analyzer.local_decls, self.tcx).ty.kind()
else {
self.super_operand(operand, location);
return;
};

if m.is_mut() {
let new_local = self.insert_reborrow(self.tcx.mk_place_deref(p), *inner_ty);
*operand = mir::Operand::Move(new_local.into());
}

self.super_operand(operand, location);
}
}

impl<'a, 'tcx, 'ctx> ReborrowVisitor<'a, 'tcx, 'ctx> {
pub fn new(analyzer: &'a mut crate::analyze::basic_block::Analyzer<'tcx, 'ctx>) -> Self {
let tcx = analyzer.tcx;
Self { analyzer, tcx }
}

pub fn visit_statement(&mut self, stmt: &mut mir::Statement<'tcx>) {
// dummy location
mir::visit::MutVisitor::visit_statement(self, stmt, mir::Location::START);
}

pub fn visit_terminator(&mut self, term: &mut mir::Terminator<'tcx>) {
// dummy location
mir::visit::MutVisitor::visit_terminator(self, term, mir::Location::START);
}
}
Loading