From e8708dc990daa6ac392bc89e7d6d2b69a33689f9 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 26 Jan 2026 15:40:43 -0800 Subject: [PATCH 1/6] [tx] Add max_micro_batches config to limit batch size in engine Add max_micro_batches config (default: 64) to EngineConfig to limit how many micro batches are processed before returning results to clients. This prevents long wait times when clients send large numbers of requests. The limit behavior depends on train_micro_batch_size: - When > 0: counts micro batches as ceil(sequences / micro_batch_size) - When = 0 (full batch mode): each request counts as 1 Always includes at least one request to avoid starvation. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/tinker/test_engine.py | 86 +++++++++++++++++++++++++++- skyrl-tx/tx/tinker/config.py | 4 ++ skyrl-tx/tx/tinker/engine.py | 21 +++++++ 3 files changed, 110 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tests/tinker/test_engine.py b/skyrl-tx/tests/tinker/test_engine.py index 3319a8c3e2..8756db3ca1 100644 --- a/skyrl-tx/tests/tinker/test_engine.py +++ b/skyrl-tx/tests/tinker/test_engine.py @@ -1,12 +1,13 @@ from cloudpathlib import AnyPath from datetime import datetime, timedelta, timezone +import pytest from sqlmodel import Session, SQLModel from tx.tinker.engine import TinkerEngine from tx.tinker.config import EngineConfig from tx.tinker import types -from tx.tinker.db_models import SessionDB, ModelDB +from tx.tinker.db_models import SessionDB, ModelDB, FutureDB, RequestStatus BASE_MODEL = "trl-internal-testing/tiny-Qwen3ForCausalLM" @@ -80,3 +81,86 @@ def test_cleanup_stale_sessions(): # Run cleanup and assert one model was unloaded assert engine.cleanup_stale_sessions() == 1 assert not engine.backend.has_model(model_id) + + +class TestMaxMicroBatches: + """Tests for max_micro_batches limiting in find_batchable_model_passes.""" + + @staticmethod + def _make_request_data(num_sequences: int) -> dict: + """Create a ForwardBackwardInput request data with the given number of sequences.""" + data = [] + for _ in range(num_sequences): + data.append({ + "model_input": {"chunks": [{"tokens": [1, 2, 3]}]}, + "loss_fn_inputs": { + "target_tokens": {"data": [2, 3, 4]}, + "weights": {"data": [1.0, 1.0, 1.0]}, + "advantages": {"data": [0.0, 0.0, 0.0]}, + "logprobs": {"data": [0.0, 0.0, 0.0]}, + }, + }) + return {"data": data, "loss_fn": "cross_entropy"} + + @staticmethod + def _create_engine(train_micro_batch_size: int, max_micro_batches: int) -> TinkerEngine: + """Create an engine with the given micro batch configuration.""" + config = EngineConfig( + base_model=BASE_MODEL, + checkpoints_base=AnyPath(""), + backend_config={"max_lora_adapters": 4, "max_lora_rank": 32, "train_micro_batch_size": train_micro_batch_size}, + max_micro_batches=max_micro_batches, + database_url="sqlite:///:memory:", + ) + engine = TinkerEngine(config) + SQLModel.metadata.create_all(engine.db_engine) + return engine + + def _add_requests(self, engine: TinkerEngine, sequence_counts: list[int]): + """Add FORWARD_BACKWARD requests with the given sequence counts.""" + with Session(engine.db_engine) as session: + for num_sequences in sequence_counts: + session.add(FutureDB( + request_type=types.RequestType.FORWARD_BACKWARD, + model_id="model1", + request_data=self._make_request_data(num_sequences), + status=RequestStatus.PENDING, + )) + session.commit() + + @pytest.mark.parametrize( + "train_micro_batch_size,max_micro_batches,sequence_counts,expected_count", + [ + # Gradient accumulation mode: ceil(16/4) + ceil(20/4) = 4 + 5 = 9 <= 10, ceil(8/4) = 2 would exceed + (4, 10, [16, 20, 8], 2), + # Full batch mode: each request counts as 1, so 3 requests fit in max_micro_batches=3 + (0, 3, [100, 200, 50, 75], 3), + # Disabled: all requests included when max_micro_batches=0 + (4, 0, [50] * 10, 10), + ], + ids=["gradient_accumulation", "full_batch_mode", "disabled"], + ) + def test_micro_batch_limiting(self, train_micro_batch_size, max_micro_batches, sequence_counts, expected_count): + """Test that micro batches are limited correctly under different configurations.""" + engine = self._create_engine(train_micro_batch_size, max_micro_batches) + self._add_requests(engine, sequence_counts) + + with Session(engine.db_engine) as session: + result = engine.find_batchable_model_passes(session, types.RequestType.FORWARD_BACKWARD) + + assert len(result) == expected_count + + def test_always_includes_at_least_one_request(self): + """Test that at least one request is always included even if it exceeds the limit.""" + # train_micro_batch_size=4, max_micro_batches=10 + # Request with 100 sequences = ceil(100/4) = 25 micro batches > 10 + # Should still be included to avoid starvation + engine = self._create_engine(train_micro_batch_size=4, max_micro_batches=10) + self._add_requests(engine, [100]) + + with Session(engine.db_engine) as session: + result = engine.find_batchable_model_passes(session, types.RequestType.FORWARD_BACKWARD) + + assert len(result) == 1 + _, req_data = list(result.values())[0] + assert len(req_data.data) == 100 diff --git a/skyrl-tx/tx/tinker/config.py b/skyrl-tx/tx/tinker/config.py index e126e5499e..ab11a3d339 100644 --- a/skyrl-tx/tx/tinker/config.py +++ b/skyrl-tx/tx/tinker/config.py @@ -51,6 +51,10 @@ class EngineConfig(BaseModel): default=300, description="Seconds without heartbeat before session is considered stale. Set to -1 to disable cleanup.", ) + max_micro_batches: int = Field( + default=64, + description="Maximum number of micro batches per forward/forward_backward batch. Limits how many are processed before returning results to clients. Set to 0 to disable.", + ) def convert_env_var(env_name: str, env_value: str, expected_type: type): diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index b7f4b29178..2ba174a738 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -1,6 +1,7 @@ """Background engine for processing training requests.""" import argparse +import math import time from contextlib import contextmanager from datetime import datetime, timedelta, timezone @@ -270,6 +271,26 @@ def find_batchable_model_passes( # Filter: only include ops that come before their model's barrier batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]] + # Limit total micro batches if configured + if self.config.max_micro_batches > 0 and isinstance(self.backend, JaxBackend): + micro_batch_size = self.backend.config.train_micro_batch_size + limited = [] + total_micro_batches = 0 + for op in batchable: + num_sequences = len(op.request_data.get("data", [])) + if micro_batch_size > 0: + # Gradient accumulation enabled: count actual micro batches + num_micro_batches = math.ceil(num_sequences / micro_batch_size) + else: + # Full batch mode: each request is processed as one unit + num_micro_batches = 1 + # Always include at least one request to avoid starvation + if limited and total_micro_batches + num_micro_batches > self.config.max_micro_batches: + break + limited.append(op) + total_micro_batches += num_micro_batches + batchable = limited + return { str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data)) for f in batchable From be102bb8ea049ac32f8214be1379e0dc69430d51 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 26 Jan 2026 17:26:03 -0800 Subject: [PATCH 2/6] lint --- skyrl-tx/tests/tinker/test_engine.py | 40 +++++++++++++++++----------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_engine.py b/skyrl-tx/tests/tinker/test_engine.py index 8756db3ca1..706ace219e 100644 --- a/skyrl-tx/tests/tinker/test_engine.py +++ b/skyrl-tx/tests/tinker/test_engine.py @@ -91,15 +91,17 @@ def _make_request_data(num_sequences: int) -> dict: """Create a ForwardBackwardInput request data with the given number of sequences.""" data = [] for _ in range(num_sequences): - data.append({ - "model_input": {"chunks": [{"tokens": [1, 2, 3]}]}, - "loss_fn_inputs": { - "target_tokens": {"data": [2, 3, 4]}, - "weights": {"data": [1.0, 1.0, 1.0]}, - "advantages": {"data": [0.0, 0.0, 0.0]}, - "logprobs": {"data": [0.0, 0.0, 0.0]}, - }, - }) + data.append( + { + "model_input": {"chunks": [{"tokens": [1, 2, 3]}]}, + "loss_fn_inputs": { + "target_tokens": {"data": [2, 3, 4]}, + "weights": {"data": [1.0, 1.0, 1.0]}, + "advantages": {"data": [0.0, 0.0, 0.0]}, + "logprobs": {"data": [0.0, 0.0, 0.0]}, + }, + } + ) return {"data": data, "loss_fn": "cross_entropy"} @staticmethod @@ -108,7 +110,11 @@ def _create_engine(train_micro_batch_size: int, max_micro_batches: int) -> Tinke config = EngineConfig( base_model=BASE_MODEL, checkpoints_base=AnyPath(""), - backend_config={"max_lora_adapters": 4, "max_lora_rank": 32, "train_micro_batch_size": train_micro_batch_size}, + backend_config={ + "max_lora_adapters": 4, + "max_lora_rank": 32, + "train_micro_batch_size": train_micro_batch_size, + }, max_micro_batches=max_micro_batches, database_url="sqlite:///:memory:", ) @@ -120,12 +126,14 @@ def _add_requests(self, engine: TinkerEngine, sequence_counts: list[int]): """Add FORWARD_BACKWARD requests with the given sequence counts.""" with Session(engine.db_engine) as session: for num_sequences in sequence_counts: - session.add(FutureDB( - request_type=types.RequestType.FORWARD_BACKWARD, - model_id="model1", - request_data=self._make_request_data(num_sequences), - status=RequestStatus.PENDING, - )) + session.add( + FutureDB( + request_type=types.RequestType.FORWARD_BACKWARD, + model_id="model1", + request_data=self._make_request_data(num_sequences), + status=RequestStatus.PENDING, + ) + ) session.commit() @pytest.mark.parametrize( From e4703e9f14668f50f05f54f77f517a55724bfcd2 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 11 Feb 2026 09:43:24 -0800 Subject: [PATCH 3/6] fix --- skyrl-tx/tx/tinker/engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index ef24e72b39..3838ae9e45 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -12,6 +12,7 @@ from pydantic import BaseModel from sqlmodel import create_engine, Session, select, update, func +from tx.tinker.backends.jax import JaxBackend from tx.tinker.db_models import ( FutureDB, RequestStatus, From 229e4b78e0cb46a81265ffdbaf4e224abf59faf4 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 12 Feb 2026 15:51:39 -0800 Subject: [PATCH 4/6] [tx] Remove batching for forward/forward_backward requests Process one request at a time instead of batching all pending requests together, to avoid large batches that cause long wait times and retrieve_future requests piling up. Replaces the max_micro_batches approach with a simpler max_requests=1 limit on find_batchable_model_passes. Co-Authored-By: Claude Opus 4.6 --- skyrl-tx/tests/tinker/test_engine.py | 94 +--------------------------- skyrl-tx/tx/tinker/config.py | 4 -- skyrl-tx/tx/tinker/engine.py | 35 +++-------- 3 files changed, 11 insertions(+), 122 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_engine.py b/skyrl-tx/tests/tinker/test_engine.py index 5cb4397b3e..934dc3efa4 100644 --- a/skyrl-tx/tests/tinker/test_engine.py +++ b/skyrl-tx/tests/tinker/test_engine.py @@ -1,13 +1,12 @@ from cloudpathlib import AnyPath from datetime import datetime, timedelta, timezone -import pytest from sqlmodel import Session, SQLModel from tx.tinker.engine import TinkerEngine, prepare_model_pass_batch from tx.tinker.config import EngineConfig from tx.tinker import types -from tx.tinker.db_models import SessionDB, ModelDB, FutureDB, RequestStatus +from tx.tinker.db_models import SessionDB, ModelDB BASE_MODEL = "trl-internal-testing/tiny-Qwen3ForCausalLM" @@ -124,94 +123,3 @@ def test_prepare_model_pass_batch_loss_fn_config(): batch_no_config = prepare_model_pass_batch(requests_without_config) assert batch_no_config.all_loss_fns == ["cross_entropy"] assert batch_no_config.all_loss_fn_configs == [None] - - -class TestMaxMicroBatches: - """Tests for max_micro_batches limiting in find_batchable_model_passes.""" - - @staticmethod - def _make_request_data(num_sequences: int) -> dict: - """Create a ForwardBackwardInput request data with the given number of sequences.""" - data = [] - for _ in range(num_sequences): - data.append( - { - "model_input": {"chunks": [{"tokens": [1, 2, 3]}]}, - "loss_fn_inputs": { - "target_tokens": {"data": [2, 3, 4]}, - "weights": {"data": [1.0, 1.0, 1.0]}, - "advantages": {"data": [0.0, 0.0, 0.0]}, - "logprobs": {"data": [0.0, 0.0, 0.0]}, - }, - } - ) - return {"data": data, "loss_fn": "cross_entropy"} - - @staticmethod - def _create_engine(train_micro_batch_size: int, max_micro_batches: int) -> TinkerEngine: - """Create an engine with the given micro batch configuration.""" - config = EngineConfig( - base_model=BASE_MODEL, - checkpoints_base=AnyPath(""), - backend_config={ - "max_lora_adapters": 4, - "max_lora_rank": 32, - "train_micro_batch_size": train_micro_batch_size, - }, - max_micro_batches=max_micro_batches, - database_url="sqlite:///:memory:", - ) - engine = TinkerEngine(config) - SQLModel.metadata.create_all(engine.db_engine) - return engine - - def _add_requests(self, engine: TinkerEngine, sequence_counts: list[int]): - """Add FORWARD_BACKWARD requests with the given sequence counts.""" - with Session(engine.db_engine) as session: - for num_sequences in sequence_counts: - session.add( - FutureDB( - request_type=types.RequestType.FORWARD_BACKWARD, - model_id="model1", - request_data=self._make_request_data(num_sequences), - status=RequestStatus.PENDING, - ) - ) - session.commit() - - @pytest.mark.parametrize( - "train_micro_batch_size,max_micro_batches,sequence_counts,expected_count", - [ - # Gradient accumulation mode: ceil(16/4) + ceil(20/4) = 4 + 5 = 9 <= 10, ceil(8/4) = 2 would exceed - (4, 10, [16, 20, 8], 2), - # Full batch mode: each request counts as 1, so 3 requests fit in max_micro_batches=3 - (0, 3, [100, 200, 50, 75], 3), - # Disabled: all requests included when max_micro_batches=0 - (4, 0, [50] * 10, 10), - ], - ids=["gradient_accumulation", "full_batch_mode", "disabled"], - ) - def test_micro_batch_limiting(self, train_micro_batch_size, max_micro_batches, sequence_counts, expected_count): - """Test that micro batches are limited correctly under different configurations.""" - engine = self._create_engine(train_micro_batch_size, max_micro_batches) - self._add_requests(engine, sequence_counts) - - with Session(engine.db_engine) as session: - result = engine.find_batchable_model_passes(session, types.RequestType.FORWARD_BACKWARD) - - assert len(result) == expected_count - - def test_always_includes_at_least_one_request(self): - """Test that at least one request is always included even if it exceeds the limit.""" - # train_micro_batch_size=4, max_micro_batches=10 - # Request with 100 sequences = ceil(100/4) = 25 micro batches > 10 - # Should still be included to avoid starvation - engine = self._create_engine(train_micro_batch_size=4, max_micro_batches=10) - self._add_requests(engine, [100]) - - with Session(engine.db_engine) as session: - result = engine.find_batchable_model_passes(session, types.RequestType.FORWARD_BACKWARD) - - assert len(result) == 1 - _, req_data = list(result.values())[0] - assert len(req_data.data) == 100 diff --git a/skyrl-tx/tx/tinker/config.py b/skyrl-tx/tx/tinker/config.py index ab11a3d339..e126e5499e 100644 --- a/skyrl-tx/tx/tinker/config.py +++ b/skyrl-tx/tx/tinker/config.py @@ -51,10 +51,6 @@ class EngineConfig(BaseModel): default=300, description="Seconds without heartbeat before session is considered stale. Set to -1 to disable cleanup.", ) - max_micro_batches: int = Field( - default=64, - description="Maximum number of micro batches per forward/forward_backward batch. Limits how many are processed before returning results to clients. Set to 0 to disable.", - ) def convert_env_var(env_name: str, env_value: str, expected_type: type): diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index 3838ae9e45..342108fd6f 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -1,7 +1,6 @@ """Background engine for processing training requests.""" import argparse -import math import time from contextlib import contextmanager from datetime import datetime, timedelta, timezone @@ -12,7 +11,6 @@ from pydantic import BaseModel from sqlmodel import create_engine, Session, select, update, func -from tx.tinker.backends.jax import JaxBackend from tx.tinker.db_models import ( FutureDB, RequestStatus, @@ -264,7 +262,7 @@ def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoi session.commit() def find_batchable_model_passes( - self, session: Session, request_type: types.RequestType + self, session: Session, request_type: types.RequestType, max_requests: int = 0 ) -> dict[str, tuple[str, types.ForwardBackwardInput]]: """Find all requests of the given type that come before any destructive update for their model. @@ -274,6 +272,7 @@ def find_batchable_model_passes( Args: session: Database session request_type: The type of request to find (e.g., FORWARD or FORWARD_BACKWARD) + max_requests: Maximum number of requests to return. 0 means no limit. Returns: Dict mapping request_id to (model_id, request_data) tuples @@ -302,25 +301,8 @@ def find_batchable_model_passes( # Filter: only include ops that come before their model's barrier batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]] - # Limit total micro batches if configured - if self.config.max_micro_batches > 0 and isinstance(self.backend, JaxBackend): - micro_batch_size = self.backend.config.train_micro_batch_size - limited = [] - total_micro_batches = 0 - for op in batchable: - num_sequences = len(op.request_data.get("data", [])) - if micro_batch_size > 0: - # Gradient accumulation enabled: count actual micro batches - num_micro_batches = math.ceil(num_sequences / micro_batch_size) - else: - # Full batch mode: each request is processed as one unit - num_micro_batches = 1 - # Always include at least one request to avoid starvation - if limited and total_micro_batches + num_micro_batches > self.config.max_micro_batches: - break - limited.append(op) - total_micro_batches += num_micro_batches - batchable = limited + if max_requests > 0: + batchable = batchable[:max_requests] return { str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data)) @@ -652,11 +634,14 @@ def process_pending_requests(self): while True: # Query for pending requests and extract data within session context with Session(self.db_engine) as session: - # Use look-ahead scheduling to find batchable forward_backward and forward model passes + # Process one request at a time to avoid large batches that cause long + # wait times and retrieve_future requests piling up. forward_backward_requests = self.find_batchable_model_passes( - session, types.RequestType.FORWARD_BACKWARD + session, types.RequestType.FORWARD_BACKWARD, max_requests=1 + ) + forward_requests = self.find_batchable_model_passes( + session, types.RequestType.FORWARD, max_requests=1 ) - forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD) # Find pending sample requests that can be batched sample_requests = self.find_batchable_sample(session) # Get other pending requests (non forward_backward and non sampling) From 6a3a404a9f2a554c33efa55be2ce2087e2609a02 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 12 Feb 2026 15:54:05 -0800 Subject: [PATCH 5/6] Early break in find_batchable_model_passes when max_requests is reached Co-Authored-By: Claude Opus 4.6 --- skyrl-tx/tx/tinker/engine.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index 342108fd6f..f977c9f127 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -299,10 +299,13 @@ def find_batchable_model_passes( ops = session.exec(query).all() # Filter: only include ops that come before their model's barrier - batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]] - - if max_requests > 0: - batchable = batchable[:max_requests] + batchable = [] + for op in ops: + if op.model_id in barriers and op.request_id >= barriers[op.model_id]: + continue + batchable.append(op) + if max_requests > 0 and len(batchable) >= max_requests: + break return { str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data)) From e8d95fd54c428ed6e3b44adf37dd9fb88f26d469 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 12 Feb 2026 15:58:30 -0800 Subject: [PATCH 6/6] lint --- skyrl-tx/tx/tinker/engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index f977c9f127..d678aad1de 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -642,9 +642,7 @@ def process_pending_requests(self): forward_backward_requests = self.find_batchable_model_passes( session, types.RequestType.FORWARD_BACKWARD, max_requests=1 ) - forward_requests = self.find_batchable_model_passes( - session, types.RequestType.FORWARD, max_requests=1 - ) + forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD, max_requests=1) # Find pending sample requests that can be batched sample_requests = self.find_batchable_sample(session) # Get other pending requests (non forward_backward and non sampling)