From bd6c072a873ce6d4d92ddc7a55af5b54b27bf3e3 Mon Sep 17 00:00:00 2001 From: Marcus Kinsella Date: Tue, 23 Oct 2018 21:13:43 +0000 Subject: [PATCH 1/2] Improve thread safety of SwaggerClient SwaggerClient discovers methods by reading a swagger document, creating instances of _ClientMethodFactory for paths in the swagger, and assigning those instances as class attributes of SwaggerClient. This means that a single instance of each _ClientMethodFactory is shared across threads. This is not so good because _ClientMethodFactories store state associated with requests. In particular, consider the following scenario: Thread-1: client = hca.dss.DSSClient() Thread-2: client = hca.dss.DSSClient() Thread-1: with client.get_file.stream(...) as handle: Thread-2: with client.get_file.stream(...) as handle: Thread-1: handle.raw.read() Depending on the execution of __enter__, Thread-1 may have just read from the file for Thread-2! Going on, Thread-1: __exit__ the with block Thread-2: handle.raw.read() Thread-2: __exit__ the with block Exception: NoneType has no attribute close() Or you know, something like that. This change makes the methods instance attributes. --- hca/util/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hca/util/__init__.py b/hca/util/__init__.py index 877510ae..138a36e3 100755 --- a/hca/util/__init__.py +++ b/hca/util/__init__.py @@ -468,7 +468,7 @@ def _build_client_method(self, http_method, http_path, method_data): docstring += "\n\n" + _md2rst(method_data["description"]) client_method.__doc__ = docstring - setattr(self.__class__, method_name, types.MethodType(client_method, SwaggerClient)) + setattr(self, method_name, types.MethodType(client_method, SwaggerClient)) self.methods[method_name] = dict(method_data, entry_point=getattr(self, method_name)._cli_call, signature=client_method.__signature__, args=method_args) From 3c57a7f29dd8a00fab0c79fbd9a8f5d0cd0c4b58 Mon Sep 17 00:00:00 2001 From: Marcus Kinsella Date: Sun, 28 Oct 2018 21:00:41 +0000 Subject: [PATCH 2/2] Add a troubling test of multithreaded SwaggerClient Really feel like this one should pass... --- requirements-dev.txt | 1 + test/integration/util/test_swagger_client.py | 57 +++++++++++++++++--- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index c821f8b7..78aa67a6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ flake8==3.5.0 moto==1.3.3 +futures; python_version < '3.2' coverage pyyaml responses diff --git a/test/integration/util/test_swagger_client.py b/test/integration/util/test_swagger_client.py index 106a7ee7..be3e8cf3 100755 --- a/test/integration/util/test_swagger_client.py +++ b/test/integration/util/test_swagger_client.py @@ -2,10 +2,12 @@ # coding: utf-8 import argparse +import concurrent.futures import json import os import requests import sys +import time import unittest pkg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) # noqa @@ -45,6 +47,7 @@ class TestSwaggerClient(unittest.TestCase): dummy_response = requests.models.Response() dummy_response.status_code = 200 dummy_response._content = "content" + dummy_response.headers["content-type"] = "audio/vnd.rn-realaudio" generated_method_names = [ # method names corresponding to all `paths` @@ -69,13 +72,21 @@ def setUpClass(cls): content = fh.read() swagger_response._content = content cls.test_swagger_json = json.loads(content.decode("utf-8")) - + cls.test_swagger_response = swagger_response cls.url_base = (cls.test_swagger_json['schemes'][0] + "://" + cls.test_swagger_json['host'] + cls.test_swagger_json['basePath']) + cls.client = cls.create_client( + cls.swagger_url, cls.test_swagger_response, cls.subparsers, cls.open_fn_name) + + @staticmethod + def create_client(swagger_url, swagger_response, subparsers, open_fn_name): + """ + Create and return a new SwaggerClient + """ with mock.patch('requests.Session.get') as mock_get, \ - mock.patch(cls.open_fn_name, mock_open()), \ + mock.patch(open_fn_name, mock_open()), \ mock.patch('hca.util.fs.atomic_write'), \ mock.patch('hca.dss.SwaggerClient.load_swagger_json') as mock_load_swagger_json: # init SwaggerClient with test swagger JSON file @@ -84,9 +95,10 @@ def setUpClass(cls): config = HCAConfig(save_on_exit=False) config['SwaggerClient'] = {} - config['SwaggerClient'].swagger_url = cls.swagger_url - cls.client = hca.util.SwaggerClient(config) - cls.client.build_argparse_subparsers(cls.subparsers) + config['SwaggerClient'].swagger_url = swagger_url + client = hca.util.SwaggerClient(config) + client.build_argparse_subparsers(subparsers) + return client @classmethod def tearDownClass(cls): @@ -97,8 +109,8 @@ def setUp(self): def test_client_methods_exist(self): for method_name in self.generated_method_names: - self.assertTrue(hasattr(self.client.__class__, method_name) and - callable(getattr(self.client.__class__, method_name))) + self.assertTrue(hasattr(self.client, method_name) and + callable(getattr(self.client, method_name))) def test_get_with_path_query_params(self): http_method = "get" @@ -331,6 +343,37 @@ def test_put_with_invalid_enum_param(self): '--query-param', query_param_invalid]) self.assertEqual(e.exception.code, 2) + def test_multithreaded(self): + http_method = "get" + path = "/with/path/query/params" + path_param = "path" + url = self.url_base + path + "/" + path_param + num_threads = 32 + num_attempts = 400 + + with concurrent.futures.ThreadPoolExecutor(num_threads) as exe, \ + mock.patch('requests.Session.request') as mock_request: + + mock_request.return_value = self.dummy_response + + def call_with_query_param(param): + client = self.create_client( + self.swagger_url, self.test_swagger_response, self.subparsers, + self.open_fn_name) + with client.get_with_path_query_params.stream( + path_param=path_param, query_param=param): + pass + + futures = [exe.submit(call_with_query_param, str(i)) for i in range(num_attempts)] + + while any(not f.done() for f in futures): + time.sleep(.5) + + called_query_params = set() + for call in mock_request.mock_calls: + if 'params' in call[2]: + called_query_params.add(call[2]["params"]["query_param"]) + self.assertSetEqual(called_query_params, set(str(i) for i in range(num_attempts))) if __name__ == "__main__": unittest.main()