-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathinferenceServer.py
More file actions
96 lines (75 loc) · 2.56 KB
/
inferenceServer.py
File metadata and controls
96 lines (75 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import asyncio
import struct
import torch
from modules.pytorch_mlp import PytorchMLPReg
regr = None
# 3 floats in, 4 floats out
INPUT_FMT = "!3f" # network byte order
OUTPUT_FMT = "!4f"
INPUT_SIZE = struct.calcsize(INPUT_FMT)
async def handle_client(reader: asyncio.StreamReader,
writer: asyncio.StreamWriter):
global regr
print("Client connected")
try:
while True:
# Read exactly 3 floats
data = await reader.readexactly(INPUT_SIZE)
x, y, z = struct.unpack(INPUT_FMT, data)
print(f"Received: {x}, {y}, {z}")
# Prepare tensor
target = torch.tensor(
[[x, y, z]],
dtype=torch.float32,
device="cpu",
)
with torch.inference_mode():
output = regr.predict(target)[0]
print("Infered:", output)
output_bytes = struct.pack(
OUTPUT_FMT,
float(output[0]),
float(output[1]),
float(output[2]),
float(output[3]),
)
writer.write(output_bytes)
await writer.drain()
except asyncio.IncompleteReadError:
print("Client disconnected early")
except Exception as e:
print("Server error:", e)
finally:
writer.close()
await writer.wait_closed()
async def main(model_file):
global regr
regr = PytorchMLPReg(model_file=model_file, batch_size=1)
server = await asyncio.start_server(
handle_client, '127.0.0.1', 5000)
async with server:
print("serving on 127.0.0.1:5000")
await server.serve_forever()
print("Closed server")
if __name__ == "__main__":
import argparse
import os
import sys
parser = argparse.ArgumentParser(prog=sys.argv[0], description="Simulate a leg.")
parser.add_argument(
metavar="model_file",
type=str,
nargs="?",
help="the path to the file containing the model",
dest="model_file",
)
try:
args = parser.parse_args()
except SystemExit:
print(sys.argv[0], "Invalid arguments, get defaults instead.")
args = parser.parse_args([])
print(
os.path.basename(__file__),
f"Using model file: {os.path.join(os.path.dirname(os.path.realpath(__file__)), args.model_file)}",
)
asyncio.run(main(os.path.join(os.path.dirname(os.path.realpath(__file__)), args.model_file)))