diff --git a/customerio/client_base.py b/customerio/client_base.py index 1096336..6e95bad 100644 --- a/customerio/client_base.py +++ b/customerio/client_base.py @@ -2,7 +2,6 @@ Implements the base client that is used by other classes to make requests. """ -import logging import math from datetime import datetime, timezone @@ -28,7 +27,7 @@ def __init__(self, retries=3, timeout=10, backoff_factor=0.02, use_connection_po @property def http(self): if self._current_session is None: - self._current_session = self._get_session() + self._current_session = self._build_session() return self._current_session @@ -36,12 +35,21 @@ def send_request(self, method, url, data): """Dispatches the request and returns a response.""" try: - response = self.http.request( - method, - url=url, - json=self._sanitize(data), - timeout=self.timeout, - ) + if self.use_connection_pooling: + response = self.http.request( + method, + url=url, + json=self._sanitize(data), + timeout=self.timeout, + ) + else: + with self._build_session() as http: + response = http.request( + method, + url=url, + json=self._sanitize(data), + timeout=self.timeout, + ) result_status = response.status_code if result_status != 200: @@ -57,9 +65,6 @@ def send_request(self, method, url, data): """ raise CustomerIOException(message) from e - finally: - self._close() - def _sanitize(self, data): return {key: self._sanitize_value(value) for key, value in data.items()} @@ -84,18 +89,6 @@ def _stringify_list(self, customer_ids): raise CustomerIOException(f"customer_ids cannot be {type(v)}") return customer_string_ids - def _get_session(self): - if self.use_connection_pooling: - if self._current_session is None: - self._current_session = self._build_session() - - logging.debug("Using existing session...") - return self._current_session - - logging.debug("Creating new session...") - self._current_session = self._build_session() - return self._current_session - def _build_session(self): session = Session() session.headers["User-Agent"] = f"Customer.io Python Client/{ClientVersion}" @@ -106,8 +99,3 @@ def _build_session(self): ) return session - - def _close(self): - if not self.use_connection_pooling and self._current_session is not None: - self._current_session.close() - self._current_session = None diff --git a/tests/test_client_base.py b/tests/test_client_base.py new file mode 100644 index 0000000..c6cef74 --- /dev/null +++ b/tests/test_client_base.py @@ -0,0 +1,87 @@ +import threading +import unittest + +from customerio.client_base import ClientBase + + +class FakeResponse: + status_code = 200 + text = "ok" + + +class FakeSession: + def __init__(self, request_barrier=None): + self.closed = False + self.request_barrier = request_barrier + self.request_count = 0 + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def request(self, *args, **kwargs): + self.request_count += 1 + if self.request_barrier is not None: + self.request_barrier.wait(timeout=5) + return FakeResponse() + + def close(self): + self.closed = True + + +class TestClientBase(unittest.TestCase): + def test_connection_pooling_reuses_session(self): + client = ClientBase(use_connection_pooling=True) + sessions = [] + + def build_session(): + session = FakeSession() + sessions.append(session) + return session + + client._build_session = build_session + + self.assertEqual(client.send_request("POST", "https://example.com", {}), "ok") + self.assertEqual(client.send_request("POST", "https://example.com", {}), "ok") + + self.assertEqual(len(sessions), 1) + self.assertFalse(sessions[0].closed) + + def test_disabled_connection_pooling_isolates_overlapping_requests(self): + client = ClientBase(use_connection_pooling=False) + request_barrier = threading.Barrier(2) + sessions = [] + responses = [] + errors = [] + + def build_session(): + session = FakeSession(request_barrier) + sessions.append(session) + return session + + def send_request(): + try: + responses.append(client.send_request("POST", "https://example.com", {})) + except Exception as exc: # pragma: no cover + errors.append(exc) + + client._build_session = build_session + + threads = [threading.Thread(target=send_request) for _ in range(2)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + self.assertEqual(errors, []) + self.assertEqual(sorted(responses), ["ok", "ok"]) + self.assertEqual(len(sessions), 2) + self.assertTrue(all(session.closed for session in sessions)) + self.assertTrue(all(session.request_count == 1 for session in sessions)) + self.assertIsNone(client._current_session) + + +if __name__ == "__main__": + unittest.main()