Skip to content

Commit a5fda9d

Browse files
committed
refactor(http-client): extract shared backoff helpers as module-level functions
## Purpose `_compute_backoff`, `_compute_retry_after_delay`, and `_notify_throttle` were duplicated identically between `_RetryTransport` and `_AsyncRetryTransport`. Any future change to the decorrelated-jitter formula or Retry-After parsing had to be applied in two places, with no enforcement that they stayed in sync. ## Solution Extracted all three as module-level functions that accept `config: RetryConfig` as their first argument. Both transport classes now call the shared functions instead of their own instance methods. Updated existing jitter-distribution tests in `test_http_client.py` (which were calling the instance methods directly) and added dedicated unit tests for each helper in `test_retry_transport.py`.
1 parent db6f54e commit a5fda9d

3 files changed

Lines changed: 166 additions & 87 deletions

File tree

pinecone/_internal/http_client.py

Lines changed: 43 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,43 @@ def _log_curl(
118118
logger.debug("curl equivalent:\n%s", curl_cmd)
119119

120120

121+
def _compute_backoff(config: RetryConfig, attempt: int, prev_delay: float | None) -> float:
122+
"""Decorrelated jitter: uniform(base, prev*3), capped at max_wait."""
123+
base_delay = config.backoff_factor
124+
if prev_delay is None:
125+
prev_delay = base_delay
126+
upper = min(config.max_wait, prev_delay * 3.0)
127+
return random.uniform(base_delay, max(base_delay, upper))
128+
129+
130+
def _compute_retry_after_delay(
131+
config: RetryConfig,
132+
response: httpx.Response,
133+
attempt: int,
134+
prev_delay: float | None,
135+
) -> float:
136+
retry_after = response.headers.get("retry-after")
137+
if retry_after is not None:
138+
try:
139+
ra = float(retry_after)
140+
if ra >= 0:
141+
smear = random.uniform(0.0, ra * 0.5)
142+
return ra + smear
143+
except (ValueError, TypeError):
144+
pass
145+
return _compute_backoff(config, attempt, prev_delay)
146+
147+
148+
def _notify_throttle(config: RetryConfig, request: httpx.Request) -> None:
149+
cb = config.on_throttle
150+
if cb is None:
151+
return
152+
try:
153+
cb(request.url.host)
154+
except Exception as exc:
155+
logger.debug("on_throttle callback raised, ignoring: %s", exc)
156+
157+
121158
class _RetryTransport(httpx.BaseTransport):
122159
"""Sync transport wrapper that retries on transient server errors."""
123160

@@ -145,17 +182,17 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
145182
self._config.max_retries + 1,
146183
exc,
147184
)
148-
delay = self._compute_backoff(attempt, prev_delay)
185+
delay = _compute_backoff(self._config, attempt, prev_delay)
149186
prev_delay = delay
150187
time.sleep(delay)
151188
continue
152189
last_exc = None
153190
if response.status_code not in self._config.retryable_status_codes:
154191
return response
155-
self._notify_throttle(request)
192+
_notify_throttle(self._config, request)
156193
if attempt < self._config.max_retries:
157194
response.close()
158-
delay = self._compute_retry_after_delay(response, attempt, prev_delay)
195+
delay = _compute_retry_after_delay(self._config, response, attempt, prev_delay)
159196
prev_delay = delay
160197
time.sleep(delay)
161198
else:
@@ -164,40 +201,6 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
164201
raise last_exc
165202
raise RuntimeError("max_retries must be non-negative")
166203

167-
def _compute_backoff(self, attempt: int, prev_delay: float | None) -> float:
168-
"""Decorrelated jitter: uniform(base, prev*3), capped at max_wait."""
169-
base_delay = self._config.backoff_factor
170-
if prev_delay is None:
171-
prev_delay = base_delay
172-
upper = min(self._config.max_wait, prev_delay * 3.0)
173-
return random.uniform(base_delay, max(base_delay, upper))
174-
175-
def _compute_retry_after_delay(
176-
self,
177-
response: httpx.Response,
178-
attempt: int,
179-
prev_delay: float | None,
180-
) -> float:
181-
retry_after = response.headers.get("retry-after")
182-
if retry_after is not None:
183-
try:
184-
ra = float(retry_after)
185-
if ra >= 0:
186-
smear = random.uniform(0.0, ra * 0.5)
187-
return ra + smear
188-
except (ValueError, TypeError):
189-
pass
190-
return self._compute_backoff(attempt, prev_delay)
191-
192-
def _notify_throttle(self, request: httpx.Request) -> None:
193-
cb = self._config.on_throttle
194-
if cb is None:
195-
return
196-
try:
197-
cb(request.url.host)
198-
except Exception as exc:
199-
logger.debug("on_throttle callback raised, ignoring: %s", exc)
200-
201204
def close(self) -> None:
202205
self._transport.close()
203206

