-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathinfer.py
More file actions
64 lines (48 loc) · 2.18 KB
/
infer.py
File metadata and controls
64 lines (48 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import shutil
import torch
import soundfile as sf
from tqdm import tqdm
from omegaconf import OmegaConf
from models.gtcrn_end2end import GTCRN as Model
def main(args):
cfg_infer = OmegaConf.load(args.config)
cfg_network = OmegaConf.load(cfg_infer.network.config)
noisy_folder = cfg_infer.test_dataset.noisy_dir
clean_folder = cfg_infer.test_dataset.clean_dir
enh_folder = cfg_infer.network.enh_folder
os.makedirs(enh_folder, exist_ok=True)
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
model = Model(**cfg_network['network_config']).to(device)
checkpoint = torch.load(cfg_infer.network.checkpoint, map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()
noisy_wavs = sorted(list(filter(lambda x: x.endswith("wav"), os.listdir(noisy_folder))))
inf_scp_list = []
ref_scp_list = []
for wav_name in tqdm(noisy_wavs):
noisy, fs = sf.read(os.path.join(noisy_folder, wav_name), dtype='float32')
input = torch.FloatTensor(noisy).unsqueeze(0).to(device)
with torch.inference_mode():
output = model(input)
enhanced = output.cpu().detach().numpy().squeeze()
uid = wav_name.split(".wav")[0]
enh_path = os.path.join(enh_folder, uid + f"_enh.wav")
ref_path = os.path.join(clean_folder, wav_name)
inf_scp_list.append([uid, enh_path])
ref_scp_list.append([uid, ref_path])
sf.write(enh_path, enhanced, fs)
# Save paths into scp file for evaluation
with open(os.path.join(enh_folder, "inf.scp"), "w") as f:
for uid, audio_path in inf_scp_list:
f.write(f"{uid} {audio_path}\n")
with open(os.path.join(enh_folder, "ref.scp"), "w") as f:
for uid, audio_path in ref_scp_list:
f.write(f"{uid} {audio_path}\n")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-C', '--config', default='configs/cfg_infer.yaml')
parser.add_argument('-D', '--device', default='0', help='Index of the gpu device')
args = parser.parse_args()
main(args)