Skip to content

Commit b941a92

Browse files
authored
Merge pull request #17 from coord-e/enum-on-demand
Register EnumDatatypeDef on-demand
2 parents 827c750 + 52b5563 commit b941a92

16 files changed

Lines changed: 362 additions & 118 deletions

src/analyze.rs

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::collections::HashMap;
1111
use std::rc::Rc;
1212

1313
use rustc_hir::lang_items::LangItem;
14+
use rustc_index::IndexVec;
1415
use rustc_middle::mir::{self, BasicBlock, Local};
1516
use rustc_middle::ty::{self as mir_ty, TyCtxt};
1617
use rustc_span::def_id::{DefId, LocalDefId};
@@ -114,6 +115,33 @@ enum DefTy<'tcx> {
114115
Deferred(DeferredDefTy<'tcx>),
115116
}
116117

118+
#[derive(Debug, Clone, Default)]
119+
pub struct EnumDefs {
120+
defs: HashMap<DefId, rty::EnumDatatypeDef>,
121+
}
122+
123+
impl EnumDefs {
124+
pub fn find_by_name(&self, name: &chc::DatatypeSymbol) -> Option<&rty::EnumDatatypeDef> {
125+
self.defs.values().find(|def| &def.name == name)
126+
}
127+
128+
pub fn get(&self, def_id: DefId) -> Option<&rty::EnumDatatypeDef> {
129+
self.defs.get(&def_id)
130+
}
131+
132+
pub fn insert(&mut self, def_id: DefId, def: rty::EnumDatatypeDef) {
133+
self.defs.insert(def_id, def);
134+
}
135+
}
136+
137+
impl refine::EnumDefProvider for Rc<RefCell<EnumDefs>> {
138+
fn enum_def(&self, name: &chc::DatatypeSymbol) -> rty::EnumDatatypeDef {
139+
self.borrow().find_by_name(name).unwrap().clone()
140+
}
141+
}
142+
143+
pub type Env = refine::Env<Rc<RefCell<EnumDefs>>>;
144+
117145
#[derive(Clone)]
118146
pub struct Analyzer<'tcx> {
119147
tcx: TyCtxt<'tcx>,
@@ -131,7 +159,7 @@ pub struct Analyzer<'tcx> {
131159
basic_blocks: HashMap<LocalDefId, HashMap<BasicBlock, BasicBlockType>>,
132160
def_ids: did_cache::DefIdCache<'tcx>,
133161

134-
enum_defs: Rc<RefCell<HashMap<DefId, rty::EnumDatatypeDef>>>,
162+
enum_defs: Rc<RefCell<EnumDefs>>,
135163
}
136164

137165
impl<'tcx> crate::refine::TemplateRegistry for Analyzer<'tcx> {
@@ -174,7 +202,58 @@ impl<'tcx> Analyzer<'tcx> {
174202
}
175203
}
176204

