@@ -212,9 +212,7 @@ def generate_constrained(input_ids: torch.Tensor) -> tuple[str, str]:
212212 # Forward pass 2: feed action tokens + suffix, pick emotion
213213 combined = action_token_ids [chosen_action ] + action_suffix
214214 combined_tensor = torch .tensor ([combined ], dtype = torch .long , device = device )
215- outputs = model (
216- input_ids = combined_tensor , past_key_values = past , use_cache = True
217- )
215+ outputs = model (input_ids = combined_tensor , past_key_values = past , use_cache = True )
218216
219217 logits = outputs .logits [:, - 1 , :]
220218 mask = torch .full_like (logits , float ("-inf" ))
@@ -319,9 +317,7 @@ def predict(req: PredictRequest):
319317 emotion ,
320318 latency ,
321319 )
322- return PredictResponse (
323- action = action , emotion = emotion , latency_ms = round (latency , 1 )
324- )
320+ return PredictResponse (action = action , emotion = emotion , latency_ms = round (latency , 1 ))
325321
326322
327323@app .post ("/predict_batch" , response_model = BatchPredictResponse )
@@ -363,7 +359,9 @@ def predict_batch(req: BatchPredictRequest):
363359 )
364360
365361 total_latency = (time .perf_counter () - total_start ) * 1000
366- logger .info ("predict_batch | count=%d | total=%.0fms" , len (req .texts ), total_latency )
362+ logger .info (
363+ "predict_batch | count=%d | total=%.0fms" , len (req .texts ), total_latency
364+ )
367365 return BatchPredictResponse (
368366 results = results , count = len (results ), total_latency_ms = round (total_latency , 1 )
369367 )
0 commit comments