Skip to content

Commit 56ee425

Browse files
committed
feat(batch): gate sync ThreadPoolExecutor via admission semaphore from _AdaptiveLimiter
## Purpose Sync bulk paths submitted all batches to the thread pool up front with no way to dynamically reduce in-flight concurrency under throttling. Without an admission gate, a pre-sized pool of 8 threads would keep 8 batches in-flight even when the adaptive limiter had dropped the limit to 2. ## Solution Added `limiter_registry` and `host` parameters (both optional, SDK-internal) to `batch_execute`. When wired, a `threading.Condition`-protected inflight counter gates submission: the main thread acquires before each `executor.submit()` and workers release (via `_wrapped_op` finally) on completion. `_acquire()` re-reads `limiter.current_limit()` on every wakeup, so a recovering limiter unlocks additional slots without any extra wiring. When `limiter_registry is None` the gate is a no-op: `_acquire()` and `_release()` return immediately, preserving the original submission behavior. ## Follow-ups Transport callback wiring (DX-0159 for REST, DX-0160 for gRPC) is needed before `report_throttled()` is called in production; until then `current_limit()` stays at ceiling and the gate is functionally a no-op.
1 parent c5f6c33 commit 56ee425

2 files changed

Lines changed: 142 additions & 4 deletions

File tree

pinecone/_internal/batch.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import asyncio
11+
import threading
1112
from concurrent.futures import ThreadPoolExecutor, as_completed
1213
from typing import TYPE_CHECKING, Any, TypeVar
1314

