Skip to content

Commit 350354c

Browse files
authored
Update modules.py
1 parent 3336925 commit 350354c

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

lib/modules.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import glob
2222
from shutil import move
2323

24-
2524
sup_audioext = {
2625
"wav",
2726
"mp3",
@@ -49,21 +48,23 @@ def note_to_hz(note_name):
4948
except:
5049
return None
5150

52-
def load_hubert():
51+
def load_hubert(hubert_model_path, config):
5352
from fairseq import checkpoint_utils
5453

55-
hubert_path = "assets/hubert/hubert_base.pt"
56-
5754
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
58-
[hubert_path],
55+
[hubert_model_path],
5956
suffix="",
6057
)
6158
hubert_model = models[0]
62-
hubert_model = hubert_model.float()
63-
hubert_model.eval()
59+
hubert_model = hubert_model.to(config.device)
60+
if config.is_half:
61+
hubert_model = hubert_model.half()
62+
else:
63+
hubert_model = hubert_model.float()
64+
65+
hubert_models = hubert_model.eval()
66+
return hubert_models
6467

65-
return hubert_model
66-
6768
class VC:
6869
def __init__(self, config):
6970
self.n_spk = None
@@ -459,10 +460,10 @@ def vc_single(
459460
times = [0, 0, 0]
460461

461462
if self.hubert_model is None:
462-
self.hubert_model = load_hubert()
463+
self.hubert_model = load_hubert(hubert_model_path, self.config)
463464

464-
#try:
465-
# self.if_f0 = self.cpt.get("f0", 1)
465+
try:
466+
self.if_f0 = self.cpt.get("f0", 1)
466467
except NameError:
467468
message = "Model was not properly selected"
468469
print(message)

0 commit comments

Comments
 (0)