Skip to content

Commit b8bf44c

Browse files
authored
Update full_inference.py
1 parent 38e757f commit b8bf44c

File tree

1 file changed

+278
-2
lines changed

1 file changed

+278
-2
lines changed

tabs/full_inference.py

Lines changed: 278 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,282 @@
1-
from tabs.settinginf import *
2-
1+
from core import full_inference_program, download_music
2+
import sys, os
3+
import gradio as gr
4+
import regex as re
5+
from assets.i18n.i18n import I18nAuto
6+
import torch
7+
import shutil
8+
import unicodedata
39
import gradio as gr
10+
from assets.i18n.i18n import I18nAuto
11+
12+
13+
i18n = I18nAuto()
14+
15+
16+
now_dir = os.getcwd()
17+
sys.path.append(now_dir)
18+
19+
20+
model_root = os.path.join(now_dir, "logs")
21+
audio_root = os.path.join(now_dir, "audio_files", "original_files")
22+
23+
24+
model_root_relative = os.path.relpath(model_root, now_dir)
25+
audio_root_relative = os.path.relpath(audio_root, now_dir)
26+
27+
28+
sup_audioext = {
29+
"wav",
30+
"mp3",
31+
"flac",
32+
"ogg",
33+
"opus",
34+
"m4a",
35+
"mp4",
36+
"aac",
37+
"alac",
38+
"wma",
39+
"aiff",
40+
"webm",
41+
"ac3",
42+
}
43+
44+
45+
names = [
46+
os.path.join(root, file)
47+
for root, _, files in os.walk(model_root_relative, topdown=False)
48+
for file in files
49+
if (
50+
file.endswith((".pth", ".onnx"))
51+
and not (file.startswith("G_") or file.startswith("D_"))
52+
)
53+
]
54+
55+
56+
indexes_list = [
57+
os.path.join(root, name)
58+
for root, _, files in os.walk(model_root_relative, topdown=False)
59+
for name in files
60+
if name.endswith(".index") and "trained" not in name
61+
]
62+
63+
64+
audio_paths = [
65+
os.path.join(root, name)
66+
for root, _, files in os.walk(audio_root_relative, topdown=False)
67+
for name in files
68+
if name.endswith(tuple(sup_audioext))
69+
and root == audio_root_relative
70+
and "_output" not in name
71+
]
72+
73+
74+
vocals_model_names = [
75+
"Mel-Roformer by KimberleyJSN",
76+
"BS-Roformer by ViperX",
77+
"MDX23C",
78+
]
79+
80+
81+
karaoke_models_names = [
82+
"Mel-Roformer Karaoke by aufr33 and viperx",
83+
"UVR-BVE",
84+
]
85+
86+
87+
denoise_models_names = [
88+
"Mel-Roformer Denoise Normal by aufr33",
89+
"Mel-Roformer Denoise Aggressive by aufr33",
90+
"UVR Denoise",
91+
]
92+
93+
94+
dereverb_models_names = [
95+
"MDX23C DeReverb by aufr33 and jarredou",
96+
"UVR-Deecho-Dereverb",
97+
"MDX Reverb HQ by FoxJoy",
98+
"BS-Roformer Dereverb by anvuew",
99+
]
100+
101+
102+
deeecho_models_names = ["UVR-Deecho-Normal", "UVR-Deecho-Aggressive"]
103+
104+
105+
def get_indexes():
106+
107+
indexes_list = [
108+
os.path.join(dirpath, filename)
109+
for dirpath, _, filenames in os.walk(model_root_relative)
110+
for filename in filenames
111+
if filename.endswith(".index") and "trained" not in filename
112+
]
113+
114+
return indexes_list if indexes_list else ""
115+
116+
117+
def match_index(model_file_value):
118+
if model_file_value:
119+
model_folder = os.path.dirname(model_file_value)
120+
model_name = os.path.basename(model_file_value)
121+
index_files = get_indexes()
122+
pattern = r"^(.*?)_"
123+
match = re.match(pattern, model_name)
124+
for index_file in index_files:
125+
if os.path.dirname(index_file) == model_folder:
126+
return index_file
127+
128+
elif match and match.group(1) in os.path.basename(index_file):
129+
return index_file
130+
131+
elif model_name in os.path.basename(index_file):
132+
return index_file
133+
134+
return ""
135+
136+
137+
def output_path_fn(input_audio_path):
138+
original_name_without_extension = os.path.basename(input_audio_path).rsplit(".", 1)[
139+
0
140+
]
141+
new_name = original_name_without_extension + "_output.wav"
142+
output_path = os.path.join(os.path.dirname(input_audio_path), new_name)
143+
144+
return output_path
145+
146+
147+
def get_number_of_gpus():
148+
if torch.cuda.is_available():
149+
num_gpus = torch.cuda.device_count()
150+
151+
return "-".join(map(str, range(num_gpus)))
152+
153+
else:
154+
155+
return "-"
156+
157+
158+
def max_vram_gpu(gpu):
159+
160+
if torch.cuda.is_available():
161+
gpu_properties = torch.cuda.get_device_properties(gpu)
162+
total_memory_gb = round(gpu_properties.total_memory / 1024 / 1024 / 1024)
163+
164+
return total_memory_gb / 2
165+
166+
else:
167+
168+
return "0"
169+
170+
171+
def format_title(title):
172+
173+
formatted_title = (
174+
unicodedata.normalize("NFKD", title).encode("ascii", "ignore").decode("utf-8")
175+
)
176+
177+
formatted_title = re.sub(r"[\u2500-\u257F]+", "", formatted_title)
178+
formatted_title = re.sub(r"[^\w\s.-]", "", formatted_title)
179+
formatted_title = re.sub(r"\s+", "_", formatted_title)
180+
181+
return formatted_title
182+
183+
184+
def save_to_wav(upload_audio):
185+
186+
file_path = upload_audio
187+
formated_name = format_title(os.path.basename(file_path))
188+
target_path = os.path.join(audio_root_relative, formated_name)
189+
190+
if os.path.exists(target_path):
191+
os.remove(target_path)
192+
193+
os.makedirs(os.path.dirname(target_path), exist_ok=True)
194+
shutil.copy(file_path, target_path)
195+
196+
return target_path, output_path_fn(target_path)
197+
198+
199+
def delete_outputs():
200+
gr.Info(f"Outputs cleared!")
201+
for root, _, files in os.walk(audio_root_relative, topdown=False):
202+
for name in files:
203+
if name.endswith(tuple(sup_audioext)) and name.__contains__("_output"):
204+
os.remove(os.path.join(root, name))
205+
206+
207+
def change_choices():
208+
names = [
209+
os.path.join(root, file)
210+
for root, _, files in os.walk(model_root_relative, topdown=False)
211+
for file in files
212+
if (
213+
file.endswith((".pth", ".onnx"))
214+
and not (file.startswith("G_") or file.startswith("D_"))
215+
)
216+
]
217+
218+
indexes_list = [
219+
os.path.join(root, name)
220+
for root, _, files in os.walk(model_root_relative, topdown=False)
221+
for name in files
222+
if name.endswith(".index") and "trained" not in name
223+
]
224+
225+
audio_paths = [
226+
os.path.join(root, name)
227+
for root, _, files in os.walk(audio_root_relative, topdown=False)
228+
for name in files
229+
if name.endswith(tuple(sup_audioext))
230+
and root == audio_root_relative
231+
and "_output" not in name
232+
]
233+
234+
return (
235+
{"choices": sorted(names), "__type__": "update"},
236+
{"choices": sorted(indexes_list), "__type__": "update"},
237+
{"choices": sorted(audio_paths), "__type__": "update"},
238+
)
239+
240+
241+
242+
243+
def update_dropdown_visibility(checkbox):
244+
245+
return gr.update(visible=checkbox)
246+
247+
def update_reverb_sliders_visibility(reverb_checked):
248+
249+
return {
250+
reverb_room_size: gr.update(visible=reverb_checked),
251+
reverb_damping: gr.update(visible=reverb_checked),
252+
reverb_wet_gain: gr.update(visible=reverb_checked),
253+
reverb_dry_gain: gr.update(visible=reverb_checked),
254+
reverb_width: gr.update(visible=reverb_checked),
255+
}
256+
257+
def update_visibility_infer_backing(infer_backing_vocals):
258+
259+
visible = infer_backing_vocals
260+
261+
return (
262+
{"visible": visible, "__type__": "update"},
263+
{"visible": visible, "__type__": "update"},
264+
{"visible": visible, "__type__": "update"},
265+
{"visible": visible, "__type__": "update"},
266+
{"visible": visible, "__type__": "update"},
267+
)
268+
269+
def update_hop_length_visibility(pitch_extract_value):
270+
271+
return gr.update(visible=pitch_extract_value in ["crepe", "crepe-tiny"])
272+
273+
274+
275+
276+
277+
278+
279+
4280

5281

6282

0 commit comments

Comments
 (0)