@@ -142,6 +143,8 @@ def batch_execute(
142143
show_progress: bool = True,
143144
desc: str = "Batches",
144145
executor: ThreadPoolExecutor | None = None,
146+
limiter_registry: _AdaptiveLimiterRegistry | None = None,
147+
host: str | None = None,
145148
) -> BatchResult:
146149
"""Execute *operation* on *items* in parallel batches.
147150
@@ -163,6 +166,10 @@ def batch_execute(
163166
or torn down per call. Caller is responsible for ``shutdown()``.
164167
When ``None`` (default), a private executor is created and
165168
shut down at the end of this call.
169+
limiter_registry (_AdaptiveLimiterRegistry | None): Optional registry
170+
for adaptive concurrency. SDK-internal; not for user code.
171+
host (str | None): Host key for the limiter registry lookup.
172+
SDK-internal; not for user code.
166173
167174
Returns:
168175
BatchResult with aggregated success/failure counts.
@@ -182,16 +189,49 @@ def batch_execute(
182189
lsn_reconciled_values: list[int] = []
183190
lsn_committed_values: list[int] = []
184191

192+
if limiter_registry is not None and host is not None:
193+
limiter = limiter_registry.get(host, max_concurrency)
194+
else:
195+
limiter = None
196+
197+
condition = threading.Condition()
198+
inflight = 0
199+
200+
def _acquire() -> None:
201+
nonlocal inflight
202+
if limiter is None:
203+
return
204+
with condition:
205+
while inflight >= limiter.current_limit():
206+
condition.wait()
207+
inflight += 1
208+
209+
def _release() -> None:
210+
nonlocal inflight
211+
if limiter is None:
212+
return
213+
with condition:
214+
inflight -= 1
215+
condition.notify_all()
216+
217+
def _wrapped_op(batch: list[dict[str, Any]]) -> Any:
218+
try:
219+
return operation(batch)
220+
finally:
221+
_release()
222+
185223
progress = _create_progress_bar(total_batches, desc, show_progress)
186224

187225
own_executor = executor is None
188226
if executor is None:
189227
executor = ThreadPoolExecutor(max_workers=max_concurrency)
190228

191229
try:
192-
future_to_batch = {
193-
executor.submit(operation, batch): (idx, batch) for idx, batch in enumerate(batches)
194-
}
230+
future_to_batch: dict[Any, tuple[int, list[dict[str, Any]]]] = {}
231+
for idx, batch in enumerate(batches):
232+
_acquire()
233+
future = executor.submit(_wrapped_op, batch)
234+
future_to_batch[future] = (idx, batch)
195235

196236
for future in as_completed(future_to_batch):
197237
batch_idx, batch = future_to_batch[future]

tests/unit/_internal/test_batch.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from __future__ import annotations
44

55
import asyncio
6+
import threading
7+
import time
68
from typing import Any
79

810
import pytest
911

1012
from pinecone._internal.adaptive import _AdaptiveLimiterRegistry
11-
from pinecone._internal.batch import async_batch_execute
13+
from pinecone._internal.batch import async_batch_execute, batch_execute
1214

1315

1416
class TestAdaptiveBatchExecute:
@@ -200,3 +202,99 @@ async def slow(batch: list[dict[str, Any]]) -> Any:
200202
host="host2",
201203
)
202204
assert peak[0] <= 3
205+
206+
207+
class TestAdaptiveBatchExecuteSync:
208+
def test_sync_no_limiter_unchanged(self) -> None:
209+
"""Without a limiter_registry, behavior is identical to baseline."""
210+
items = [{"id": str(i), "values": [0.1, 0.2]} for i in range(6)]
211+
212+
def op(batch: list[dict[str, Any]]) -> Any:
213+
return {"upserted_count": len(batch)}
214+
215+
result = batch_execute(
216+
items=items,
217+
operation=op,
218+
batch_size=2,
219+
max_concurrency=4,
220+
show_progress=False,
221+
)
222+
assert result.successful_item_count == 6
223+
assert result.failed_item_count == 0
224+
assert result.total_item_count == 6
225+
226+
def test_sync_uses_limiter(self) -> None:
227+
"""With a pre-throttled limiter (4), max observed inflight must not exceed 4."""
228+
registry = _AdaptiveLimiterRegistry()
229+
limiter = registry.get("test-host", 8)
230+
limiter.report_throttled() # 8 → 4
231+
assert limiter.current_limit() == 4
232+
233+
inflight_counter = 0
234+
max_observed = [0]
235+
lock = threading.Lock()
236+
237+
def slow_op(batch: list[dict[str, Any]]) -> Any:
238+
nonlocal inflight_counter
239+
with lock:
240+
inflight_counter += 1
241+
max_observed[0] = max(max_observed[0], inflight_counter)
242+
time.sleep(0.05)
243+
with lock:
244+
inflight_counter -= 1
245+
return {"upserted_count": len(batch)}
246+
247+
items = [{"id": str(i), "values": [0.1, 0.2]} for i in range(20)]
248+
result = batch_execute(
249+
items=items,
250+
operation=slow_op,
251+
batch_size=2, # 10 batches
252+
max_concurrency=8,
253+
show_progress=False,
254+
limiter_registry=registry,
255+
host="test-host",
256+
)
257+
assert max_observed[0] <= 4, (
258+
f"observed max concurrency {max_observed[0]} exceeds limiter cap"
259+
)
260+
assert result.successful_item_count == 20
261+
assert result.failed_item_count == 0
262+
263+
def test_sync_concurrency_recovers_after_signals(self) -> None:
264+
"""Concurrency recovers as limiter is signalled with successes mid-execution."""
265+
registry = _AdaptiveLimiterRegistry()
266+
limiter = registry.get("test-host", 8)
267+
# Pre-throttle to 1
268+
for _ in range(3):
269+
limiter.report_throttled()
270+
assert limiter.current_limit() == 1
271+
272+
call_count = 0
273+
call_lock = threading.Lock()
274+
limits_observed: list[int] = []
275+
276+
def op(batch: list[dict[str, Any]]) -> Any:
277+
nonlocal call_count
278+
with call_lock:
279+
call_count += 1
280+
current = call_count
281+
# After 3 calls, recover by reporting successes
282+
if current == 3:
283+
for _ in range(20):
284+
limiter.report_success()
285+
limits_observed.append(limiter.current_limit())
286+
return {"upserted_count": len(batch)}
287+
288+
items = [{"id": str(i)} for i in range(20)] # 10 batches of 2
289+
result = batch_execute(
290+
items=items,
291+
operation=op,
292+
batch_size=2,
293+
max_concurrency=8,
294+
show_progress=False,
295+
limiter_registry=registry,
296+
host="test-host",
297+
)
298+
assert result.successful_item_count == 20
299+
assert result.failed_item_count == 0
300+
assert max(limits_observed) > 1, "limiter should have recovered above 1"

0 commit comments

Comments
 (0)