-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_session_lifecycle.py
More file actions
181 lines (142 loc) · 5.28 KB
/
test_session_lifecycle.py
File metadata and controls
181 lines (142 loc) · 5.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import asyncio
from multiprocessing.pool import ThreadPool
from time import sleep
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.orm import Session, scoped_session
from sqlalchemy_bind_manager._bind_manager import SQLAlchemyAsyncBind
from sqlalchemy_bind_manager._session_handler import AsyncSessionHandler, SessionHandler
async def test_session_is_removed_on_cleanup(session_handler_class, sa_bind):
sh = session_handler_class(sa_bind)
original_session_remove = sh.scoped_session.remove
with patch.object(
sh.scoped_session,
"remove",
wraps=original_session_remove,
) as mocked_remove:
# This should trigger the garbage collector and close the session
sh = None
mocked_remove.assert_called_once()
def test_session_is_removed_on_cleanup_even_if_loop_is_not_running(sa_manager):
# Running the test without a loop will trigger the loop creation
sh = AsyncSessionHandler(sa_manager.get_bind("async"))
original_session_remove = sh.scoped_session.remove
original_get_event_loop = asyncio.get_event_loop
with (
patch.object(
sh.scoped_session,
"remove",
wraps=original_session_remove,
) as mocked_close,
patch(
"asyncio.get_event_loop",
wraps=original_get_event_loop,
) as mocked_get_event_loop,
):
# This should trigger the garbage collector and close the session
sh = None
mocked_get_event_loop.assert_called_once()
mocked_close.assert_called_once()
def test_session_is_removed_on_cleanup_even_if_loop_search_errors_out(sa_manager):
# Running the test without a loop will trigger the loop creation
sh = AsyncSessionHandler(sa_manager.get_bind("async"))
original_session_remove = sh.scoped_session.remove
with (
patch.object(
sh.scoped_session,
"remove",
wraps=original_session_remove,
) as mocked_close,
patch(
"asyncio.get_event_loop",
side_effect=RuntimeError(),
) as mocked_get_event_loop,
):
# This should trigger the garbage collector and close the session
sh = None
mocked_get_event_loop.assert_called_once()
mocked_close.assert_called_once()
@pytest.mark.parametrize("read_only_flag", [True, False])
async def test_commit_is_called_only_if_not_read_only(
read_only_flag,
session_handler_class,
model_class,
sa_bind,
sync_async_cm_wrapper,
):
sh = session_handler_class(sa_bind)
# Populate a database entry to be used for tests
model1 = model_class(
name="Someone",
)
with patch.object(
session_handler_class, "commit", return_value=None
) as mocked_sh_commit:
async with sync_async_cm_wrapper(
sh.get_session(read_only=read_only_flag)
) as _session:
_session.add(model1)
assert mocked_sh_commit.call_count == int(not read_only_flag)
@pytest.mark.parametrize("commit_fails", [True, False])
async def test_rollback_is_called_if_commit_fails(
commit_fails,
session_handler_class,
sa_bind,
sync_async_wrapper,
):
sh = session_handler_class(sa_bind)
failure_exception = Exception("Some Error")
mocked_session = (
AsyncMock(spec=async_scoped_session)
if isinstance(sa_bind, SQLAlchemyAsyncBind)
else MagicMock(spec=scoped_session)
)
if commit_fails:
mocked_session.commit.side_effect = failure_exception
try:
await sync_async_wrapper(sh.commit(mocked_session))
except Exception as e:
assert commit_fails is True
assert e == failure_exception
assert mocked_session.commit.call_count == 1
assert mocked_session.rollback.call_count == int(commit_fails)
async def test_session_is_different_on_different_asyncio_tasks(sa_manager):
# Running the test without a loop will trigger the loop creation
sh = AsyncSessionHandler(sa_manager.get_bind("async"))
s1 = sh.scoped_session()
s2 = sh.scoped_session()
assert isinstance(s1, AsyncSession)
assert isinstance(s2, AsyncSession)
assert s1 is s2
async def _get_sh_session():
return sh.scoped_session()
s = await asyncio.gather(
_get_sh_session(),
_get_sh_session(),
)
assert isinstance(s[0], AsyncSession)
assert isinstance(s[1], AsyncSession)
assert s[0] is not s[1]
async def test_session_is_different_on_different_threads(sa_manager):
# Running the test without a loop will trigger the loop creation
sh = SessionHandler(sa_manager.get_bind("sync"))
s1 = sh.scoped_session()
s2 = sh.scoped_session()
assert isinstance(s1, Session)
assert isinstance(s2, Session)
assert s1 is s2
def _get_session():
# This sleep is to make sure the task doesn't
# resolve immediately and multiple instances
# end up in different threads
sleep(1)
return sh.scoped_session()
with ThreadPool() as pool:
s3_task = pool.apply_async(_get_session)
s4_task = pool.apply_async(_get_session)
s3 = s3_task.get()
s4 = s4_task.get()
assert isinstance(s3, Session)
assert isinstance(s4, Session)
assert s3 is not s4