-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpipeline.py
More file actions
127 lines (108 loc) · 4.25 KB
/
pipeline.py
File metadata and controls
127 lines (108 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from dataclasses import dataclass, replace
import logging
import typing
from align import align, align_piecewise_linear
from annotate import annotate, AnnotationProgress
from transcode import transcode, TranscodingProgress
from transcribe import transcribe, TranscriptionProgress
import common
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
import common
from common import app
class PipelineError(Exception):
pass
@dataclass
class PipelineProgress:
state: str
transcription: typing.Optional[common.Transcription] = None
def pipeline(
transcription_id: str,
language: str = None,
prompt: str = None,
media_path: str = common.MEDIA_PATH,
local_mode: bool = False,
):
"""
The media processing pipeline
"""
t = common.db.select(transcription_id)
if not t:
raise PipelineError(f"invalid id : {transcription_id}")
try:
# bit awkward. supports local modal tests
transcode_fn = transcode.remote_gen
transcribe_fn = transcribe.remote_gen
annotate_fn = annotate.remote
align_fn = align.remote
if local_mode:
transcode_fn = transcode.local
transcribe_fn = transcribe.local
annotate_fn = annotate.local
align_fn = align.local
# transcode
if (not t.transcoded) or (not t.track):
logger.info(f"transcoding...")
yield PipelineProgress(state="transcoding")
for update in transcode_fn(transcription_id, media_path=media_path):
match update:
case TranscodingProgress(percent_done, None):
yield update
case TranscodingProgress(
percent_done, track
) if track is not None:
logger.info(f"completed transcoding. {track}")
t = replace(t, transcoded=True, track=track)
common.db.create(t)
yield update
case x:
raise ValueError(
f"cannot parse TranscodingProgress: {x}"
)
else:
logger.info(f"already transcoded. continuing")
# transcribe
changing_language = language and language != t.language
if (not t.transcribed) or changing_language:
logger.info("transcribing...")
yield PipelineProgress(state="transcribing")
for update in transcribe_fn(transcription_id, language, prompt):
match update:
case int(percent_done):
yield TranscriptionProgress(percent_done=percent_done)
case dict(transcript):
# save results
if not language:
# save the detected language
language = transcript.get("language")
# completed
logger.info(f"completed transcription.")
t = replace(t, transcript=transcript, language=language)
common.db.create(t)
yield TranscriptionProgress(
percent_done=100, transcript=t.transcript
)
case _:
update_str = f"{update} ({type(update)})"
raise ValueError(
f"cannot parse TranscriptionProgress: {update_str}"
)
else:
logger.info(f"already transcribed. continuing")
# align
logger.info("aligning...")
yield PipelineProgress(state="aligning")
t.alignment = align_fn(transcription_id, language=language)
# ... and diarize
logger.info("diarizing...")
yield PipelineProgress(state="annotating")
t.diarization = annotate_fn(transcription_id)
# .. and save
common.db.create(t)
logger.info("competed.")
yield PipelineProgress(state="completed", transcription=t)
except Exception as e:
import traceback
print(traceback.format_exc())
logger.error(e)
yield PipelineProgress(state="error")