1+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
12use std:: sync:: Arc ;
23use std:: { borrow:: Borrow , collections:: HashMap } ;
34
45use tokio:: sync:: RwLock ;
6+ use tokio:: task:: JoinHandle ;
57
68use crate :: provider:: { FeatureProvider , NoOpProvider } ;
79
@@ -32,20 +34,26 @@ impl ProviderRegistry {
3234 }
3335
3436 pub async fn set_default < T : FeatureProvider > ( & self , mut provider : T ) {
35- let mut map = self . providers . write ( ) . await ;
36- map. remove ( "" ) ;
37+ let old_provider = self . providers . write ( ) . await . remove ( "" ) ;
38+
39+ if let Some ( old_provider) = old_provider {
40+ old_provider. shutdown_in_background ( ) ;
41+ }
3742
3843 provider
3944 . initialize ( self . global_evaluation_context . get ( ) . await . borrow ( ) )
4045 . await ;
4146
42- map. insert ( String :: default ( ) , FeatureProviderWrapper :: new ( provider) ) ;
47+ self . providers
48+ . write ( )
49+ . await
50+ . insert ( String :: default ( ) , FeatureProviderWrapper :: new ( provider) ) ;
4351 }
4452
4553 pub async fn set_named < T : FeatureProvider > ( & self , name : & str , mut provider : T ) {
4654 // Drop the already registered provider if any.
47- if self . get_named ( name ) . await . is_some ( ) {
48- self . providers . write ( ) . await . remove ( name ) ;
55+ if let Some ( old_provider ) = self . providers . write ( ) . await . remove ( name ) {
56+ old_provider . shutdown_in_background ( ) ;
4957 }
5058
5159 provider
@@ -74,7 +82,21 @@ impl ProviderRegistry {
7482 }
7583
7684 pub async fn clear ( & self ) {
85+ let providers: Vec < FeatureProviderWrapper > =
86+ self . providers . read ( ) . await . values ( ) . cloned ( ) . collect ( ) ;
87+
88+ let mut shutdown_handles = Vec :: with_capacity ( providers. len ( ) ) ;
89+ for provider in providers {
90+ if let Some ( handle) = provider. shutdown_in_background ( ) {
91+ shutdown_handles. push ( handle) ;
92+ }
93+ }
94+
7795 self . providers . write ( ) . await . clear ( ) ;
96+
97+ for handle in shutdown_handles {
98+ let _ = handle. await ;
99+ }
78100 }
79101}
80102
@@ -89,14 +111,44 @@ impl Default for ProviderRegistry {
89111// ============================================================
90112
91113#[ derive( Clone ) ]
92- pub struct FeatureProviderWrapper ( Arc < dyn FeatureProvider > ) ;
114+ pub struct FeatureProviderWrapper ( Arc < ProviderEntry > ) ;
93115
94116impl FeatureProviderWrapper {
95117 pub fn new ( provider : impl FeatureProvider ) -> Self {
96- Self ( Arc :: new ( provider) )
118+ Self ( Arc :: new ( ProviderEntry :: new ( provider) ) )
97119 }
98120
99121 pub fn get ( & self ) -> Arc < dyn FeatureProvider > {
100- self . 0 . clone ( )
122+ self . 0 . provider . clone ( )
123+ }
124+
125+ pub fn shutdown_in_background ( & self ) -> Option < JoinHandle < ( ) > > {
126+ if self
127+ . 0
128+ . shutdown_started
129+ . compare_exchange ( false , true , Ordering :: AcqRel , Ordering :: Acquire )
130+ . is_ok ( )
131+ {
132+ let provider = self . get ( ) ;
133+ Some ( tokio:: spawn ( async move {
134+ provider. shutdown ( ) . await ;
135+ } ) )
136+ } else {
137+ None
138+ }
139+ }
140+ }
141+
142+ struct ProviderEntry {
143+ provider : Arc < dyn FeatureProvider > ,
144+ shutdown_started : AtomicBool ,
145+ }
146+
147+ impl ProviderEntry {
148+ fn new ( provider : impl FeatureProvider ) -> Self {
149+ Self {
150+ provider : Arc :: new ( provider) ,
151+ shutdown_started : AtomicBool :: new ( false ) ,
152+ }
101153 }
102154}
0 commit comments