@@ -229,17 +232,17 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
229232
self._config.max_retries + 1,
230233
exc,
231234
)
232-
delay = self._compute_backoff(attempt, prev_delay)
235+
delay = _compute_backoff(self._config, attempt, prev_delay)
233236
prev_delay = delay
234237
await asyncio.sleep(delay)
235238
continue
236239
last_exc = None
237240
if response.status_code not in self._config.retryable_status_codes:
238241
return response
239-
self._notify_throttle(request)
242+
_notify_throttle(self._config, request)
240243
if attempt < self._config.max_retries:
241244
await response.aclose()
242-
delay = self._compute_retry_after_delay(response, attempt, prev_delay)
245+
delay = _compute_retry_after_delay(self._config, response, attempt, prev_delay)
243246
prev_delay = delay
244247
await asyncio.sleep(delay)
245248
else:
@@ -248,40 +251,6 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
248251
raise last_exc
249252
raise RuntimeError("max_retries must be non-negative")
250253

251-
def _compute_backoff(self, attempt: int, prev_delay: float | None) -> float:
252-
"""Decorrelated jitter: uniform(base, prev*3), capped at max_wait."""
253-
base_delay = self._config.backoff_factor
254-
if prev_delay is None:
255-
prev_delay = base_delay
256-
upper = min(self._config.max_wait, prev_delay * 3.0)
257-
return random.uniform(base_delay, max(base_delay, upper))
258-
259-
def _compute_retry_after_delay(
260-
self,
261-
response: httpx.Response,
262-
attempt: int,
263-
prev_delay: float | None,
264-
) -> float:
265-
retry_after = response.headers.get("retry-after")
266-
if retry_after is not None:
267-
try:
268-
ra = float(retry_after)
269-
if ra >= 0:
270-
smear = random.uniform(0.0, ra * 0.5)
271-
return ra + smear
272-
except (ValueError, TypeError):
273-
pass
274-
return self._compute_backoff(attempt, prev_delay)
275-
276-
def _notify_throttle(self, request: httpx.Request) -> None:
277-
cb = self._config.on_throttle
278-
if cb is None:
279-
return
280-
try:
281-
cb(request.url.host)
282-
except Exception as exc:
283-
logger.debug("on_throttle callback raised, ignoring: %s", exc)
284-
285254
async def aclose(self) -> None:
286255
await self._transport.aclose()
287256

tests/unit/_internal/test_retry_transport.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
import pytest
99

1010
from pinecone._internal.config import RetryConfig
11-
from pinecone._internal.http_client import _AsyncRetryTransport, _RetryTransport
11+
from pinecone._internal.http_client import (
12+
_AsyncRetryTransport,
13+
_compute_backoff,
14+
_compute_retry_after_delay,
15+
_notify_throttle,
16+
_RetryTransport,
17+
)
1218

1319

