-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathserver.py
More file actions
253 lines (208 loc) · 9.08 KB
/
server.py
File metadata and controls
253 lines (208 loc) · 9.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""Socket.IO Server.
Responsible to managing the Socket.IO connection and events.
Directs events to the platform handler or handles them directly.
Usage:
Instantiate Server with the appropriate options, platform handler, and console.
Then call `launch()` to start the server.
```python
Server(options, platform_handler, console).launch()
```
"""
from asyncio import new_event_loop
from collections.abc import Callable, Coroutine
from json import JSONDecodeError, dumps, loads
from typing import Any, TypeVar, final
from aiohttp.web import Application, run_app
from pydantic import ValidationError
from socketio import AsyncServer # pyright: ignore [reportMissingTypeStubs]
from vbl_aquarium.models.ephys_link import (
EphysLinkOptions,
SetDepthRequest,
SetInsideBrainRequest,
SetPositionRequest,
)
from vbl_aquarium.utils.vbl_base_model import VBLBaseModel
from ephys_link.__about__ import __version__
from ephys_link.back_end.platform_handler import PlatformHandler
from ephys_link.front_end.console import Console
from ephys_link.utils.constants import (
MALFORMED_REQUEST_ERROR,
PORT,
UNKNOWN_EVENT_ERROR,
cannot_connect_as_client_is_already_connected_error,
client_disconnected_without_being_connected_error,
)
# Server message generic types.
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=VBLBaseModel)
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=VBLBaseModel)
@final
class Server:
def __init__(self, options: EphysLinkOptions, platform_handler: PlatformHandler, console: Console) -> None:
"""Initialize server fields based on options and platform handler.
Args:
options: Launch options object.
platform_handler: Platform handler instance.
console: Console instance.
"""
# Save fields.
self._options = options
self._platform_handler = platform_handler
self._console = console
# Initialize server.
self._sio: AsyncServer = AsyncServer()
self._app = Application()
self._sio.attach(self._app) # pyright: ignore [reportUnknownMemberType]
# Bind connection events.
_ = self._sio.on("connect", self.connect) # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType]
_ = self._sio.on("disconnect", self.disconnect) # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType]
# Store connected client.
self._client_sid: str = ""
# Bind events.
_ = self._sio.on("*", self.platform_event_handler) # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType]
def launch(self) -> None:
"""Launch the server."""
# List platform and available manipulators.
self._console.info_print("PLATFORM", self._platform_handler.get_display_name())
# Create a temporary event loop for getting manipulators
loop = new_event_loop()
try:
self._console.info_print(
"MANIPULATORS",
str(loop.run_until_complete(self._platform_handler.get_manipulators()).manipulators),
)
finally:
loop.close()
# Launch server
run_app(self._app, port=PORT)
# Helper functions.
def _malformed_request_response(self, request: str, data: tuple[tuple[Any], ...]) -> str: # pyright: ignore [reportExplicitAny]
"""Return a response for a malformed request.
Args:
request: Original request.
data: Request data.
Returns:
Response for a malformed request.
"""
self._console.error_print("MALFORMED REQUEST", f"{request}: {data}")
return dumps(MALFORMED_REQUEST_ERROR)
async def _run_if_data_available(
self,
function: Callable[[str], Coroutine[Any, Any, VBLBaseModel]], # pyright: ignore [reportExplicitAny]
event: str,
data: Any, # pyright: ignore [reportAny, reportExplicitAny]
) -> str:
"""Run a function if data is available.
Args:
function: Function to run.
event: Event name.
data: Event data.
Returns:
Response data from function.
"""
if data:
return str((await function(str(data))).to_json_string()) # pyright: ignore[reportAny]
return self._malformed_request_response(event, data) # pyright: ignore[reportAny]
async def _run_if_data_parses(
self,
function: Callable[[INPUT_TYPE], Coroutine[Any, Any, OUTPUT_TYPE]], # pyright: ignore [reportExplicitAny]
data_type: type[INPUT_TYPE],
event: str,
data: Any, # pyright: ignore [reportAny, reportExplicitAny]
) -> str:
"""Run a function if data parses.
Args:
function: Function to run.
data_type: Data type to parse.
event: Event name.
data: Event data.
Returns:
Response data from function.
"""
if data:
try:
parsed_data = data_type(**loads(str(data))) # pyright: ignore[reportAny]
except JSONDecodeError:
return self._malformed_request_response(event, data) # pyright: ignore[reportAny]
except ValidationError as e:
self._console.exception_error_print(event, e)
return self._malformed_request_response(event, data) # pyright: ignore[reportAny]
else:
return str((await function(parsed_data)).to_json_string())
return self._malformed_request_response(event, data) # pyright: ignore[reportAny]
# Event Handlers.
async def connect(self, sid: str, _: str) -> bool:
"""Handle connections to the server.
Args:
sid: Socket session ID.
_: Extra connection data (unused).
Returns:
False on error to refuse connection, True otherwise.
"""
self._console.info_print("CONNECTION REQUEST", sid)
if self._client_sid == "":
self._client_sid = sid
self._console.info_print("CONNECTION GRANTED", sid)
return True
self._console.error_print(
"CONNECTION REFUSED", cannot_connect_as_client_is_already_connected_error(sid, self._client_sid)
)
return False
async def disconnect(self, sid: str) -> None:
"""Handle disconnections from the server.
Args:
sid: Socket session ID.
"""
self._console.info_print("DISCONNECTION REQUEST", sid)
# Reset client SID if it matches.
if self._client_sid == sid:
self._client_sid = ""
self._console.info_print("DISCONNECTED", sid)
else:
self._console.error_print("DISCONNECTION", client_disconnected_without_being_connected_error(sid))
async def platform_event_handler(self, event: str, _: str, data: Any) -> str: # pyright: ignore [reportAny, reportExplicitAny]
"""Handle events from the server.
Matches incoming events based on the Socket.IO API.
Args:
event: Event name.
_: Socket session ID (unused).
data: Event data.
Returns:
Response data.
"""
# Log event.
self._console.debug_print("EVENT", event)
# Handle event.
match event:
# Server metadata.
case "get_version":
return __version__
case "get_platform_info":
return (await self._platform_handler.get_platform_info()).to_json_string()
# Manipulator commands.
case "get_manipulators":
return str((await self._platform_handler.get_manipulators()).to_json_string())
case "get_position":
return await self._run_if_data_available(self._platform_handler.get_position, event, data)
case "get_angles":
return await self._run_if_data_available(self._platform_handler.get_angles, event, data)
case "get_shank_count":
return await self._run_if_data_available(self._platform_handler.get_shank_count, event, data)
case "set_position":
return await self._run_if_data_parses(
self._platform_handler.set_position, SetPositionRequest, event, data
)
case "set_depth":
return await self._run_if_data_parses(self._platform_handler.set_depth, SetDepthRequest, event, data)
case "set_inside_brain":
return await self._run_if_data_parses(
self._platform_handler.set_inside_brain, SetInsideBrainRequest, event, data
)
case "stop":
if data:
return await self._platform_handler.stop(str(data)) # pyright: ignore[reportAny]
return self._malformed_request_response(event, data) # pyright: ignore[reportAny]
case "stop_all":
return await self._platform_handler.stop_all()
case _:
self._console.error_print("EVENT", f"Unknown event: {event}.")
return dumps(UNKNOWN_EVENT_ERROR)