Skip to content

Commit c58a4d0

Browse files
authored
feat: release gil during remote storage operations to prevent blocking (#29)
1 parent cd0c7c9 commit c58a4d0

1 file changed

Lines changed: 19 additions & 14 deletions

File tree

python/src/lib.rs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ impl Context {
7171
#[pyo3(signature = (uri, *, storage_options=None))]
7272
fn create(
7373
_cls: &Bound<'_, PyType>,
74+
py: Python<'_>,
7475
uri: &str,
7576
storage_options: Option<&Bound<'_, PyDict>>,
7677
) -> PyResult<Self> {
@@ -80,9 +81,10 @@ impl Context {
8081
storage_options: storage_options_from_dict(storage_options)?,
8182
};
8283

83-
let store = runtime
84-
.block_on(ContextStore::open_with_options(uri, options))
85-
.map_err(to_py_err)?;
84+
let store_res = py.allow_threads(|| {
85+
runtime.block_on(ContextStore::open_with_options(uri, options))
86+
});
87+
let store = store_res.map_err(to_py_err)?;
8688
let run_id = new_run_id();
8789
Ok(Self {
8890
inner: RustContext::new(uri),
@@ -111,6 +113,7 @@ impl Context {
111113
#[pyo3(signature = (role, content, data_type = None))]
112114
fn add(
113115
&mut self,
116+
py: Python<'_>,
114117
role: &str,
115118
content: &Bound<'_, PyAny>,
116119
data_type: Option<&str>,
@@ -147,9 +150,11 @@ impl Context {
147150
embedding: None,
148151
};
149152

150-
self.runtime
151-
.block_on(self.store.add(std::slice::from_ref(&record)))
152-
.map_err(to_py_err)?;
153+
let add_res = py.allow_threads(|| {
154+
self.runtime
155+
.block_on(self.store.add(std::slice::from_ref(&record)))
156+
});
157+
add_res.map_err(to_py_err)?;
153158
self.inner.add(role, &inner_content, data_type);
154159
Ok(())
155160
}
@@ -168,10 +173,9 @@ impl Context {
168173
}
169174
}
170175

171-
fn checkout(&mut self, version_id: u64) -> PyResult<()> {
172-
self.runtime
173-
.block_on(self.store.checkout(version_id))
174-
.map_err(to_py_err)?;
176+
fn checkout(&mut self, py: Python<'_>, version_id: u64) -> PyResult<()> {
177+
let res = py.allow_threads(|| self.runtime.block_on(self.store.checkout(version_id)));
178+
res.map_err(to_py_err)?;
175179
self.run_id = new_run_id();
176180
Ok(())
177181
}
@@ -183,10 +187,11 @@ impl Context {
183187
query: Vec<f32>,
184188
limit: Option<usize>,
185189
) -> PyResult<Vec<PyObject>> {
186-
let hits = self
187-
.runtime
188-
.block_on(self.store.search(&query, limit))
189-
.map_err(to_py_err)?;
190+
let hits_res = py.allow_threads(|| {
191+
self.runtime
192+
.block_on(self.store.search(&query, limit))
193+
});
194+
let hits = hits_res.map_err(to_py_err)?;
190195
hits.into_iter()
191196
.map(|hit| search_hit_to_py(py, hit))
192197
.collect()

0 commit comments

Comments
 (0)