-
Notifications
You must be signed in to change notification settings - Fork 791
Expand file tree
/
Copy pathnli_e2e_example.py
More file actions
104 lines (79 loc) · 2.79 KB
/
nli_e2e_example.py
File metadata and controls
104 lines (79 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import sys
import threading
import time
import requests
import uvicorn
from memos.extras.nli_model.client import NLIClient
from memos.extras.nli_model.server.serve import app
# Config
PORT = 32534
def run_server():
print(f"Starting server on port {PORT}...")
# Using a separate thread for the server
uvicorn.run(app, host="127.0.0.1", port=PORT, log_level="info")
def main():
print("Initializing E2E Test...")
# Start server thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Wait for server to be up
print("Waiting for server to initialize (this may take time if downloading model)...")
client = NLIClient(base_url=f"http://127.0.0.1:{PORT}")
# Poll until server is ready
start_time = time.time()
ready = False
# Wait up to 5 minutes for model download and initialization
timeout = 300
while time.time() - start_time < timeout:
try:
# Check if docs endpoint is accessible
resp = requests.get(f"http://127.0.0.1:{PORT}/docs", timeout=1)
if resp.status_code == 200:
ready = True
break
except requests.ConnectionError:
pass
except Exception:
# Ignore other errors during startup
pass
time.sleep(2)
print(".", end="", flush=True)
print("\n")
if not ready:
print("Server failed to start in time.")
sys.exit(1)
print("Server is up! Sending request...")
# Test Data
source = "I like apples"
targets = ["I like apples", "I hate apples", "Paris is a city"]
try:
results = client.compare_one_to_many(source, targets)
print("-" * 30)
print(f"Source: {source}")
print("Targets & Results:")
for t, r in zip(targets, results, strict=False):
print(f" - '{t}': {r.value}")
print("-" * 30)
# Basic Validation
passed = True
if results[0].value != "Duplicate":
print(f"FAILURE: Expected Duplicate for '{targets[0]}', got {results[0].value}")
passed = False
if results[1].value != "Contradiction":
print(f"FAILURE: Expected Contradiction for '{targets[1]}', got {results[1].value}")
passed = False
if results[2].value != "Unrelated":
print(f"FAILURE: Expected Unrelated for '{targets[2]}', got {results[2].value}")
passed = False
if passed:
print("\nSUCCESS: Logic verification passed!")
else:
print("\nFAILURE: Unexpected results!")
except Exception as e:
print(f"Error during request: {e}")
sys.exit(1)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nTest interrupted.")