1420
def _transport(max_retries: int = 3) -> tuple[_RetryTransport, MagicMock]:
@@ -218,3 +224,107 @@ async def test_async_post_retried_on_408() -> None:
218224
result = await rt.handle_async_request(httpx.Request("POST", "https://example.com/query"))
219225
assert result.status_code == 200
220226
assert inner.handle_async_request.call_count == 2
227+
228+
229+
# --- module-level helper function tests ---
230+
231+
232+
def _cfg(backoff_factor: float = 0.5, max_wait: float = 60.0) -> RetryConfig:
233+
return RetryConfig(backoff_factor=backoff_factor, max_wait=max_wait)
234+
235+
236+
def test_compute_backoff_first_attempt_stays_at_base() -> None:
237+
cfg = _cfg(backoff_factor=0.5, max_wait=60.0)
238+
result = _compute_backoff(cfg, 0, None)
239+
assert 0.5 <= result <= 1.5 # uniform(base, max(base, min(max_wait, base*3)))
240+
241+
242+
def test_compute_backoff_grows_with_prev_delay() -> None:
243+
cfg = _cfg(backoff_factor=0.5, max_wait=60.0)
244+
result = _compute_backoff(cfg, 1, 2.0)
245+
# upper = min(60, 2.0 * 3) = 6.0; result in [0.5, 6.0]
246+
assert 0.5 <= result <= 6.0
247+
248+
249+
def test_compute_backoff_capped_at_max_wait() -> None:
250+
cfg = _cfg(backoff_factor=1.0, max_wait=2.0)
251+
result = _compute_backoff(cfg, 5, 100.0)
252+
# upper = min(2.0, 300.0) = 2.0; result in [1.0, 2.0]
253+
assert 1.0 <= result <= 2.0
254+
255+
256+
def test_compute_retry_after_uses_header_value() -> None:
257+
cfg = _cfg()
258+
response = httpx.Response(429, headers={"Retry-After": "5"})
259+
result = _compute_retry_after_delay(cfg, response, 0, None)
260+
# Should be between 5 (no smear) and 7.5 (5 + 5*0.5)
261+
assert 5.0 <= result <= 7.5
262+
263+
264+
def test_compute_retry_after_ignores_invalid_header() -> None:
265+
cfg = _cfg(backoff_factor=0.5, max_wait=60.0)
266+
response = httpx.Response(429, headers={"Retry-After": "not-a-number"})
267+
result = _compute_retry_after_delay(cfg, response, 0, None)
268+
# Falls back to _compute_backoff
269+
assert result >= 0.5
270+
271+
272+
def test_compute_retry_after_no_header_falls_back_to_backoff() -> None:
273+
cfg = _cfg(backoff_factor=0.5, max_wait=60.0)
274+
response = httpx.Response(503)
275+
result = _compute_retry_after_delay(cfg, response, 0, None)
276+
assert result >= 0.5
277+
278+
279+
def test_notify_throttle_calls_callback() -> None:
280+
calls: list[str] = []
281+
cfg = RetryConfig(on_throttle=lambda host: calls.append(host))
282+
request = httpx.Request("POST", "https://example.com/vectors/upsert")
283+
_notify_throttle(cfg, request)
284+
assert calls == ["example.com"]
285+
286+
287+
def test_notify_throttle_no_callback_is_noop() -> None:
288+
cfg = RetryConfig(on_throttle=None)
289+
request = httpx.Request("POST", "https://example.com/query")
290+
_notify_throttle(cfg, request) # should not raise
291+
292+
293+
def test_notify_throttle_swallows_callback_exception() -> None:
294+
def bad_cb(host: str) -> None:
295+
raise RuntimeError("oops")
296+
297+
cfg = RetryConfig(on_throttle=bad_cb)
298+
request = httpx.Request("POST", "https://example.com/query")
299+
_notify_throttle(cfg, request) # should not raise
300+
301+
302+
def test_sync_transport_calls_module_level_compute_backoff() -> None:
303+
"""Verify sync transport uses module-level _compute_backoff (not an instance method)."""
304+
rt, inner = _transport(max_retries=1)
305+
inner.handle_request.side_effect = [
306+
httpx.ConnectError("boom"),
307+
httpx.Response(200),
308+
]
309+
with patch("pinecone._internal.http_client._compute_backoff", return_value=0.001) as mock_cb:
310+
rt.handle_request(_req())
311+
mock_cb.assert_called_once()
312+
args = mock_cb.call_args[0]
313+
assert isinstance(args[0], RetryConfig)
314+
315+
316+
def test_async_transport_calls_module_level_compute_backoff() -> None:
317+
"""Verify async transport uses module-level _compute_backoff (not an instance method)."""
318+
import asyncio
319+
320+
rt, inner = _async_transport(max_retries=1)
321+
inner.handle_async_request.side_effect = [
322+
httpx.ConnectError("boom"),
323+
httpx.Response(200),
324+
]
325+
326+
with patch("pinecone._internal.http_client._compute_backoff", return_value=0.001) as mock_cb:
327+
asyncio.get_event_loop().run_until_complete(rt.handle_async_request(_req()))
328+
mock_cb.assert_called_once()
329+
args = mock_cb.call_args[0]
330+
assert isinstance(args[0], RetryConfig)

tests/unit/test_http_client.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
HTTPClient,
1919
_AsyncRetryTransport,
2020
_build_headers,
21+
_compute_backoff,
22+
_compute_retry_after_delay,
2123
_encode_json,
2224
_log_curl,
2325
_prepare_json_kwargs,
@@ -895,7 +897,7 @@ def test_backoff_first_attempt_delay_range(self) -> None:
895897

896898
random.seed(42)
897899
t = _make_sync_retry_transport()
898-
delays = [t._compute_backoff(0, None) for _ in range(200)]
900+
delays = [_compute_backoff(t._config, 0, None) for _ in range(200)]
899901
assert all(0.25 <= d <= 0.75 for d in delays), (
900902
f"out-of-range: {[d for d in delays if not (0.25 <= d <= 0.75)]}"
901903
)
@@ -908,26 +910,24 @@ def test_backoff_first_attempt_delay_range(self) -> None:
908910