177-
pub fn register_enum_def(&mut self, def_id: DefId, enum_def: rty::EnumDatatypeDef) {
205+
fn build_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef {
206+
let adt = self.tcx.adt_def(def_id);
207+
208+
let name = refine::datatype_symbol(self.tcx, def_id);
209+
let variants: IndexVec<_, _> = adt
210+
.variants()
211+
.iter()
212+
.map(|variant| {
213+
let name = refine::datatype_symbol(self.tcx, variant.def_id);
214+
// TODO: consider using TyCtxt::tag_for_variant
215+
let discr = resolve_discr(self.tcx, variant.discr);
216+
let field_tys = variant
217+
.fields
218+
.iter()
219+
.map(|field| {
220+
let field_ty = self.tcx.type_of(field.did).instantiate_identity();
221+
TypeBuilder::new(self.tcx, def_id).build(field_ty)
222+
})
223+
.collect();
224+
rty::EnumVariantDef {
225+
name,
226+
discr,
227+
field_tys,
228+
}
229+
})
230+
.collect();
231+
232+
let generics = self.tcx.generics_of(def_id);
233+
let ty_params = (0..generics.count())
234+
.filter(|idx| {
235+
matches!(
236+
generics.param_at(*idx, self.tcx).kind,
237+
mir_ty::GenericParamDefKind::Type { .. }
238+
)
239+
})
240+
.count();
241+
tracing::debug!(?def_id, ?name, ?ty_params, "ty_params count");
242+
243+
rty::EnumDatatypeDef {
244+
name,
245+
ty_params,
246+
variants,
247+
}
248+
}
249+
250+
pub fn get_or_register_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef {
251+
let mut enum_defs = self.enum_defs.borrow_mut();
252+
if let Some(enum_def) = enum_defs.get(def_id) {
253+
return enum_def.clone();
254+
}
255+
256+
let enum_def = self.build_enum_def(def_id);
178257
tracing::debug!(def_id = ?def_id, enum_def = ?enum_def, "register_enum_def");
179258
let ctors = enum_def
180259
.variants
@@ -199,21 +278,10 @@ impl<'tcx> Analyzer<'tcx> {
199278
params: enum_def.ty_params,
200279
ctors,
201280
};
202-
self.enum_defs.borrow_mut().insert(def_id, enum_def);
281+
enum_defs.insert(def_id, enum_def.clone());
203282
self.system.borrow_mut().datatypes.push(datatype);
204-
}
205283

206-
pub fn find_enum_variant(
207-
&self,
208-
ty_sym: &chc::DatatypeSymbol,
209-
v_sym: &chc::DatatypeSymbol,
210-
) -> Option<rty::EnumVariantDef> {
211-
self.enum_defs
212-
.borrow()
213-
.iter()
214-
.find(|(_, d)| &d.name == ty_sym)
215-
.and_then(|(_, d)| d.variants.iter().find(|v| &v.name == v_sym))
216-
.cloned()
284+
enum_def
217285
}
218286

219287
pub fn register_def(&mut self, def_id: DefId, rty: rty::RefinedType) {
@@ -304,14 +372,8 @@ impl<'tcx> Analyzer<'tcx> {
304372
self.register_def(panic_def_id, rty::RefinedType::unrefined(panic_ty.into()));
305373
}
306374

307-
pub fn new_env(&self) -> refine::Env {
308-
let defs = self
309-
.enum_defs
310-
.borrow()
311-
.values()
312-
.map(|def| (def.name.clone(), def.clone()))
313-
.collect();
314-
refine::Env::new(defs)
375+
pub fn new_env(&self) -> Env {
376+
refine::Env::new(Rc::clone(&self.enum_defs))
315377
}
316378

317379
pub fn crate_analyzer(&mut self) -> crate_::Analyzer<'tcx, '_> {

src/analyze/basic_block.rs

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ use crate::analyze;
1313
use crate::chc;
1414
use crate::pretty::PrettyDisplayExt as _;
1515
use crate::refine::{
16-
self, Assumption, BasicBlockType, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx,
17-
TypeBuilder, Var,
16+
Assumption, BasicBlockType, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, TypeBuilder,
17+
Var,
1818
};
1919
use crate::rty::{
2020
self, ClauseBuilderExt as _, ClauseScope as _, ShiftExistential as _, Subtyping as _,
@@ -34,7 +34,7 @@ pub struct Analyzer<'tcx, 'ctx> {
3434
body: Cow<'tcx, Body<'tcx>>,
3535

3636
type_builder: TypeBuilder<'tcx>,
37-
env: Env,
37+
env: analyze::Env,
3838
local_decls: IndexVec<Local, mir::LocalDecl<'tcx>>,
3939
// TODO: remove this
4040
prophecy_vars: HashMap<usize, TempVarIdx>,
@@ -350,16 +350,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
350350
.map(|operand| self.operand_type(operand).boxed())
351351
.collect();
352352
match *kind {
353-
mir::AggregateKind::Adt(did, variant_id, args, _, _)
353+
mir::AggregateKind::Adt(did, variant_idx, args, _, _)
354354
if self.tcx.def_kind(did) == DefKind::Enum =>
355355
{
356-
let adt = self.tcx.adt_def(did);
357-
let ty_sym = refine::datatype_symbol(self.tcx, did);
358-
let variant = adt.variant(variant_id);
359-
let v_sym = refine::datatype_symbol(self.tcx, variant.def_id);
360-
361-
let enum_variant_def = self.ctx.find_enum_variant(&ty_sym, &v_sym).unwrap();
362-
let variant_rtys = enum_variant_def
356+
let enum_def = self.ctx.get_or_register_enum_def(did);
357+
let variant_def = &enum_def.variants[variant_idx];
358+
let variant_rtys = variant_def
363359
.field_tys
364360
.clone()
365361
.into_iter()
@@ -386,7 +382,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
386382

387383
let sort_args: Vec<_> =
388384
rty_args.iter().map(|rty| rty.ty.to_sort()).collect();
389-
let ty = rty::EnumType::new(ty_sym.clone(), rty_args).into();
385+
let ty = rty::EnumType::new(enum_def.name.clone(), rty_args).into();
390386

391387
let mut builder = PlaceTypeBuilder::default();
392388
let mut field_terms = Vec::new();
@@ -396,7 +392,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
396392
}
397393
builder.build(
398394
ty,
399-
chc::Term::datatype_ctor(ty_sym, sort_args, v_sym, field_terms),
395+
chc::Term::datatype_ctor(
396+
enum_def.name,
397+
sort_args,
398+
variant_def.name.clone(),
399+
field_terms,
400+
),
400401
)
401402
}
402403
_ => PlaceType::tuple(field_tys),
@@ -924,6 +925,31 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
924925
}
925926
}
926927
}
928+
929+
fn register_enum_defs(&mut self) {
930+
for local_decl in &self.local_decls {
931+
use mir_ty::{TypeSuperVisitable as _, TypeVisitable as _};
932+
#[derive(Default)]
933+
struct EnumCollector {
934+
enums: std::collections::HashSet<DefId>,
935+
}
936+
impl<'tcx> mir_ty::TypeVisitor<mir_ty::TyCtxt<'tcx>> for EnumCollector {
937+
fn visit_ty(&mut self, ty: mir_ty::Ty<'tcx>) {
938+
if let mir_ty::TyKind::Adt(adt_def, _) = ty.kind() {
939+
if adt_def.is_enum() {
940+
self.enums.insert(adt_def.did());
941+
}
942+
}
943+
ty.super_visit_with(self);
944+
}
945+
}
946+
let mut visitor = EnumCollector::default();
947+
local_decl.ty.visit_with(&mut visitor);
948+
for def_id in visitor.enums {
949+
self.ctx.get_or_register_enum_def(def_id);
950+
}
951+
}
952+
}
927953
}
928954

929955
/// Turns [`rty::RefinedType<Var>`] into [`rty::RefinedType<T>`].
@@ -967,7 +993,7 @@ impl<T> UnbindAtoms<T> {
967993
self.existentials.extend(var_ty.existentials);
968994
}
969995

970-
pub fn unbind(mut self, env: &Env, ty: rty::RefinedType<Var>) -> rty::RefinedType<T> {
996+
pub fn unbind(mut self, env: &analyze::Env, ty: rty::RefinedType<Var>) -> rty::RefinedType<T> {
971997
let rty::RefinedType {
972998
ty: src_ty,
973999
refinement,
@@ -1136,14 +1162,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
11361162
self
11371163
}
11381164

1139-
pub fn env(&mut self, env: Env) -> &mut Self {
1140-
self.env = env;
1141-
self
1142-
}
1143-
11441165
pub fn run(&mut self, expected: &BasicBlockType) {
11451166
let span = tracing::info_span!("bb", bb = ?self.basic_block);
11461167
let _guard = span.enter();
1168+
self.register_enum_defs();
11471169

11481170
let params = expected.as_ref().params.clone();
11491171
self.bind_locals(&params);

src/analyze/crate_.rs

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
33
use std::collections::HashSet;
44

5-
use rustc_hir::def::DefKind;
6-
use rustc_index::IndexVec;
75
use rustc_middle::ty::{self as mir_ty, TyCtxt};
86
use rustc_span::def_id::{DefId, LocalDefId};
97

108
use crate::analyze;
119
use crate::chc;
12-
use crate::refine::{self, TypeBuilder};
1310
use crate::rty::{self, ClauseBuilderExt as _};
1411

1512
/// An implementation of local crate analysis.
@@ -167,57 +164,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
167164
}
168165
}
169166
}
170-
171-
fn register_enum_defs(&mut self) {
172-
for local_def_id in self.tcx.iter_local_def_id() {
173-
let DefKind::Enum = self.tcx.def_kind(local_def_id) else {
174-
continue;
175-
};
176-
let adt = self.tcx.adt_def(local_def_id);
177-
178-
let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id());
179-
let variants: IndexVec<_, _> = adt
180-
.variants()
181-
.iter()
182-
.map(|variant| {
183-
let name = refine::datatype_symbol(self.tcx, variant.def_id);
184-
// TODO: consider using TyCtxt::tag_for_variant
185-
let discr = analyze::resolve_discr(self.tcx, variant.discr);
186-
let field_tys = variant
187-
.fields
188-
.iter()
189-
.map(|field| {
190-
let field_ty = self.tcx.type_of(field.did).instantiate_identity();
191-
TypeBuilder::new(self.tcx, local_def_id.to_def_id()).build(field_ty)
192-
})
193-
.collect();
194-
rty::EnumVariantDef {
195-
name,
196-
discr,
197-
field_tys,
198-
}
199-
})
200-
.collect();
201-
202-
let generics = self.tcx.generics_of(local_def_id);
203-
let ty_params = (0..generics.count())
204-
.filter(|idx| {
205-
matches!(
206-
generics.param_at(*idx, self.tcx).kind,
207-
mir_ty::GenericParamDefKind::Type { .. }
208-
)
209-
})
210-
.count();
211-
tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count");
212-
213-
let def = rty::EnumDatatypeDef {
214-
name,
215-
ty_params,
216-
variants,
217-
};
218-
self.ctx.register_enum_def(local_def_id.to_def_id(), def);
219-
}
220-
}
221167
}
222168

