-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgoogle_cloud.py
More file actions
116 lines (97 loc) · 4.13 KB
/
google_cloud.py
File metadata and controls
116 lines (97 loc) · 4.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import annotations
import asyncio
from logging import getLogger
from google.cloud import speech
from ..static import LISTEN_AUDIO_FORMAT, LISTEN_LANGUAGE_CODE
from ..types import StreamingSpeechRecognizer, StreamingSpeechSession
logger = getLogger(__name__)
_STREAM_END = object()
class _GoogleCloudStreamingSession(StreamingSpeechSession):
def __init__(
self,
client: speech.SpeechAsyncClient,
) -> None:
self._client = client
self._config = speech.StreamingRecognitionConfig(
config=speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
sample_rate_hertz=LISTEN_AUDIO_FORMAT.sample_rate_hz,
language_code=LISTEN_LANGUAGE_CODE,
),
interim_results=False,
single_utterance=False,
)
self._audio_queue: asyncio.Queue[bytes | object] = asyncio.Queue()
self._done = asyncio.Event()
self._closed = False
self._error: Exception | None = None
self._final_transcripts: list[str] = []
self._latest_transcript = ""
self._task = asyncio.create_task(self._run())
async def push_audio(self, pcm_bytes: bytes) -> None:
if self._closed:
raise RuntimeError("streaming speech session is already closed")
if pcm_bytes:
await self._audio_queue.put(bytes(pcm_bytes))
async def finish(self) -> str:
await self._close_stream()
await self._task
if self._error is not None:
raise self._error
transcript = "".join(self._final_transcripts)
return transcript or self._latest_transcript
async def abort(self) -> None:
await self._close_stream()
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
async def _close_stream(self) -> None:
if self._closed:
return
self._closed = True
await self._audio_queue.put(_STREAM_END)
async def _request_iter(self):
yield speech.StreamingRecognizeRequest(streaming_config=self._config)
while True:
chunk = await self._audio_queue.get()
if chunk is _STREAM_END:
break
yield speech.StreamingRecognizeRequest(audio_content=chunk)
async def _run(self) -> None:
try:
responses = await self._client.streaming_recognize(requests=self._request_iter())
async for response in responses:
for result in response.results:
if not result.alternatives:
continue
transcript = result.alternatives[0].transcript
if result.is_final:
logger.info("Streaming transcript(final): %s", transcript)
self._final_transcripts.append(transcript)
self._latest_transcript = ""
else:
logger.info("Streaming transcript(interim): %s", transcript)
self._latest_transcript = transcript
except asyncio.CancelledError:
raise
except Exception as exc:
self._error = exc
finally:
self._done.set()
class GoogleCloudSpeechToText(StreamingSpeechRecognizer):
def __init__(self, client: speech.SpeechAsyncClient | None = None) -> None:
self._client = client or speech.SpeechAsyncClient()
async def transcribe(self, pcm_bytes: bytes) -> str:
audio = speech.RecognitionAudio(content=pcm_bytes)
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
sample_rate_hertz=LISTEN_AUDIO_FORMAT.sample_rate_hz,
language_code=LISTEN_LANGUAGE_CODE,
)
response = await self._client.recognize(config=config, audio=audio)
return "".join(result.alternatives[0].transcript for result in response.results)
async def start_stream(self) -> StreamingSpeechSession:
return _GoogleCloudStreamingSession(self._client)
__all__ = ["GoogleCloudSpeechToText"]