-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathinference_model.py
More file actions
121 lines (109 loc) · 4.77 KB
/
inference_model.py
File metadata and controls
121 lines (109 loc) · 4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import logging
import numpy as np
import os
import torchaudio
import soundfile as sf
import time
import shutil
from funcineforge.register import tables
from funcineforge.auto.auto_model import AutoModel
from funcineforge.utils.set_all_random_seed import set_all_random_seed
from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip
@tables.register("model_classes", "FunCineForgeInferModel")
class FunCineForgeInferModel(nn.Module):
def __init__(
self,
lm_model: AutoModel,
fm_model: AutoModel,
voc_model: AutoModel,
**kwargs
):
super().__init__()
self.tokenizer = lm_model.kwargs["tokenizer"]
self.frontend = fm_model.kwargs["frontend"]
self.lm_model = lm_model.model
self.fm_model = fm_model.model
self.voc_model = voc_model.model
mel_extractor = self.fm_model.mel_extractor
if mel_extractor:
self.mel_frame_rate = mel_extractor.sampling_rate // mel_extractor.hop_length
self.sample_rate = mel_extractor.sampling_rate
else:
self.mel_frame_rate = self.fm_model.sample_rate // 480
self.sample_rate = self.fm_model.sample_rate
@torch.no_grad()
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
**kwargs,
):
uttid = key[0]
logging.info(f"generating {uttid}")
# text -> codec in [1, T]
kwargs["tokenizer"] = self.tokenizer
set_all_random_seed(kwargs.get("random_seed", 0))
lm_time = time.time()
codec, hit_eos, states = self.lm_model.inference(data_in, data_lengths, key, **kwargs)
logging.info(f"[llm time]: {((time.time()-lm_time)*1000):.2f} ms, [hit_eos]: {hit_eos}, [gen len]: {codec.shape[1]}, [speech tokens]: {codec[0].cpu().tolist()}")
wav, batch_data_time = None, 1.0
if codec.shape[1] > 0:
fm_time = time.time()
data_in[0]["codec"] = codec
set_all_random_seed(kwargs.get("random_seed", 0))
feat = self.fm_model.inference(data_in, data_lengths, key, **kwargs)
# feat -> wav
set_all_random_seed(kwargs.get("random_seed", 0))
wav = self.voc_model.inference([feat[0]], data_lengths, key, **kwargs)
# output save
output_dir = kwargs.get("output_dir", None)
if output_dir is not None:
feat_out_dir = os.path.join(output_dir, "feat")
os.makedirs(feat_out_dir, exist_ok=True)
np.save(os.path.join(feat_out_dir, f"{key[0]}.npy"), feat[0].cpu().numpy())
wav_out_dir = os.path.join(output_dir, "wav")
os.makedirs(wav_out_dir, exist_ok=True)
output_wav_path = os.path.join(wav_out_dir, f"{key[0]}.wav")
sf.write(
output_wav_path,
wav.cpu().squeeze(0).numpy(),
samplerate=self.sample_rate,
subtype='PCM_16'
)
silent_video_path = data_in[0]["video"]
if os.path.exists(silent_video_path):
video_out_dir = os.path.join(output_dir, "mp4")
video_gt_dir = os.path.join(output_dir, "gt")
os.makedirs(video_out_dir, exist_ok=True)
os.makedirs(video_gt_dir, exist_ok=True)
output_video_path = os.path.join(video_out_dir, f"{key[0]}.mp4")
copy_video_path = os.path.join(video_gt_dir, f"{key[0]}.mp4")
shutil.copy2(silent_video_path, copy_video_path)
self.merge_video_audio(
silent_video_path=silent_video_path,
wav_path=output_wav_path,
output_path=output_video_path,
)
logging.info(f"fm_voc time: {((time.time()-fm_time)*1000):.2f} ms")
batch_data_time = wav.shape[1] / self.voc_model.sample_rate
return [[wav]], {"batch_data_time": batch_data_time}
def merge_video_audio(self, silent_video_path, wav_path, output_path):
video_clip = VideoFileClip(silent_video_path)
video_duration = video_clip.duration
audio_clip = AudioFileClip(wav_path)
audio_duration = audio_clip.duration
if audio_duration >= video_duration:
audio_clip = audio_clip.subclipped(0, video_duration)
video_clip = video_clip.with_audio(audio_clip)
video_clip.write_videofile(
output_path,
codec='libx264',
audio_codec='aac',
fps=video_clip.fps,
logger=None
)
video_clip.close()
audio_clip.close()