223169
impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
@@ -231,7 +177,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
231177
let span = tracing::debug_span!("crate", krate = %self.tcx.crate_name(rustc_span::def_id::LOCAL_CRATE));
232178
let _guard = span.enter();
233179

234-
self.register_enum_defs();
235180
self.refine_local_defs();
236181
self.analyze_local_defs();
237182
self.assert_callable_entry();

src/chc/unbox.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,42 @@ fn unbox_term(term: Term) -> Term {
1313
Term::App(fun, args) => Term::App(fun, args.into_iter().map(unbox_term).collect()),
1414
Term::Tuple(ts) => Term::Tuple(ts.into_iter().map(unbox_term).collect()),
1515
Term::TupleProj(t, i) => Term::TupleProj(Box::new(unbox_term(*t)), i),
16-
Term::DatatypeCtor(s1, s2, args) => {
17-
Term::DatatypeCtor(s1, s2, args.into_iter().map(unbox_term).collect())
18-
}
16+
Term::DatatypeCtor(s1, s2, args) => Term::DatatypeCtor(
17+
unbox_datatype_sort(s1),
18+
s2,
19+
args.into_iter().map(unbox_term).collect(),
20+
),
1921
Term::DatatypeDiscr(sym, arg) => Term::DatatypeDiscr(sym, Box::new(unbox_term(*arg))),
2022
Term::FormulaExistentialVar(sort, name) => {
2123
Term::FormulaExistentialVar(unbox_sort(sort), name)
2224
}
2325
}
2426
}
2527

28+
fn unbox_matcher_pred(pred: MatcherPred) -> Pred {
29+
let MatcherPred {
30+
datatype_symbol,
31+
datatype_args,
32+
} = pred;
33+
let datatype_args = datatype_args.into_iter().map(unbox_sort).collect();
34+
Pred::Matcher(MatcherPred {
35+
datatype_symbol,
36+
datatype_args,
37+
})
38+
}
39+
40+
fn unbox_pred(pred: Pred) -> Pred {
41+
match pred {
42+
Pred::Known(pred) => Pred::Known(pred),
43+
Pred::Var(pred) => Pred::Var(pred),
44+
Pred::Matcher(pred) => unbox_matcher_pred(pred),
45+
}
46+
}
47+
2648
fn unbox_atom(atom: Atom) -> Atom {
2749
let Atom { guard, pred, args } = atom;
2850
let guard = guard.map(|fo| Box::new(unbox_formula(*fo)));
51+
let pred = unbox_pred(pred);
2952
let args = args.into_iter().map(unbox_term).collect();
3053
Atom { guard, pred, args }
3154
}

0 commit comments

Comments
 (0)