From 163b53a6f1ff0726db40db7dc6a4f7805bf9053e Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 6 Jan 2026 20:05:57 +0800 Subject: [PATCH 1/3] Add default batch size for `SQLReader` --- tests/buffer/sql_test.py | 5 ++++- trinity/buffer/reader/sql_reader.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 4d63f041e4..e78c17dfc2 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -1,4 +1,5 @@ import os +from copy import deepcopy import ray import torch @@ -40,7 +41,9 @@ async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None: ) if enable_replay: config.replay_buffer = ReplayBufferConfig(enable=True) - sql_writer = SQLWriter(config.to_storage_config()) + writer_config = deepcopy(config) + writer_config.batch_size = put_batch_size + sql_writer = SQLWriter(writer_config.to_storage_config()) sql_reader = SQLReader(config.to_storage_config()) exps = [ Experience( diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index f7572c628c..b9e21207ca 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -16,15 +16,18 @@ class SQLReader(BufferReader): def __init__(self, config: StorageConfig) -> None: assert config.storage_type == StorageType.SQL.value self.wrap_in_ray = config.wrap_in_ray + self.read_batch_size = config.batch_size self.storage = SQLStorage.get_wrapper(config) def read(self, batch_size: Optional[int] = None, **kwargs) -> List: + batch_size = batch_size or self.read_batch_size if self.wrap_in_ray: return ray.get(self.storage.read.remote(batch_size, **kwargs)) else: return self.storage.read(batch_size, **kwargs) async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: + batch_size = batch_size or self.read_batch_size if self.wrap_in_ray: try: return await self.storage.read.remote(batch_size, **kwargs) From 8aedc049e7c0e4feb5c590c49e6baa18d185f30f Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 6 Jan 2026 20:25:50 +0800 Subject: [PATCH 2/3] apply suggestions from gemini --- trinity/buffer/reader/queue_reader.py | 4 ++-- trinity/buffer/reader/sql_reader.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index b3b1d14c12..6f8e7c0334 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -21,7 +21,7 @@ def __init__(self, config: StorageConfig): def read(self, batch_size: Optional[int] = None, **kwargs) -> List: try: - batch_size = batch_size or self.read_batch_size + batch_size = self.read_batch_size if batch_size is None else batch_size exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs)) if len(exps) != batch_size: raise TimeoutError( @@ -32,7 +32,7 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List: return exps async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: - batch_size = batch_size or self.read_batch_size + batch_size = self.read_batch_size if batch_size is None else batch_size exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs) if len(exps) != batch_size: raise TimeoutError( diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index b9e21207ca..fd1425bb8f 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -20,14 +20,14 @@ def __init__(self, config: StorageConfig) -> None: self.storage = SQLStorage.get_wrapper(config) def read(self, batch_size: Optional[int] = None, **kwargs) -> List: - batch_size = batch_size or self.read_batch_size + batch_size = self.read_batch_size if batch_size is None else batch_size if self.wrap_in_ray: return ray.get(self.storage.read.remote(batch_size, **kwargs)) else: return self.storage.read(batch_size, **kwargs) async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: - batch_size = batch_size or self.read_batch_size + batch_size = self.read_batch_size if batch_size is None else batch_size if self.wrap_in_ray: try: return await self.storage.read.remote(batch_size, **kwargs) From 638f7e68670e7292e87e7740980775755f04e182 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 6 Jan 2026 20:32:28 +0800 Subject: [PATCH 3/3] apply suggestions from gemini --- tests/buffer/sql_test.py | 2 ++ trinity/buffer/storage/sql.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index e78c17dfc2..a1a9f39907 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -43,6 +43,8 @@ async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None: config.replay_buffer = ReplayBufferConfig(enable=True) writer_config = deepcopy(config) writer_config.batch_size = put_batch_size + # Create buffer by writer, so buffer.batch_size will be set to put_batch_size + # This will check whether read_batch_size tasks effect sql_writer = SQLWriter(writer_config.to_storage_config()) sql_reader = SQLReader(config.to_storage_config()) exps = [ diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index 04f0c20bda..eb1097cdc7 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -197,7 +197,7 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: if self.stopped: raise StopIteration() - batch_size = batch_size or self.batch_size + batch_size = self.batch_size if batch_size is None else batch_size return self._read_method(batch_size, **kwargs) @classmethod @@ -248,7 +248,7 @@ def read(self, batch_size: Optional[int] = None) -> List[Task]: raise StopIteration() if self.offset > self.total_samples: raise StopIteration() - batch_size = batch_size or self.batch_size + batch_size = self.batch_size if batch_size is None else batch_size with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: query = ( session.query(self.table_model_cls)