Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions skyrl-tx/tx/tinker/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoi
session.commit()

def find_batchable_model_passes(
self, session: Session, request_type: types.RequestType
self, session: Session, request_type: types.RequestType, max_requests: int = 0
) -> dict[str, tuple[str, types.ForwardBackwardInput]]:
"""Find all requests of the given type that come before any destructive update for their model.

Expand All @@ -303,6 +303,7 @@ def find_batchable_model_passes(
Args:
session: Database session
request_type: The type of request to find (e.g., FORWARD or FORWARD_BACKWARD)
max_requests: Maximum number of requests to return. 0 means no limit.

Returns:
Dict mapping request_id to (model_id, request_data) tuples
Expand All @@ -329,7 +330,13 @@ def find_batchable_model_passes(
ops = session.exec(query).all()

# Filter: only include ops that come before their model's barrier
batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]]
batchable = []
for op in ops:
if op.model_id in barriers and op.request_id >= barriers[op.model_id]:
continue
batchable.append(op)
if max_requests > 0 and len(batchable) >= max_requests:
break

return {
str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data))
Expand Down Expand Up @@ -679,11 +686,12 @@ def process_pending_requests(self):
while True:
# Query for pending requests and extract data within session context
with Session(self.db_engine) as session:
# Use look-ahead scheduling to find batchable forward_backward and forward model passes
# Process one request at a time to avoid large batches that cause long
# wait times and retrieve_future requests piling up.
forward_backward_requests = self.find_batchable_model_passes(
session, types.RequestType.FORWARD_BACKWARD
session, types.RequestType.FORWARD_BACKWARD, max_requests=1
)
forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD)
forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD, max_requests=1)
# Find pending sample requests that can be batched
sample_requests = self.find_batchable_sample(session)
# Get other pending requests (non forward_backward and non sampling)
Expand Down
Loading