@@ -21,7 +21,7 @@ use pyo3::exceptions::PyException;
2121use pyo3:: prelude:: PyAnyMethods ;
2222use pyo3:: types:: PyTuple ;
2323use pyo3:: { Bound , IntoPyObject , PyAny , PyResult , Python } ;
24- use std:: collections:: { BTreeSet , HashMap } ;
24+ use std:: collections:: HashMap ;
2525use std:: path:: { Path , PathBuf } ;
2626use std:: time:: Duration ;
2727use tako:: gateway:: {
@@ -273,35 +273,29 @@ pub struct JobWaitStatus {
273273pub fn wait_for_jobs_impl (
274274 py : Python ,
275275 ctx : ClientContextPtr ,
276- job_ids : Vec < PyJobId > ,
276+ mut job_ids : Vec < PyJobId > ,
277277 callback : & Bound < ' _ , PyAny > ,
278278) -> PyResult < Vec < PyJobId > > {
279279 run_future ( async move {
280- let mut remaining_job_ids: BTreeSet < PyJobId > = job_ids. iter ( ) . copied ( ) . collect ( ) ;
280+ job_ids. sort_unstable ( ) ;
281+ let selector = IdSelector :: Specific ( IntArray :: from_sorted_ids ( job_ids. into_iter ( ) ) ) ;
281282
282283 let mut ctx = borrow_mut ! ( py, ctx) ;
283284 let mut response: JobInfoResponse ;
284285
285286 loop {
286- let selector =
287- IdSelector :: Specific ( IntArray :: from_sorted_ids ( remaining_job_ids. iter ( ) . copied ( ) ) ) ;
288-
289287 response = rpc_call ! (
290288 ctx. session. connection( ) ,
291289 FromClientMessage :: JobInfo ( JobInfoRequest {
292- selector,
290+ selector: selector . clone ( ) ,
293291 include_running_tasks: false
294292 } , None ) ,
295293 ToClientMessage :: JobInfoResponse ( r) => r
296294 )
297295 . await
298296 . map_py_err ( ) ?;
299297
300- for job in response. jobs . iter ( ) {
301- if is_terminated ( job) {
302- remaining_job_ids. remove ( & PyJobId :: from ( job. id ) ) ;
303- }
304- }
298+ let completed = response. jobs . iter ( ) . all ( is_terminated) ;
305299
306300 let status: HashMap < PyJobId , JobWaitStatus > = response
307301 . jobs
@@ -320,7 +314,7 @@ pub fn wait_for_jobs_impl(
320314 let args = PyTuple :: new ( py, & [ status. into_pyobject ( py) ?] ) ?;
321315 callback. call1 ( args) ?;
322316
323- if remaining_job_ids . is_empty ( ) {
317+ if completed {
324318 break ;
325319 }
326320
0 commit comments