diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index f1a479ee9..74cee918f 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -293,7 +293,7 @@ def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoi session.commit() def find_batchable_model_passes( - self, session: Session, request_type: types.RequestType + self, session: Session, request_type: types.RequestType, max_requests: int = 0 ) -> dict[str, tuple[str, types.ForwardBackwardInput]]: """Find all requests of the given type that come before any destructive update for their model. @@ -303,6 +303,7 @@ def find_batchable_model_passes( Args: session: Database session request_type: The type of request to find (e.g., FORWARD or FORWARD_BACKWARD) + max_requests: Maximum number of requests to return. 0 means no limit. Returns: Dict mapping request_id to (model_id, request_data) tuples @@ -329,7 +330,13 @@ def find_batchable_model_passes( ops = session.exec(query).all() # Filter: only include ops that come before their model's barrier - batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]] + batchable = [] + for op in ops: + if op.model_id in barriers and op.request_id >= barriers[op.model_id]: + continue + batchable.append(op) + if max_requests > 0 and len(batchable) >= max_requests: + break return { str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data)) @@ -679,11 +686,12 @@ def process_pending_requests(self): while True: # Query for pending requests and extract data within session context with Session(self.db_engine) as session: - # Use look-ahead scheduling to find batchable forward_backward and forward model passes + # Process one request at a time to avoid large batches that cause long + # wait times and retrieve_future requests piling up. forward_backward_requests = self.find_batchable_model_passes( - session, types.RequestType.FORWARD_BACKWARD + session, types.RequestType.FORWARD_BACKWARD, max_requests=1 ) - forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD) + forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD, max_requests=1) # Find pending sample requests that can be batched sample_requests = self.find_batchable_sample(session) # Get other pending requests (non forward_backward and non sampling)