Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python3
"""
gRPC server of LocalAI for NVIDIA NEMO Toolkit ASR.
Fixed to handle transcription without dataloader issues.
"""
from concurrent import futures
import time
import argparse
import signal
import sys
import os
import backend_pb2
import backend_pb2_grpc
import torch
import nemo.collections.asr as nemo_asr
import numpy as np
import torchaudio

import grpc


def is_float(s):
try:
float(s)
return True
except ValueError:
return False


def is_int(s):
try:
int(s)
return True
except ValueError:
return False


_ONE_DAY_IN_SECONDS = 60 * 60 * 24
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))


class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

def LoadModel(self, request, context):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
if mps_available:
device = "mps"
if not torch.cuda.is_available() and request.CUDA:
return backend_pb2.Result(success=False, message="CUDA is not available")

self.device = device
self.options = {}

for opt in request.Options:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
self.options[key] = value

model_name = request.Model or "nvidia/parakeet-tdt-0.6b-v3"

try:
print(f"Loading NEMO ASR model from {model_name}", file=sys.stderr)
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
self.model.to(self.device)
self.model.eval()
print("NEMO ASR model loaded successfully", file=sys.stderr)
except Exception as err:
print(f"[ERROR] LoadModel failed: {err}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
return backend_pb2.Result(success=False, message=str(err))

return backend_pb2.Result(message="Model loaded successfully", success=True)

def AudioTranscription(self, request, context):
result_segments = []
text = ""
try:
audio_path = request.dst
if not audio_path or not os.path.exists(audio_path):
print(f"Error: Audio file not found: {audio_path}", file=sys.stderr)
return backend_pb2.TranscriptResult(segments=[], text="")

# Load audio file using torchaudio
waveform, sample_rate = torchaudio.load(audio_path)

# Resample if necessary to match model's expected sample rate (16000 for most NEMO models)
target_sample_rate = 16000
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
waveform = resampler(waveform)

# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)

# Transcribe using the model's transcribe method with preprocessed audio
# Use the simpler transcription path that doesn't require dataloader setup
with torch.no_grad():
# Convert waveform to the format expected by the model
audio_list = [waveform.squeeze().cpu().numpy()]
results = self.model.transcribe(audio_list)

if not results or len(results) == 0:
print("No transcription results returned", file=sys.stderr)
return backend_pb2.TranscriptResult(segments=[], text="")

# Get the transcript text from the first result
text = results[0]
if text:
# Create a single segment with the full transcription
result_segments.append(backend_pb2.TranscriptSegment(
id=0, start=0, end=0, text=text
))

except Exception as err:
print(f"Error in AudioTranscription: {err}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
return backend_pb2.TranscriptResult(segments=[], text="")

return backend_pb2.TranscriptResult(segments=result_segments, text=text)


def serve(address):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
])
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)

def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.")
args = parser.parse_args()
serve(args.addr)
2 changes: 1 addition & 1 deletion backend/cpp/llama-cpp/Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

LLAMA_VERSION?=57819b8d4b39d893408e51520dff3d47d1ebb757
LLAMA_VERSION?=10e5b148b061569aaee8ae0cf72a703129df0eab
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp

CMAKE_ARGS?=
Expand Down
2 changes: 1 addition & 1 deletion backend/go/acestep-cpp/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)

# acestep.cpp version
ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp
ACESTEP_CPP_VERSION?=5aa065445541094cba934299cd498bbb9fa5c434
ACESTEP_CPP_VERSION?=master
SO_TARGET?=libgoacestepcpp.so

CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
Expand Down
Loading