Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 16 additions & 28 deletions customerio/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Implements the base client that is used by other classes to make requests.
"""

import logging
import math
from datetime import datetime, timezone

Expand All @@ -28,20 +27,29 @@ def __init__(self, retries=3, timeout=10, backoff_factor=0.02, use_connection_po
@property
def http(self):
if self._current_session is None:
self._current_session = self._get_session()
self._current_session = self._build_session()

return self._current_session

def send_request(self, method, url, data):
"""Dispatches the request and returns a response."""

try:
response = self.http.request(
method,
url=url,
json=self._sanitize(data),
timeout=self.timeout,
)
if self.use_connection_pooling:
response = self.http.request(
method,
url=url,
json=self._sanitize(data),
timeout=self.timeout,
)
else:
with self._build_session() as http:
response = http.request(
method,
url=url,
json=self._sanitize(data),
timeout=self.timeout,
)

result_status = response.status_code
if result_status != 200:
Expand All @@ -57,9 +65,6 @@ def send_request(self, method, url, data):
"""
raise CustomerIOException(message) from e

finally:
self._close()

def _sanitize(self, data):
return {key: self._sanitize_value(value) for key, value in data.items()}

Expand All @@ -84,18 +89,6 @@ def _stringify_list(self, customer_ids):
raise CustomerIOException(f"customer_ids cannot be {type(v)}")
return customer_string_ids

def _get_session(self):
if self.use_connection_pooling:
if self._current_session is None:
self._current_session = self._build_session()

logging.debug("Using existing session...")
return self._current_session

logging.debug("Creating new session...")
self._current_session = self._build_session()
return self._current_session

def _build_session(self):
session = Session()
session.headers["User-Agent"] = f"Customer.io Python Client/{ClientVersion}"
Expand All @@ -106,8 +99,3 @@ def _build_session(self):
)

return session

def _close(self):
if not self.use_connection_pooling and self._current_session is not None:
self._current_session.close()
self._current_session = None
87 changes: 87 additions & 0 deletions tests/test_client_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import threading
import unittest

from customerio.client_base import ClientBase


class FakeResponse:
status_code = 200
text = "ok"


class FakeSession:
def __init__(self, request_barrier=None):
self.closed = False
self.request_barrier = request_barrier
self.request_count = 0

def __enter__(self):
return self

def __exit__(self, *args):
self.close()

def request(self, *args, **kwargs):
self.request_count += 1
if self.request_barrier is not None:
self.request_barrier.wait(timeout=5)
return FakeResponse()

def close(self):
self.closed = True


class TestClientBase(unittest.TestCase):
def test_connection_pooling_reuses_session(self):
client = ClientBase(use_connection_pooling=True)
sessions = []

def build_session():
session = FakeSession()
sessions.append(session)
return session

client._build_session = build_session

self.assertEqual(client.send_request("POST", "https://example.com", {}), "ok")
self.assertEqual(client.send_request("POST", "https://example.com", {}), "ok")

self.assertEqual(len(sessions), 1)
self.assertFalse(sessions[0].closed)

def test_disabled_connection_pooling_isolates_overlapping_requests(self):
client = ClientBase(use_connection_pooling=False)
request_barrier = threading.Barrier(2)
sessions = []
responses = []
errors = []

def build_session():
session = FakeSession(request_barrier)
sessions.append(session)
return session

def send_request():
try:
responses.append(client.send_request("POST", "https://example.com", {}))
except Exception as exc: # pragma: no cover
errors.append(exc)

client._build_session = build_session

threads = [threading.Thread(target=send_request) for _ in range(2)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

self.assertEqual(errors, [])
self.assertEqual(sorted(responses), ["ok", "ok"])
self.assertEqual(len(sessions), 2)
self.assertTrue(all(session.closed for session in sessions))
self.assertTrue(all(session.request_count == 1 for session in sessions))
self.assertIsNone(client._current_session)


if __name__ == "__main__":
unittest.main()