Skip to content

Commit cd285af

Browse files
localai-botteam-coding-agent-1
authored andcommitted
fix: Fix nemo-parakeet-tdt-0.6b empty transcription by using direct audio input (#8682)
- Replace file path-based transcription with direct waveform input - Add torchaudio for audio loading and preprocessing - Resample audio to 16000Hz if necessary - Convert stereo to mono before transcription - Use model's transcribe method with numpy array input to avoid dataloader issues Fixes #8682 Signed-off-by: team-coding-agent-2 <team-coding-agent-2@example.com>
1 parent 6e5a58c commit cd285af

1 file changed

Lines changed: 170 additions & 0 deletions

File tree

backend.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#!/usr/bin/env python3
2+
"""
3+
gRPC server of LocalAI for NVIDIA NEMO Toolkit ASR.
4+
Fixed to handle transcription without dataloader issues.
5+
"""
6+
from concurrent import futures
7+
import time
8+
import argparse
9+
import signal
10+
import sys
11+
import os
12+
import backend_pb2
13+
import backend_pb2_grpc
14+
import torch
15+
import nemo.collections.asr as nemo_asr
16+
import numpy as np
17+
import torchaudio
18+
19+
import grpc
20+
21+
22+
def is_float(s):
23+
try:
24+
float(s)
25+
return True
26+
except ValueError:
27+
return False
28+
29+
30+
def is_int(s):
31+
try:
32+
int(s)
33+
return True
34+
except ValueError:
35+
return False
36+
37+
38+
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
39+
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
40+
41+
42+
class BackendServicer(backend_pb2_grpc.BackendServicer):
43+
def Health(self, request, context):
44+
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
45+
46+
def LoadModel(self, request, context):
47+
if torch.cuda.is_available():
48+
device = "cuda"
49+
else:
50+
device = "cpu"
51+
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
52+
if mps_available:
53+
device = "mps"
54+
if not torch.cuda.is_available() and request.CUDA:
55+
return backend_pb2.Result(success=False, message="CUDA is not available")
56+
57+
self.device = device
58+
self.options = {}
59+
60+
for opt in request.Options:
61+
if ":" not in opt:
62+
continue
63+
key, value = opt.split(":", 1)
64+
if is_float(value):
65+
value = float(value)
66+
elif is_int(value):
67+
value = int(value)
68+
elif value.lower() in ["true", "false"]:
69+
value = value.lower() == "true"
70+
self.options[key] = value
71+
72+
model_name = request.Model or "nvidia/parakeet-tdt-0.6b-v3"
73+
74+
try:
75+
print(f"Loading NEMO ASR model from {model_name}", file=sys.stderr)
76+
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
77+
self.model.to(self.device)
78+
self.model.eval()
79+
print("NEMO ASR model loaded successfully", file=sys.stderr)
80+
except Exception as err:
81+
print(f"[ERROR] LoadModel failed: {err}", file=sys.stderr)
82+
import traceback
83+
traceback.print_exc(file=sys.stderr)
84+
return backend_pb2.Result(success=False, message=str(err))
85+
86+
return backend_pb2.Result(message="Model loaded successfully", success=True)
87+
88+
def AudioTranscription(self, request, context):
89+
result_segments = []
90+
text = ""
91+
try:
92+
audio_path = request.dst
93+
if not audio_path or not os.path.exists(audio_path):
94+
print(f"Error: Audio file not found: {audio_path}", file=sys.stderr)
95+
return backend_pb2.TranscriptResult(segments=[], text="")
96+
97+
# Load audio file using torchaudio
98+
waveform, sample_rate = torchaudio.load(audio_path)
99+
100+
# Resample if necessary to match model's expected sample rate (16000 for most NEMO models)
101+
target_sample_rate = 16000
102+
if sample_rate != target_sample_rate:
103+
resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
104+
waveform = resampler(waveform)
105+
106+
# Convert to mono if stereo
107+
if waveform.shape[0] > 1:
108+
waveform = waveform.mean(dim=0, keepdim=True)
109+
110+
# Transcribe using the model's transcribe method with preprocessed audio
111+
# Use the simpler transcription path that doesn't require dataloader setup
112+
with torch.no_grad():
113+
# Convert waveform to the format expected by the model
114+
audio_list = [waveform.squeeze().cpu().numpy()]
115+
results = self.model.transcribe(audio_list)
116+
117+
if not results or len(results) == 0:
118+
print("No transcription results returned", file=sys.stderr)
119+
return backend_pb2.TranscriptResult(segments=[], text="")
120+
121+
# Get the transcript text from the first result
122+
text = results[0]
123+
if text:
124+
# Create a single segment with the full transcription
125+
result_segments.append(backend_pb2.TranscriptSegment(
126+
id=0, start=0, end=0, text=text
127+
))
128+
129+
except Exception as err:
130+
print(f"Error in AudioTranscription: {err}", file=sys.stderr)
131+
import traceback
132+
traceback.print_exc(file=sys.stderr)
133+
return backend_pb2.TranscriptResult(segments=[], text="")
134+
135+
return backend_pb2.TranscriptResult(segments=result_segments, text=text)
136+
137+
138+
def serve(address):
139+
server = grpc.server(
140+
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
141+
options=[
142+
('grpc.max_message_length', 50 * 1024 * 1024),
143+
('grpc.max_send_message_length', 50 * 1024 * 1024),
144+
('grpc.max_receive_message_length', 50 * 1024 * 1024),
145+
])
146+
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
147+
server.add_insecure_port(address)
148+
server.start()
149+
print("Server started. Listening on: " + address, file=sys.stderr)
150+
151+
def signal_handler(sig, frame):
152+
print("Received termination signal. Shutting down...")
153+
server.stop(0)
154+
sys.exit(0)
155+
156+
signal.signal(signal.SIGINT, signal_handler)
157+
signal.signal(signal.SIGTERM, signal_handler)
158+
159+
try:
160+
while True:
161+
time.sleep(_ONE_DAY_IN_SECONDS)
162+
except KeyboardInterrupt:
163+
server.stop(0)
164+
165+
166+
if __name__ == "__main__":
167+
parser = argparse.ArgumentParser(description="Run the gRPC server.")
168+
parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.")
169+
args = parser.parse_args()
170+
serve(args.addr)

0 commit comments

Comments
 (0)