@@ -13,8 +13,8 @@ use crate::analyze;
1313use crate :: chc;
1414use crate :: pretty:: PrettyDisplayExt as _;
1515use crate :: refine:: {
16- self , Assumption , BasicBlockType , Env , PlaceType , PlaceTypeBuilder , PlaceTypeVar , TempVarIdx ,
17- TypeBuilder , Var ,
16+ Assumption , BasicBlockType , PlaceType , PlaceTypeBuilder , PlaceTypeVar , TempVarIdx , TypeBuilder ,
17+ Var ,
1818} ;
1919use crate :: rty:: {
2020 self , ClauseBuilderExt as _, ClauseScope as _, ShiftExistential as _, Subtyping as _,
@@ -34,7 +34,7 @@ pub struct Analyzer<'tcx, 'ctx> {
3434 body : Cow < ' tcx , Body < ' tcx > > ,
3535
3636 type_builder : TypeBuilder < ' tcx > ,
37- env : Env ,
37+ env : analyze :: Env ,
3838 local_decls : IndexVec < Local , mir:: LocalDecl < ' tcx > > ,
3939 // TODO: remove this
4040 prophecy_vars : HashMap < usize , TempVarIdx > ,
@@ -350,16 +350,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
350350 . map ( |operand| self . operand_type ( operand) . boxed ( ) )
351351 . collect ( ) ;
352352 match * kind {
353- mir:: AggregateKind :: Adt ( did, variant_id , args, _, _)
353+ mir:: AggregateKind :: Adt ( did, variant_idx , args, _, _)
354354 if self . tcx . def_kind ( did) == DefKind :: Enum =>
355355 {
356- let adt = self . tcx . adt_def ( did) ;
357- let ty_sym = refine:: datatype_symbol ( self . tcx , did) ;
358- let variant = adt. variant ( variant_id) ;
359- let v_sym = refine:: datatype_symbol ( self . tcx , variant. def_id ) ;
360-
361- let enum_variant_def = self . ctx . find_enum_variant ( & ty_sym, & v_sym) . unwrap ( ) ;
362- let variant_rtys = enum_variant_def
356+ let enum_def = self . ctx . get_or_register_enum_def ( did) ;
357+ let variant_def = & enum_def. variants [ variant_idx] ;
358+ let variant_rtys = variant_def
363359 . field_tys
364360 . clone ( )
365361 . into_iter ( )
@@ -386,7 +382,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
386382
387383 let sort_args: Vec < _ > =
388384 rty_args. iter ( ) . map ( |rty| rty. ty . to_sort ( ) ) . collect ( ) ;
389- let ty = rty:: EnumType :: new ( ty_sym . clone ( ) , rty_args) . into ( ) ;
385+ let ty = rty:: EnumType :: new ( enum_def . name . clone ( ) , rty_args) . into ( ) ;
390386
391387 let mut builder = PlaceTypeBuilder :: default ( ) ;
392388 let mut field_terms = Vec :: new ( ) ;
@@ -396,7 +392,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
396392 }
397393 builder. build (
398394 ty,
399- chc:: Term :: datatype_ctor ( ty_sym, sort_args, v_sym, field_terms) ,
395+ chc:: Term :: datatype_ctor (
396+ enum_def. name ,
397+ sort_args,
398+ variant_def. name . clone ( ) ,
399+ field_terms,
400+ ) ,
400401 )
401402 }
402403 _ => PlaceType :: tuple ( field_tys) ,
@@ -924,6 +925,31 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
924925 }
925926 }
926927 }
928+
929+ fn register_enum_defs ( & mut self ) {
930+ for local_decl in & self . local_decls {
931+ use mir_ty:: { TypeSuperVisitable as _, TypeVisitable as _} ;
932+ #[ derive( Default ) ]
933+ struct EnumCollector {
934+ enums : std:: collections:: HashSet < DefId > ,
935+ }
936+ impl < ' tcx > mir_ty:: TypeVisitor < mir_ty:: TyCtxt < ' tcx > > for EnumCollector {
937+ fn visit_ty ( & mut self , ty : mir_ty:: Ty < ' tcx > ) {
938+ if let mir_ty:: TyKind :: Adt ( adt_def, _) = ty. kind ( ) {
939+ if adt_def. is_enum ( ) {
940+ self . enums . insert ( adt_def. did ( ) ) ;
941+ }
942+ }
943+ ty. super_visit_with ( self ) ;
944+ }
945+ }
946+ let mut visitor = EnumCollector :: default ( ) ;
947+ local_decl. ty . visit_with ( & mut visitor) ;
948+ for def_id in visitor. enums {
949+ self . ctx . get_or_register_enum_def ( def_id) ;
950+ }
951+ }
952+ }
927953}
928954
929955/// Turns [`rty::RefinedType<Var>`] into [`rty::RefinedType<T>`].
@@ -967,7 +993,7 @@ impl<T> UnbindAtoms<T> {
967993 self . existentials . extend ( var_ty. existentials ) ;
968994 }
969995
970- pub fn unbind ( mut self , env : & Env , ty : rty:: RefinedType < Var > ) -> rty:: RefinedType < T > {
996+ pub fn unbind ( mut self , env : & analyze :: Env , ty : rty:: RefinedType < Var > ) -> rty:: RefinedType < T > {
971997 let rty:: RefinedType {
972998 ty : src_ty,
973999 refinement,
@@ -1136,14 +1162,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
11361162 self
11371163 }
11381164
1139- pub fn env ( & mut self , env : Env ) -> & mut Self {
1140- self . env = env;
1141- self
1142- }
1143-
11441165 pub fn run ( & mut self , expected : & BasicBlockType ) {
11451166 let span = tracing:: info_span!( "bb" , bb = ?self . basic_block) ;
11461167 let _guard = span. enter ( ) ;
1168+ self . register_enum_defs ( ) ;
11471169
11481170 let params = expected. as_ref ( ) . params . clone ( ) ;
11491171 self . bind_locals ( & params) ;
0 commit comments