Skip to content

Commit 5e26e24

Browse files
committed
update the format
1 parent 1451e32 commit 5e26e24

6 files changed

Lines changed: 26 additions & 108 deletions

File tree

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ jobs:
4646
push: true
4747
platforms: linux/arm64
4848
tags: ${{ steps.meta.outputs.tags }}
49-
labels: ${{ steps.meta.outputs.labels }}
49+
labels: ${{ steps.meta.outputs.labels }}

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
functionGemma-finetuned-g1/
22
venv/
33
venv/
4+
__pycache__/
5+
*.pyc

benchmark-g1-server.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
Usage: python3 benchmark_client.py [--url http://localhost:8200]
44
"""
55

6-
import time
76
import argparse
7+
import time
8+
89
import requests
910

1011
TESTS = [
@@ -25,6 +26,7 @@
2526
"Thank you so much",
2627
]
2728

29+
2830
def main():
2931
parser = argparse.ArgumentParser()
3032
parser.add_argument("--url", default="http://localhost:8200")
@@ -52,9 +54,11 @@ def main():
5254
inference_ms = data["latency_ms"]
5355
times.append(inference_ms)
5456

55-
print(f" {inference_ms:5.0f}ms inference | {total_ms:5.0f}ms total | {t:<30s} -> {data['action']:<16s} {data['emotion']}")
57+
print(
58+
f" {inference_ms:5.0f}ms inference | {total_ms:5.0f}ms total | {t:<30s} -> {data['action']:<16s} {data['emotion']}"
59+
)
5660

57-
print(f"\n--- Inference (model only) ---")
61+
print("\n--- Inference (model only) ---")
5862
print(f"Min: {min(times):.0f}ms")
5963
print(f"Max: {max(times):.0f}ms")
6064
print(f"Average: {sum(times)/len(times):.0f}ms")
@@ -65,7 +69,10 @@ def main():
6569
r = requests.post(f"{args.url}/predict_batch", json={"texts": TESTS})
6670
total_ms = (time.time() - start) * 1000
6771
data = r.json()
68-
print(f"Total: {data['total_latency_ms']:.0f}ms inference | {total_ms:.0f}ms with network")
72+
print(
73+
f"Total: {data['total_latency_ms']:.0f}ms inference | {total_ms:.0f}ms with network"
74+
)
75+
6976

7077
if __name__ == "__main__":
71-
main()
78+
main()

chat_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
Usage: python3 chat_client.py [--url http://localhost:8200]
44
"""
55

6-
import json
76
import argparse
7+
88
import requests
99

10+
1011
def main():
1112
parser = argparse.ArgumentParser()
1213
parser.add_argument("--url", default="http://localhost:8200")
@@ -23,7 +24,10 @@ def main():
2324

2425
r = requests.post(f"{args.url}/predict", json={"text": text})
2526
data = r.json()
26-
print(f"Robot ({data['latency_ms']:.0f}ms): action={data['action']} emotion={data['emotion']}\n")
27+
print(
28+
f"Robot ({data['latency_ms']:.0f}ms): action={data['action']} emotion={data['emotion']}\n"
29+
)
30+
2731

2832
if __name__ == "__main__":
2933
main()

src/functiongemma/server.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,7 @@ def generate_constrained(input_ids: torch.Tensor) -> tuple[str, str]:
212212
# Forward pass 2: feed action tokens + suffix, pick emotion
213213
combined = action_token_ids[chosen_action] + action_suffix
214214
combined_tensor = torch.tensor([combined], dtype=torch.long, device=device)
215-
outputs = model(
216-
input_ids=combined_tensor, past_key_values=past, use_cache=True
217-
)
215+
outputs = model(input_ids=combined_tensor, past_key_values=past, use_cache=True)
218216

219217
logits = outputs.logits[:, -1, :]
220218
mask = torch.full_like(logits, float("-inf"))
@@ -319,9 +317,7 @@ def predict(req: PredictRequest):
319317
emotion,
320318
latency,
321319
)
322-
return PredictResponse(
323-
action=action, emotion=emotion, latency_ms=round(latency, 1)
324-
)
320+
return PredictResponse(action=action, emotion=emotion, latency_ms=round(latency, 1))
325321

326322

327323
@app.post("/predict_batch", response_model=BatchPredictResponse)
@@ -363,7 +359,9 @@ def predict_batch(req: BatchPredictRequest):
363359
)
364360

365361
total_latency = (time.perf_counter() - total_start) * 1000
366-
logger.info("predict_batch | count=%d | total=%.0fms", len(req.texts), total_latency)
362+
logger.info(
363+
"predict_batch | count=%d | total=%.0fms", len(req.texts), total_latency
364+
)
367365
return BatchPredictResponse(
368366
results=results, count=len(results), total_latency_ms=round(total_latency, 1)
369367
)

tests/test_functiongemma.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

0 commit comments

Comments
 (0)