909911
def test_backoff_evolves_with_prev_delay(self) -> None:
910912
random.seed(42)
911-
t = _make_sync_retry_transport()
912913
cfg = RetryConfig(max_wait=60.0)
913-
t._config = cfg
914-
delays = [t._compute_backoff(2, 5.0) for _ in range(200)]
914+
delays = [_compute_backoff(cfg, 2, 5.0) for _ in range(200)]
915915
# Decorrelated jitter: uniform(0.25, min(60.0, 5.0*3)) = uniform(0.25, 15.0)
916916
assert all(0.25 <= d <= 15.0 for d in delays)
917917
assert any(d > 5.0 for d in delays), "upper bound did not expand with prev_delay"
918918

919919
def test_backoff_capped_at_max_wait(self) -> None:
920920
random.seed(42)
921-
t = _make_sync_retry_transport(RetryConfig(max_wait=10.0))
922-
delays = [t._compute_backoff(0, 100.0) for _ in range(50)]
921+
cfg = RetryConfig(max_wait=10.0)
922+
delays = [_compute_backoff(cfg, 0, 100.0) for _ in range(50)]
923923
# prev_delay=100 would push upper to 300, but max_wait caps at 10.0
924924
assert all(d <= 10.0 for d in delays)
925925

926926
def test_retry_after_smear_range(self) -> None:
927927
random.seed(42)
928928
t = _make_sync_retry_transport()
929929
response = httpx.Response(429, headers={"retry-after": "60"})
930-
delays = [t._compute_retry_after_delay(response, 0, None) for _ in range(200)]
930+
delays = [_compute_retry_after_delay(t._config, response, 0, None) for _ in range(200)]
931931
# Smear: delay in [60, 60 + 0.5*60) = [60, 90)
932932
assert all(60.0 <= d < 90.0 for d in delays), (
933933
f"out-of-range delays: {[d for d in delays if not (60.0 <= d < 90.0)]}"
@@ -939,23 +939,23 @@ def test_retry_after_falls_back_to_backoff_when_invalid(self) -> None:
939939
random.seed(42)
940940
t = _make_sync_retry_transport()
941941
response = httpx.Response(429, headers={"retry-after": "Fri, 31 Dec 2026 23:59:59 GMT"})
942-
delays = [t._compute_retry_after_delay(response, 0, None) for _ in range(50)]
942+
delays = [_compute_retry_after_delay(t._config, response, 0, None) for _ in range(50)]
943943
# HTTP-date is not parseable as float; falls back to backoff which is uniform(0.25, 0.75)
944944
assert all(0.25 <= d <= 0.75 for d in delays)
945945

946946
def test_retry_after_falls_back_to_backoff_when_missing(self) -> None:
947947
random.seed(42)
948948
t = _make_sync_retry_transport()
949949
response = httpx.Response(500)
950-
delays = [t._compute_retry_after_delay(response, 0, None) for _ in range(50)]
950+
delays = [_compute_retry_after_delay(t._config, response, 0, None) for _ in range(50)]
951951
# Falls back to backoff: uniform(0.25, 0.75)
952952
assert all(0.25 <= d <= 0.75 for d in delays)
953953

954954
def test_negative_retry_after_falls_back_to_backoff(self) -> None:
955955
random.seed(42)
956956
t = _make_sync_retry_transport()
957957
response = httpx.Response(429, headers={"retry-after": "-1"})
958-
delays = [t._compute_retry_after_delay(response, 0, None) for _ in range(50)]
958+
delays = [_compute_retry_after_delay(t._config, response, 0, None) for _ in range(50)]
959959
# Negative values are ignored; falls back to backoff: uniform(0.25, 0.75)
960960
assert all(0.25 <= d <= 0.75 for d in delays)
961961

@@ -971,7 +971,7 @@ def test_backoff_first_attempt_delay_range(self) -> None:
971971

972972
random.seed(42)
973973
t = _make_async_retry_transport()
974-
delays = [t._compute_backoff(0, None) for _ in range(200)]
974+
delays = [_compute_backoff(t._config, 0, None) for _ in range(200)]
975975
assert all(0.25 <= d <= 0.75 for d in delays), (
976976
f"out-of-range: {[d for d in delays if not (0.25 <= d <= 0.75)]}"
977977
)
@@ -986,7 +986,7 @@ def test_retry_after_smear_range(self) -> None:
986986
random.seed(42)
987987
t = _make_async_retry_transport()
988988
response = httpx.Response(429, headers={"retry-after": "60"})
989-
delays = [t._compute_retry_after_delay(response, 0, None) for _ in range(200)]
989+
delays = [_compute_retry_after_delay(t._config, response, 0, None) for _ in range(200)]
990990
# Smear: delay in [60, 60 + 0.5*60) = [60, 90)
991991
assert all(60.0 <= d < 90.0 for d in delays), (
992992
f"out-of-range delays: {[d for d in delays if not (60.0 <= d < 90.0)]}"

0 commit comments

Comments
 (0)