From 0cf084801adc1842ab081514ede407ad0d5bdb4e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 18 Feb 2025 04:13:38 +0000 Subject: [PATCH] chore(format): run black on master --- app.py | 17 +- assets/discord_presence.py | 10 +- assets/i18n/languages/id_ID.py | 2 +- assets/themes/Grheme.py | 1 - core.py | 12 +- programs/music_separation_code/ensemble.py | 71 +- programs/music_separation_code/inference.py | 1 - .../models/bandit/core/__init__.py | 297 ++-- .../models/bandit/core/data/__init__.py | 2 +- .../models/bandit/core/data/_types.py | 5 +- .../models/bandit/core/data/augmentation.py | 63 +- .../models/bandit/core/data/augmented.py | 17 +- .../models/bandit/core/data/base.py | 36 +- .../models/bandit/core/data/dnr/datamodule.py | 44 +- .../models/bandit/core/data/dnr/dataset.py | 252 ++-- .../models/bandit/core/data/dnr/preprocess.py | 13 +- .../bandit/core/data/musdb/datamodule.py | 68 +- .../models/bandit/core/data/musdb/dataset.py | 131 +- .../bandit/core/data/musdb/preprocess.py | 120 +- .../models/bandit/core/loss/__init__.py | 8 +- .../models/bandit/core/loss/_complex.py | 11 +- .../models/bandit/core/loss/_multistem.py | 14 +- .../models/bandit/core/loss/_timefreq.py | 78 +- .../models/bandit/core/loss/snr.py | 75 +- .../models/bandit/core/metrics/_squim.py | 98 +- .../models/bandit/core/metrics/snr.py | 81 +- .../models/bandit/core/model/_spectral.py | 74 +- .../bandit/core/model/bsrnn/bandsplit.py | 60 +- .../models/bandit/core/model/bsrnn/core.py | 802 ++++++----- .../bandit/core/model/bsrnn/maskestim.py | 277 ++-- .../models/bandit/core/model/bsrnn/tfmodel.py | 207 +-- .../models/bandit/core/model/bsrnn/utils.py | 328 ++--- .../models/bandit/core/model/bsrnn/wrapper.py | 1224 ++++++++--------- .../models/bandit/core/utils/audio.py | 289 ++-- .../models/bandit/model_from_config.py | 8 +- .../models/bandit_v2/bandit.py | 24 +- .../models/bandit_v2/film.py | 12 +- .../models/bs_roformer/attend.py | 54 +- .../models/bs_roformer/bs_roformer.py | 449 +++--- .../models/bs_roformer/mel_band_roformer.py | 418 +++--- .../music_separation_code/models/demucs4ht.py | 33 +- .../models/mdx23c_tfc_tdf_v3.py | 52 +- .../models/scnet/scnet.py | 167 ++- .../models/scnet/separation.py | 40 +- .../models/scnet_unofficial/__init__.py | 2 +- .../scnet_unofficial/modules/dualpath_rnn.py | 46 +- .../models/scnet_unofficial/scnet.py | 49 +- .../models/scnet_unofficial/utils.py | 8 +- .../models/segm_models.py | 50 +- .../models/torchseg_models.py | 50 +- .../models/upernet_swin_transformers.py | 76 +- programs/music_separation_code/utils.py | 127 +- tabs/full_inference.py | 757 +--------- tabs/settings.py | 6 - 54 files changed, 3226 insertions(+), 3990 deletions(-) diff --git a/app.py b/app.py index c74e412..8b869a2 100644 --- a/app.py +++ b/app.py @@ -3,7 +3,9 @@ from tabs.full_inference import full_inference_tab from tabs.download_model import download_model_tab from tabs.settings import theme_tab, lang_tab, restart_tab -from programs.applio_code.rvc.lib.tools.prerequisites_download import prequisites_download_pipeline +from programs.applio_code.rvc.lib.tools.prerequisites_download import ( + prequisites_download_pipeline, +) from tabs.presence import load_config_presence, presence_tab now_dir = os.getcwd() @@ -15,7 +17,7 @@ prequisites_download_pipeline( False, False, - True, + True, False, ) @@ -32,15 +34,14 @@ RPCManager.start_presence() - rvc_theme = loadThemes.load_theme() or "NoCrypt/miku" -with gr.Blocks( - theme=rvc_theme, title="Advanced RVC Inference" -) as rvc: +with gr.Blocks(theme=rvc_theme, title="Advanced RVC Inference") as rvc: gr.Markdown('

Advanced RVC Inference

') - gr.Markdown('
this project Maintained by NeoDev
') - + gr.Markdown( + '
this project Maintained by NeoDev
' + ) + with gr.Tab(i18n("Full Inference")): full_inference_tab() with gr.Tab(i18n("Download Model")): diff --git a/assets/discord_presence.py b/assets/discord_presence.py index 45bc687..ab52ec8 100644 --- a/assets/discord_presence.py +++ b/assets/discord_presence.py @@ -30,8 +30,14 @@ def update_presence(self): state="Advanced-RVC", details="Advaced voice cloning with UVR5 feature", buttons=[ - {"label": "Home", "url": "https://github.com/ArkanDash/Advanced-RVC-Inference"}, - {"label": "Download", "url": "https://github.com/ArkanDash/Advanced-RVC-Inference/archive/refs/heads/master.zip"}, + { + "label": "Home", + "url": "https://github.com/ArkanDash/Advanced-RVC-Inference", + }, + { + "label": "Download", + "url": "https://github.com/ArkanDash/Advanced-RVC-Inference/archive/refs/heads/master.zip", + }, ], large_image="logo", large_text="Experimenting with Advanced-RVC", diff --git a/assets/i18n/languages/id_ID.py b/assets/i18n/languages/id_ID.py index 29c08f3..ff76067 100644 --- a/assets/i18n/languages/id_ID.py +++ b/assets/i18n/languages/id_ID.py @@ -85,5 +85,5 @@ "Export Audio": "Ekspor Audio", "Music URL": "URL Musik", "Download": "Unduh", - "Model URL": "URL Model" + "Model URL": "URL Model", } diff --git a/assets/themes/Grheme.py b/assets/themes/Grheme.py index b1c68db..66e36ec 100644 --- a/assets/themes/Grheme.py +++ b/assets/themes/Grheme.py @@ -1,5 +1,4 @@ from __future__ import annotations -import time from typing import Iterable import gradio as gr diff --git a/core.py b/core.py index e86f679..42cb21e 100644 --- a/core.py +++ b/core.py @@ -15,8 +15,7 @@ from programs.music_separation_code.inference import proc_file logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s" + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) now_dir = os.getcwd() sys.path.append(now_dir) @@ -1015,7 +1014,6 @@ def download_model(link): return "Model downloaded with success" - def download_music(link): if not link or not isinstance(link, str): logging.error("Invalid link provided.") @@ -1035,9 +1033,11 @@ def download_music(link): command = [ "yt-dlp", "-x", - "--audio-format", "wav", - "--output", output_template, - link + "--audio-format", + "wav", + "--output", + output_template, + link, ] try: diff --git a/programs/music_separation_code/ensemble.py b/programs/music_separation_code/ensemble.py index 76fec7b..1ad1ff9 100644 --- a/programs/music_separation_code/ensemble.py +++ b/programs/music_separation_code/ensemble.py @@ -1,5 +1,5 @@ # coding: utf-8 -__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' +__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/" import os import librosa @@ -81,42 +81,42 @@ def average_waveforms(pred_track, weights, algorithm): mod_track = [] for i in range(pred_track.shape[0]): - if algorithm == 'avg_wave': + if algorithm == "avg_wave": mod_track.append(pred_track[i] * weights[i]) - elif algorithm in ['median_wave', 'min_wave', 'max_wave']: + elif algorithm in ["median_wave", "min_wave", "max_wave"]: mod_track.append(pred_track[i]) - elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']: + elif algorithm in ["avg_fft", "min_fft", "max_fft", "median_fft"]: spec = stft(pred_track[i], nfft=2048, hl=1024) - if algorithm in ['avg_fft']: + if algorithm in ["avg_fft"]: mod_track.append(spec * weights[i]) else: mod_track.append(spec) pred_track = np.array(mod_track) - if algorithm in ['avg_wave']: + if algorithm in ["avg_wave"]: pred_track = pred_track.sum(axis=0) pred_track /= np.array(weights).sum().T - elif algorithm in ['median_wave']: + elif algorithm in ["median_wave"]: pred_track = np.median(pred_track, axis=0) - elif algorithm in ['min_wave']: + elif algorithm in ["min_wave"]: pred_track = np.array(pred_track) pred_track = lambda_min(pred_track, axis=0, key=np.abs) - elif algorithm in ['max_wave']: + elif algorithm in ["max_wave"]: pred_track = np.array(pred_track) pred_track = lambda_max(pred_track, axis=0, key=np.abs) - elif algorithm in ['avg_fft']: + elif algorithm in ["avg_fft"]: pred_track = pred_track.sum(axis=0) pred_track /= np.array(weights).sum() pred_track = istft(pred_track, 1024, final_length) - elif algorithm in ['min_fft']: + elif algorithm in ["min_fft"]: pred_track = np.array(pred_track) pred_track = lambda_min(pred_track, axis=0, key=np.abs) pred_track = istft(pred_track, 1024, final_length) - elif algorithm in ['max_fft']: + elif algorithm in ["max_fft"]: pred_track = np.array(pred_track) pred_track = absmax(pred_track, axis=0) pred_track = istft(pred_track, 1024, final_length) - elif algorithm in ['median_fft']: + elif algorithm in ["median_fft"]: pred_track = np.array(pred_track) pred_track = np.median(pred_track, axis=0) pred_track = istft(pred_track, 1024, final_length) @@ -125,37 +125,58 @@ def average_waveforms(pred_track, weights, algorithm): def ensemble_files(args): parser = argparse.ArgumentParser() - parser.add_argument("--files", type=str, required=True, nargs='+', help="Path to all audio-files to ensemble") - parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft") - parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files") - parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored") + parser.add_argument( + "--files", + type=str, + required=True, + nargs="+", + help="Path to all audio-files to ensemble", + ) + parser.add_argument( + "--type", + type=str, + default="avg_wave", + help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft", + ) + parser.add_argument( + "--weights", + type=float, + nargs="+", + help="Weights to create ensemble. Number of weights must be equal to number of files", + ) + parser.add_argument( + "--output", + default="res.wav", + type=str, + help="Path to wav file where ensemble result will be stored", + ) if args is None: args = parser.parse_args() else: args = parser.parse_args(args) - print('Ensemble type: {}'.format(args.type)) - print('Number of input files: {}'.format(len(args.files))) + print("Ensemble type: {}".format(args.type)) + print("Number of input files: {}".format(len(args.files))) if args.weights is not None: weights = args.weights else: weights = np.ones(len(args.files)) - print('Weights: {}'.format(weights)) - print('Output file: {}'.format(args.output)) + print("Weights: {}".format(weights)) + print("Output file: {}".format(args.output)) data = [] for f in args.files: if not os.path.isfile(f): - print('Error. Can\'t find file: {}. Check paths.'.format(f)) + print("Error. Can't find file: {}. Check paths.".format(f)) exit() - print('Reading file: {}'.format(f)) + print("Reading file: {}".format(f)) wav, sr = librosa.load(f, sr=None, mono=False) # wav, sr = sf.read(f) print("Waveform shape: {} sample rate: {}".format(wav.shape, sr)) data.append(wav) data = np.array(data) res = average_waveforms(data, weights, args.type) - print('Result shape: {}'.format(res.shape)) - sf.write(args.output, res.T, sr, 'FLOAT') + print("Result shape: {}".format(res.shape)) + sf.write(args.output, res.T, sr, "FLOAT") if __name__ == "__main__": diff --git a/programs/music_separation_code/inference.py b/programs/music_separation_code/inference.py index 8c991d4..8ddfcb7 100644 --- a/programs/music_separation_code/inference.py +++ b/programs/music_separation_code/inference.py @@ -7,7 +7,6 @@ from tqdm import tqdm import sys import os -import glob import torch import numpy as np import soundfile as sf diff --git a/programs/music_separation_code/models/bandit/core/__init__.py b/programs/music_separation_code/models/bandit/core/__init__.py index a4d6d79..86e1557 100644 --- a/programs/music_separation_code/models/bandit/core/__init__.py +++ b/programs/music_separation_code/models/bandit/core/__init__.py @@ -1,20 +1,14 @@ import os.path from collections import defaultdict from itertools import chain, combinations -from typing import ( - Any, - Dict, - Iterator, - Mapping, Optional, - Tuple, Type, - TypedDict -) +from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict import pytorch_lightning as pl import torch import torchaudio as ta import torchmetrics as tm from asteroid import losses as asteroid_losses + # from deepspeed.ops.adam import DeepSpeedCPUAdam # from geoopt import optim as gooptim from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -30,7 +24,7 @@ # from pandas.io.json._normalize import nested_to_record -ConfigDict = TypedDict('ConfigDict', {'name': str, 'kwargs': Dict[str, Any]}) +ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]}) class SchedulerConfigDict(ConfigDict): @@ -38,9 +32,9 @@ class SchedulerConfigDict(ConfigDict): OptimizerSchedulerConfigDict = TypedDict( - 'OptimizerSchedulerConfigDict', - {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict}, - total=False + "OptimizerSchedulerConfigDict", + {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict}, + total=False, ) @@ -71,14 +65,13 @@ def get_optimizer_class(name: str) -> Type[optim.Optimizer]: def parse_optimizer_config( - config: OptimizerSchedulerConfigDict, - parameters: Iterator[nn.Parameter] + config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter] ) -> ConfigureOptimizerReturnDict: optim_class = get_optimizer_class(config["optimizer"]["name"]) optimizer = optim_class(parameters, **config["optimizer"]["kwargs"]) optim_dict: ConfigureOptimizerReturnDict = { - "optimizer": optimizer, + "optimizer": optimizer, } if "scheduler" in config: @@ -86,10 +79,7 @@ def parse_optimizer_config( lr_scheduler_class_ = config["scheduler"]["name"] lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_] lr_scheduler_dict: LRSchedulerReturnDict = { - "scheduler": lr_scheduler_class( - optimizer, - **config["scheduler"]["kwargs"] - ) + "scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"]) } if lr_scheduler_class_ == "ReduceLROnPlateau": @@ -169,29 +159,26 @@ class LightningSystem(pl.LightningModule): _BG_STEMS = ["background", "effects", "mne"] def __init__( - self, - config: Dict, - loss_adjustment: float = 1.0, - attach_fader: bool = False - ) -> None: + self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False + ) -> None: super().__init__() self.optimizer_config = config["optimizer"] self.model = parse_model_config(config["model"]) self.loss = parse_loss_config(config["loss"]) self.metrics = nn.ModuleDict( - { - stem: parse_metric_config(config["metrics"]["dev"]) - for stem in self.model.stems - } + { + stem: parse_metric_config(config["metrics"]["dev"]) + for stem in self.model.stems + } ) self.metrics.disallow_fsdp = True self.test_metrics = nn.ModuleDict( - { - stem: parse_metric_config(config["metrics"]["test"]) - for stem in self.model.stems - } + { + stem: parse_metric_config(config["metrics"]["test"]) + for stem in self.model.stems + } ) self.test_metrics.disallow_fsdp = True @@ -216,22 +203,18 @@ def __init__( self.val_prefix = None self.test_prefix = None - def configure_optimizers(self) -> Any: return parse_optimizer_config( - self.optimizer_config, - self.trainer.model.parameters() - ) + self.optimizer_config, self.trainer.model.parameters() + ) - def compute_loss(self, batch: BatchedDataDict, output: OutputType) -> Dict[ - str, torch.Tensor]: + def compute_loss( + self, batch: BatchedDataDict, output: OutputType + ) -> Dict[str, torch.Tensor]: return {"loss": self.loss(output, batch)} def update_metrics( - self, - batch: BatchedDataDict, - output: OutputType, - mode: str + self, batch: BatchedDataDict, output: OutputType, mode: str ) -> None: if mode == "test": @@ -247,9 +230,9 @@ def update_metrics( # print(f"matching for {stem}") if mode == "train": metric.update( - output["audio"][stem],#.cpu(), - batch["audio"][stem],#.cpu() - ) + output["audio"][stem], # .cpu(), + batch["audio"][stem], # .cpu() + ) else: if stem not in batch["audio"]: matched = False @@ -273,16 +256,18 @@ def update_metrics( if matched: # print(f"matched {stem}!") if stem == "mne" and "mne" not in output["audio"]: - output["audio"]["mne"] = output["audio"]["music"] + output["audio"]["effects"] - + output["audio"]["mne"] = ( + output["audio"]["music"] + output["audio"]["effects"] + ) + metric.update( - output["audio"][stem],#.cpu(), - batch["audio"][stem],#.cpu(), + output["audio"][stem], # .cpu(), + batch["audio"][stem], # .cpu(), ) # print(metric.compute()) - def compute_metrics(self, mode: str="dev") -> Dict[ - str, torch.Tensor]: + + def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]: if mode == "test": metrics = self.test_metrics @@ -293,10 +278,8 @@ def compute_metrics(self, mode: str="dev") -> Dict[ for stem, metric in metrics.items(): md = metric.compute() - metric_dict.update( - {f"{stem}/{k}": v for k, v in md.items()} - ) - + metric_dict.update({f"{stem}/{k}": v for k, v in md.items()}) + self.log_dict(metric_dict, prog_bar=True, logger=False) return metric_dict @@ -311,10 +294,8 @@ def reset_metrics(self, test_mode: bool = False) -> None: for _, metric in metrics.items(): metric.reset() - def forward(self, batch: BatchedDataDict) -> Any: batch, output = self.model(batch) - return batch, output @@ -332,7 +313,6 @@ def common_step(self, batch: BatchedDataDict, mode: str) -> Any: return output, loss_dict - def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]: if self.augmentation is not None: @@ -343,9 +323,7 @@ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]: with torch.inference_mode(): self.log_dict_with_prefix( - loss_dict, - "train", - batch_size=batch["audio"]["mixture"].shape[0] + loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0] ) loss_dict["loss"] *= self.loss_adjustment @@ -353,7 +331,7 @@ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]: return loss_dict def on_train_batch_end( - self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int + self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int ) -> None: metric_dict = self.compute_metrics() @@ -361,10 +339,7 @@ def on_train_batch_end( self.reset_metrics() def validation_step( - self, - batch: BatchedDataDict, - batch_idx: int, - dataloader_idx: int = 0 + self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0 ) -> Dict[str, Any]: with torch.inference_mode(): @@ -378,11 +353,11 @@ def validation_step( _, loss_dict = self.common_step(batch, mode="val") self.log_dict_with_prefix( - loss_dict, - self.val_prefix, - batch_size=batch["audio"]["mixture"].shape[0], - prog_bar=True, - add_dataloader_idx=False + loss_dict, + self.val_prefix, + batch_size=batch["audio"]["mixture"].shape[0], + prog_bar=True, + add_dataloader_idx=False, ) return loss_dict @@ -392,29 +367,23 @@ def on_validation_epoch_end(self) -> None: def _on_validation_epoch_end(self) -> None: metric_dict = self.compute_metrics() - self.log_dict_with_prefix(metric_dict, self.val_prefix, prog_bar=True, - add_dataloader_idx=False) + self.log_dict_with_prefix( + metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False + ) # self.logger.save() # print(self.val_prefix, "Validation metrics:", metric_dict) self.reset_metrics() - def old_predtest_step( - self, - batch: BatchedDataDict, - batch_idx: int, - dataloader_idx: int = 0 + self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0 ) -> Tuple[BatchedDataDict, OutputType]: audio_batch = batch["audio"]["mixture"] track_batch = batch.get("track", ["" for _ in range(len(audio_batch))]) output_list_of_dicts = [ - self.fader( - audio[None, ...], - lambda a: self.test_forward(a, track) - ) - for audio, track in zip(audio_batch, track_batch) + self.fader(audio[None, ...], lambda a: self.test_forward(a, track)) + for audio, track in zip(audio_batch, track_batch) ] output_dict_of_lists = defaultdict(list) @@ -424,19 +393,16 @@ def old_predtest_step( output_dict_of_lists[stem].append(audio) output = { - "audio": { - stem: torch.concat(output_list, dim=0) - for stem, output_list in output_dict_of_lists.items() - } + "audio": { + stem: torch.concat(output_list, dim=0) + for stem, output_list in output_dict_of_lists.items() + } } return batch, output def predtest_step( - self, - batch: BatchedDataDict, - batch_idx: int = -1, - dataloader_idx: int = 0 + self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0 ) -> Tuple[BatchedDataDict, OutputType]: if getattr(self.model, "bypass_fader", False): @@ -444,17 +410,13 @@ def predtest_step( else: audio_batch = batch["audio"]["mixture"] output = self.fader( - audio_batch, - lambda a: self.test_forward(a, "", batch=batch) + audio_batch, lambda a: self.test_forward(a, "", batch=batch) ) return batch, output def test_forward( - self, - audio: torch.Tensor, - track: str = "", - batch: BatchedDataDict = None + self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None ) -> torch.Tensor: if self.fader is None: @@ -466,10 +428,11 @@ def test_forward( cond = cond.repeat(audio.shape[0], 1) _, output = self.forward( - {"audio": {"mixture": audio}, - "track": track, - "condition": cond, - } + { + "audio": {"mixture": audio}, + "track": track, + "condition": cond, + } ) # TODO: support track properly return output["audio"] @@ -478,10 +441,7 @@ def on_test_epoch_start(self) -> None: self.attach_fader(force_reattach=True) def test_step( - self, - batch: BatchedDataDict, - batch_idx: int, - dataloader_idx: int = 0 + self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0 ) -> Any: curr_test_prefix = f"test{dataloader_idx}" @@ -505,22 +465,23 @@ def on_test_epoch_end(self) -> None: def _on_test_epoch_end(self) -> None: metric_dict = self.compute_metrics(mode="test") - self.log_dict_with_prefix(metric_dict, self.test_prefix, prog_bar=True, - add_dataloader_idx=False) + self.log_dict_with_prefix( + metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False + ) # self.logger.save() # print(self.test_prefix, "Test metrics:", metric_dict) self.reset_metrics() def predict_step( - self, - batch: BatchedDataDict, - batch_idx: int = 0, - dataloader_idx: int = 0, - include_track_name: Optional[bool] = None, - get_no_vox_combinations: bool = True, - get_residual: bool = False, - treat_batch_as_channels: bool = False, - fs: Optional[int] = None, + self, + batch: BatchedDataDict, + batch_idx: int = 0, + dataloader_idx: int = 0, + include_track_name: Optional[bool] = None, + get_no_vox_combinations: bool = True, + get_residual: bool = False, + treat_batch_as_channels: bool = False, + fs: Optional[int] = None, ) -> Any: assert self.predict_output_path is not None @@ -531,7 +492,7 @@ def predict_step( with torch.inference_mode(): batch, output = self.predtest_step(batch, batch_idx, dataloader_idx) - print('Pred test finished...') + print("Pred test finished...") torch.cuda.empty_cache() metric_dict = {} @@ -545,24 +506,22 @@ def predict_step( if get_no_vox_combinations: no_vox_stems = [ - stem for stem in output["audio"] if - stem not in self._VOX_STEMS + stem for stem in output["audio"] if stem not in self._VOX_STEMS ] no_vox_combinations = chain.from_iterable( - combinations(no_vox_stems, r) for r in - range(2, len(no_vox_stems) + 1) + combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1) ) for combination in no_vox_combinations: combination_ = list(combination) output["audio"]["+".join(combination_)] = sum( - [output["audio"][stem] for stem in combination_] + [output["audio"][stem] for stem in combination_] ) if treat_batch_as_channels: for stem in output["audio"]: output["audio"][stem] = output["audio"][stem].reshape( - 1, -1, output["audio"][stem].shape[-1] + 1, -1, output["audio"][stem].shape[-1] ) batch_size = 1 @@ -575,28 +534,24 @@ def predict_step( if batch.get("audio", {}).get(stem, None) is not None: self.test_metrics[stem].reset() metrics = self.test_metrics[stem]( - batch["audio"][stem][[b], ...], - output["audio"][stem][[b], ...] + batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...] ) snr = metrics["snr"] sisnr = metrics["sisnr"] sdr = metrics["sdr"] metric_dict[stem] = metrics print( - track_name, - f"snr={snr:2.2f} dB", - f"sisnr={sisnr:2.2f}", - f"sdr={sdr:2.2f} dB", + track_name, + f"snr={snr:2.2f} dB", + f"sisnr={sisnr:2.2f}", + f"sdr={sdr:2.2f} dB", ) filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav" else: filename = f"{stem}.wav" if include_track_name: - output_dir = os.path.join( - self.predict_output_path, - track_name - ) + output_dir = os.path.join(self.predict_output_path, track_name) else: output_dir = self.predict_output_path @@ -606,23 +561,23 @@ def predict_step( fs = self.fs ta.save( - os.path.join(output_dir, filename), - output["audio"][stem][b, ...].cpu(), - fs, + os.path.join(output_dir, filename), + output["audio"][stem][b, ...].cpu(), + fs, ) return metric_dict def get_stems( - self, - batch: BatchedDataDict, - batch_idx: int = 0, - dataloader_idx: int = 0, - include_track_name: Optional[bool] = None, - get_no_vox_combinations: bool = True, - get_residual: bool = False, - treat_batch_as_channels: bool = False, - fs: Optional[int] = None, + self, + batch: BatchedDataDict, + batch_idx: int = 0, + dataloader_idx: int = 0, + include_track_name: Optional[bool] = None, + get_no_vox_combinations: bool = True, + get_residual: bool = False, + treat_batch_as_channels: bool = False, + fs: Optional[int] = None, ) -> Any: assert self.predict_output_path is not None @@ -646,24 +601,22 @@ def get_stems( if get_no_vox_combinations: no_vox_stems = [ - stem for stem in output["audio"] if - stem not in self._VOX_STEMS + stem for stem in output["audio"] if stem not in self._VOX_STEMS ] no_vox_combinations = chain.from_iterable( - combinations(no_vox_stems, r) for r in - range(2, len(no_vox_stems) + 1) + combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1) ) for combination in no_vox_combinations: combination_ = list(combination) output["audio"]["+".join(combination_)] = sum( - [output["audio"][stem] for stem in combination_] + [output["audio"][stem] for stem in combination_] ) if treat_batch_as_channels: for stem in output["audio"]: output["audio"][stem] = output["audio"][stem].reshape( - 1, -1, output["audio"][stem].shape[-1] + 1, -1, output["audio"][stem].shape[-1] ) batch_size = 1 @@ -675,28 +628,24 @@ def get_stems( if batch.get("audio", {}).get(stem, None) is not None: self.test_metrics[stem].reset() metrics = self.test_metrics[stem]( - batch["audio"][stem][[b], ...], - output["audio"][stem][[b], ...] + batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...] ) snr = metrics["snr"] sisnr = metrics["sisnr"] sdr = metrics["sdr"] metric_dict[stem] = metrics print( - track_name, - f"snr={snr:2.2f} dB", - f"sisnr={sisnr:2.2f}", - f"sdr={sdr:2.2f} dB", + track_name, + f"snr={snr:2.2f} dB", + f"sisnr={sisnr:2.2f}", + f"sdr={sdr:2.2f} dB", ) filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav" else: filename = f"{stem}.wav" if include_track_name: - output_dir = os.path.join( - self.predict_output_path, - track_name - ) + output_dir = os.path.join(self.predict_output_path, track_name) else: output_dir = self.predict_output_path @@ -710,12 +659,11 @@ def get_stems( return result def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = False + self, state_dict: Mapping[str, Any], strict: bool = False ) -> Any: return super().load_state_dict(state_dict, strict=False) - def set_predict_output_path(self, path: str) -> None: self.predict_output_path = path os.makedirs(self.predict_output_path, exist_ok=True) @@ -727,18 +675,17 @@ def attach_fader(self, force_reattach=False) -> None: self.fader = parse_fader_config(self.fader_config) self.fader.to(self.device) - def log_dict_with_prefix( - self, - dict_: Dict[str, torch.Tensor], - prefix: str, - batch_size: Optional[int] = None, - **kwargs: Any + self, + dict_: Dict[str, torch.Tensor], + prefix: str, + batch_size: Optional[int] = None, + **kwargs: Any, ) -> None: self.log_dict( - {f"{prefix}/{k}": v for k, v in dict_.items()}, - batch_size=batch_size, - logger=True, - sync_dist=True, - **kwargs, - ) \ No newline at end of file + {f"{prefix}/{k}": v for k, v in dict_.items()}, + batch_size=batch_size, + logger=True, + sync_dist=True, + **kwargs, + ) diff --git a/programs/music_separation_code/models/bandit/core/data/__init__.py b/programs/music_separation_code/models/bandit/core/data/__init__.py index 1087fe2..a9d4d67 100644 --- a/programs/music_separation_code/models/bandit/core/data/__init__.py +++ b/programs/music_separation_code/models/bandit/core/data/__init__.py @@ -1,2 +1,2 @@ from .dnr.datamodule import DivideAndRemasterDataModule -from .musdb.datamodule import MUSDB18DataModule \ No newline at end of file +from .musdb.datamodule import MUSDB18DataModule diff --git a/programs/music_separation_code/models/bandit/core/data/_types.py b/programs/music_separation_code/models/bandit/core/data/_types.py index 9499f9a..65e4607 100644 --- a/programs/music_separation_code/models/bandit/core/data/_types.py +++ b/programs/music_separation_code/models/bandit/core/data/_types.py @@ -4,11 +4,10 @@ AudioDict = Dict[str, torch.Tensor] -DataDict = TypedDict('DataDict', {'audio': AudioDict, 'track': str}) +DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str}) BatchedDataDict = TypedDict( - 'BatchedDataDict', - {'audio': AudioDict, 'track': Sequence[str]} + "BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]} ) diff --git a/programs/music_separation_code/models/bandit/core/data/augmentation.py b/programs/music_separation_code/models/bandit/core/data/augmentation.py index 238214b..1aa2a9c 100644 --- a/programs/music_separation_code/models/bandit/core/data/augmentation.py +++ b/programs/music_separation_code/models/bandit/core/data/augmentation.py @@ -9,18 +9,19 @@ class BaseAugmentor(nn.Module, ABC): - def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[ - DataDict, BatchedDataDict]: + def forward( + self, item: Union[DataDict, BatchedDataDict] + ) -> Union[DataDict, BatchedDataDict]: raise NotImplementedError class StemAugmentor(BaseAugmentor): def __init__( - self, - audiomentations: Dict[str, Dict[str, Any]], - fix_clipping: bool = True, - scaler_margin: float = 0.5, - apply_both_default_and_common: bool = False, + self, + audiomentations: Dict[str, Dict[str, Any]], + fix_clipping: bool = True, + scaler_margin: float = 0.5, + apply_both_default_and_common: bool = False, ) -> None: super().__init__() @@ -32,23 +33,16 @@ def __init__( for stem in audiomentations: if audiomentations[stem]["name"] == "Compose": - augmentations[stem] = getattr( - tam, - audiomentations[stem]["name"] - )( - [ - getattr(tam, aug["name"])(**aug["kwargs"]) - for aug in - audiomentations[stem]["kwargs"]["transforms"] - ], - **audiomentations[stem]["kwargs"]["kwargs"], + augmentations[stem] = getattr(tam, audiomentations[stem]["name"])( + [ + getattr(tam, aug["name"])(**aug["kwargs"]) + for aug in audiomentations[stem]["kwargs"]["transforms"] + ], + **audiomentations[stem]["kwargs"]["kwargs"], ) else: - augmentations[stem] = getattr( - tam, - audiomentations[stem]["name"] - )( - **audiomentations[stem]["kwargs"] + augmentations[stem] = getattr(tam, audiomentations[stem]["name"])( + **audiomentations[stem]["kwargs"] ) self.augmentations = nn.ModuleDict(augmentations) @@ -56,7 +50,7 @@ def __init__( self.scaler_margin = scaler_margin def check_and_fix_clipping( - self, item: Union[DataDict, BatchedDataDict] + self, item: Union[DataDict, BatchedDataDict] ) -> Union[DataDict, BatchedDataDict]: max_abs = [] @@ -64,18 +58,20 @@ def check_and_fix_clipping( max_abs.append(item["audio"][stem].abs().max().item()) if max(max_abs) > 1.0: - scaler = 1.0 / (max(max_abs) + torch.rand( - (1,), - device=item["audio"]["mixture"].device - ) * self.scaler_margin) + scaler = 1.0 / ( + max(max_abs) + + torch.rand((1,), device=item["audio"]["mixture"].device) + * self.scaler_margin + ) for stem in item["audio"]: item["audio"][stem] *= scaler return item - def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[ - DataDict, BatchedDataDict]: + def forward( + self, item: Union[DataDict, BatchedDataDict] + ) -> Union[DataDict, BatchedDataDict]: for stem in item["audio"]: if stem == "mixture": @@ -83,22 +79,21 @@ def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[ if self.has_common: item["audio"][stem] = self.augmentations["[common]"]( - item["audio"][stem] + item["audio"][stem] ).samples if stem in self.augmentations: item["audio"][stem] = self.augmentations[stem]( - item["audio"][stem] + item["audio"][stem] ).samples elif self.has_default: if not self.has_common or self.apply_both_default_and_common: item["audio"][stem] = self.augmentations["[default]"]( - item["audio"][stem] + item["audio"][stem] ).samples item["audio"]["mixture"] = sum( - [item["audio"][stem] for stem in item["audio"] - if stem != "mixture"] + [item["audio"][stem] for stem in item["audio"] if stem != "mixture"] ) # type: ignore[call-overload, assignment] if self.fix_clipping: diff --git a/programs/music_separation_code/models/bandit/core/data/augmented.py b/programs/music_separation_code/models/bandit/core/data/augmented.py index 84d1959..3c05244 100644 --- a/programs/music_separation_code/models/bandit/core/data/augmented.py +++ b/programs/music_separation_code/models/bandit/core/data/augmented.py @@ -8,15 +8,15 @@ class AugmentedDataset(data.Dataset): def __init__( - self, - dataset: data.Dataset, - augmentation: nn.Module = nn.Identity(), - target_length: Optional[int] = None, + self, + dataset: data.Dataset, + augmentation: nn.Module = nn.Identity(), + target_length: Optional[int] = None, ) -> None: warnings.warn( - "This class is no longer used. Attach augmentation to " - "the LightningSystem instead.", - DeprecationWarning, + "This class is no longer used. Attach augmentation to " + "the LightningSystem instead.", + DeprecationWarning, ) self.dataset = dataset @@ -25,8 +25,7 @@ def __init__( self.ds_length: int = len(dataset) # type: ignore[arg-type] self.length = target_length if target_length is not None else self.ds_length - def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, - torch.Tensor]]]: + def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]: item = self.dataset[index % self.ds_length] item = self.augmentation(item) return item diff --git a/programs/music_separation_code/models/bandit/core/data/base.py b/programs/music_separation_code/models/bandit/core/data/base.py index a7b6c33..18e3739 100644 --- a/programs/music_separation_code/models/bandit/core/data/base.py +++ b/programs/music_separation_code/models/bandit/core/data/base.py @@ -1,6 +1,5 @@ -import os from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import numpy as np import pedalboard as pb @@ -13,14 +12,15 @@ class BaseSourceSeparationDataset(data.Dataset, ABC): def __init__( - self, split: str, - stems: List[str], - files: List[str], - data_path: str, - fs: int, - npy_memmap: bool, - recompute_mixture: bool - ): + self, + split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int, + npy_memmap: bool, + recompute_mixture: bool, + ): self.split = split self.stems = stems self.stems_no_mixture = [s for s in stems if s != "mixture"] @@ -31,12 +31,7 @@ def __init__( self.recompute_mixture = recompute_mixture @abstractmethod - def get_stem( - self, - *, - stem: str, - identifier: Dict[str, Any] - ) -> torch.Tensor: + def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor: raise NotImplementedError def _get_audio(self, stems, identifier: Dict[str, Any]): @@ -49,10 +44,7 @@ def _get_audio(self, stems, identifier: Dict[str, Any]): def get_audio(self, identifier: Dict[str, Any]) -> AudioDict: if self.recompute_mixture: - audio = self._get_audio( - self.stems_no_mixture, - identifier=identifier - ) + audio = self._get_audio(self.stems_no_mixture, identifier=identifier) audio["mixture"] = self.compute_mixture(audio) return audio else: @@ -64,6 +56,4 @@ def get_identifier(self, index: int) -> Dict[str, Any]: def compute_mixture(self, audio: AudioDict) -> torch.Tensor: - return sum( - audio[stem] for stem in audio if stem != "mixture" - ) + return sum(audio[stem] for stem in audio if stem != "mixture") diff --git a/programs/music_separation_code/models/bandit/core/data/dnr/datamodule.py b/programs/music_separation_code/models/bandit/core/data/dnr/datamodule.py index dc55506..2971d41 100644 --- a/programs/music_separation_code/models/bandit/core/data/dnr/datamodule.py +++ b/programs/music_separation_code/models/bandit/core/data/dnr/datamodule.py @@ -7,20 +7,20 @@ DivideAndRemasterDataset, DivideAndRemasterDeterministicChunkDataset, DivideAndRemasterRandomChunkDataset, - DivideAndRemasterRandomChunkDatasetWithSpeechReverb + DivideAndRemasterRandomChunkDatasetWithSpeechReverb, ) def DivideAndRemasterDataModule( - data_root: str = "$DATA_ROOT/DnR/v2", - batch_size: int = 2, - num_workers: int = 8, - train_kwargs: Optional[Mapping] = None, - val_kwargs: Optional[Mapping] = None, - test_kwargs: Optional[Mapping] = None, - datamodule_kwargs: Optional[Mapping] = None, - use_speech_reverb: bool = False - # augmentor=None + data_root: str = "$DATA_ROOT/DnR/v2", + batch_size: int = 2, + num_workers: int = 8, + train_kwargs: Optional[Mapping] = None, + val_kwargs: Optional[Mapping] = None, + test_kwargs: Optional[Mapping] = None, + datamodule_kwargs: Optional[Mapping] = None, + use_speech_reverb: bool = False, + # augmentor=None ) -> pl.LightningDataModule: if train_kwargs is None: train_kwargs = {} @@ -47,26 +47,20 @@ def DivideAndRemasterDataModule( else: train_cls = DivideAndRemasterRandomChunkDataset - train_dataset = train_cls( - data_root, "train", **train_kwargs - ) + train_dataset = train_cls(data_root, "train", **train_kwargs) # if augmentor is not None: # train_dataset = AugmentedDataset(train_dataset, augmentor) datamodule = pl.LightningDataModule.from_datasets( - train_dataset=train_dataset, - val_dataset=DivideAndRemasterDeterministicChunkDataset( - data_root, "val", **val_kwargs - ), - test_dataset=DivideAndRemasterDataset( - data_root, - "test", - **test_kwargs - ), - batch_size=batch_size, - num_workers=num_workers, - **datamodule_kwargs + train_dataset=train_dataset, + val_dataset=DivideAndRemasterDeterministicChunkDataset( + data_root, "val", **val_kwargs + ), + test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs), + batch_size=batch_size, + num_workers=num_workers, + **datamodule_kwargs ) datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign] diff --git a/programs/music_separation_code/models/bandit/core/data/dnr/dataset.py b/programs/music_separation_code/models/bandit/core/data/dnr/dataset.py index 639290d..00142c7 100644 --- a/programs/music_separation_code/models/bandit/core/data/dnr/dataset.py +++ b/programs/music_separation_code/models/bandit/core/data/dnr/dataset.py @@ -15,10 +15,10 @@ class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC): ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"] STEM_NAME_MAP = { - "mixture": "mix", - "speech": "speech", - "music": "music", - "effects": "sfx", + "mixture": "mix", + "speech": "speech", + "music": "music", + "effects": "sfx", } SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"} @@ -26,52 +26,42 @@ class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC): FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100 def __init__( - self, - split: str, - stems: List[str], - files: List[str], - data_path: str, - fs: int = 44100, - npy_memmap: bool = True, - recompute_mixture: bool = False, + self, + split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int = 44100, + npy_memmap: bool = True, + recompute_mixture: bool = False, ) -> None: super().__init__( - split=split, - stems=stems, - files=files, - data_path=data_path, - fs=fs, - npy_memmap=npy_memmap, - recompute_mixture=recompute_mixture + split=split, + stems=stems, + files=files, + data_path=data_path, + fs=fs, + npy_memmap=npy_memmap, + recompute_mixture=recompute_mixture, ) - def get_stem( - self, - *, - stem: str, - identifier: Dict[str, Any] - ) -> torch.Tensor: - + def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor: + if stem == "mne": - return self.get_stem( - stem="music", - identifier=identifier) + self.get_stem( - stem="effects", - identifier=identifier) + return self.get_stem(stem="music", identifier=identifier) + self.get_stem( + stem="effects", identifier=identifier + ) track = identifier["track"] path = os.path.join(self.data_path, track) if self.npy_memmap: audio = np.load( - os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), - mmap_mode="r" + os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r" ) else: # noinspection PyUnresolvedReferences - audio, _ = ta.load( - os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav") - ) + audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav")) return audio @@ -87,12 +77,12 @@ def __getitem__(self, index: int) -> DataDict: class DivideAndRemasterDataset(DivideAndRemasterBaseDataset): def __init__( - self, - data_root: str, - split: str, - stems: Optional[List[str]] = None, - fs: int = 44100, - npy_memmap: bool = True, + self, + data_root: str, + split: str, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, ) -> None: if stems is None: @@ -103,11 +93,9 @@ def __init__( files = sorted(os.listdir(data_path)) files = [ - f - for f in files - if (not f.startswith(".")) and os.path.isdir( - os.path.join(data_path, f) - ) + f + for f in files + if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f)) ] # pprint(list(enumerate(files))) if split == "train": @@ -120,12 +108,12 @@ def __init__( self.n_tracks = len(files) super().__init__( - data_path=data_path, - split=split, - stems=stems, - files=files, - fs=fs, - npy_memmap=npy_memmap, + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, ) def __len__(self) -> int: @@ -134,14 +122,14 @@ def __len__(self) -> int: class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset): def __init__( - self, - data_root: str, - split: str, - target_length: int, - chunk_size_second: float, - stems: Optional[List[str]] = None, - fs: int = 44100, - npy_memmap: bool = True, + self, + data_root: str, + split: str, + target_length: int, + chunk_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, ) -> None: if stems is None: @@ -152,11 +140,9 @@ def __init__( files = sorted(os.listdir(data_path)) files = [ - f - for f in files - if (not f.startswith(".")) and os.path.isdir( - os.path.join(data_path, f) - ) + f + for f in files + if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f)) ] if split == "train": @@ -172,12 +158,12 @@ def __init__( self.chunk_size = int(chunk_size_second * fs) super().__init__( - data_path=data_path, - split=split, - stems=stems, - files=files, - fs=fs, - npy_memmap=npy_memmap, + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, ) def __len__(self) -> int: @@ -187,22 +173,18 @@ def get_identifier(self, index): return super().get_identifier(index % self.n_tracks) def get_stem( - self, - *, - stem: str, - identifier: Dict[str, Any], - chunk_here: bool = False, - ) -> torch.Tensor: - - stem = super().get_stem( - stem=stem, - identifier=identifier - ) + self, + *, + stem: str, + identifier: Dict[str, Any], + chunk_here: bool = False, + ) -> torch.Tensor: + + stem = super().get_stem(stem=stem, identifier=identifier) if chunk_here: start = np.random.randint( - 0, - self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size + 0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size ) end = start + self.chunk_size @@ -216,29 +198,24 @@ def __getitem__(self, index: int) -> DataDict: audio = self.get_audio(identifier) # self.index_lock = None - start = np.random.randint( - 0, - self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size - ) + start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size) end = start + self.chunk_size - audio = { - k: v[:, start:end] for k, v in audio.items() - } + audio = {k: v[:, start:end] for k, v in audio.items()} return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset): def __init__( - self, - data_root: str, - split: str, - chunk_size_second: float, - hop_size_second: float, - stems: Optional[List[str]] = None, - fs: int = 44100, - npy_memmap: bool = True, + self, + data_root: str, + split: str, + chunk_size_second: float, + hop_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, ) -> None: if stems is None: @@ -249,11 +226,9 @@ def __init__( files = sorted(os.listdir(data_path)) files = [ - f - for f in files - if (not f.startswith(".")) and os.path.isdir( - os.path.join(data_path, f) - ) + f + for f in files + if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f)) ] # pprint(list(enumerate(files))) if split == "train": @@ -268,19 +243,18 @@ def __init__( self.chunk_size = int(chunk_size_second * fs) self.hop_size = int(hop_size_second * fs) self.n_chunks_per_track = int( - ( - self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second + (self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second ) self.length = self.n_tracks * self.n_chunks_per_track super().__init__( - data_path=data_path, - split=split, - stems=stems, - files=files, - fs=fs, - npy_memmap=npy_memmap, + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, ) def get_identifier(self, index): @@ -308,17 +282,17 @@ def __getitem__(self, item: int) -> DataDict: class DivideAndRemasterRandomChunkDatasetWithSpeechReverb( - DivideAndRemasterRandomChunkDataset + DivideAndRemasterRandomChunkDataset ): def __init__( - self, - data_root: str, - split: str, - target_length: int, - chunk_size_second: float, - stems: Optional[List[str]] = None, - fs: int = 44100, - npy_memmap: bool = True, + self, + data_root: str, + split: str, + target_length: int, + chunk_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, ) -> None: if stems is None: @@ -327,13 +301,13 @@ def __init__( stems_no_mixture = [s for s in stems if s != "mixture"] super().__init__( - data_root=data_root, - split=split, - target_length=target_length, - chunk_size_second=chunk_size_second, - stems=stems_no_mixture, - fs=fs, - npy_memmap=npy_memmap, + data_root=data_root, + split=split, + target_length=target_length, + chunk_size_second=chunk_size_second, + stems=stems_no_mixture, + fs=fs, + npy_memmap=npy_memmap, ) self.stems = stems @@ -349,17 +323,17 @@ def __getitem__(self, index: int) -> DataDict: wet_level = np.random.rand() speech = pb.Reverb( - room_size=np.random.rand(), - damping=np.random.rand(), - wet_level=wet_level, - dry_level=(1 - wet_level), - width=np.random.rand() + room_size=np.random.rand(), + damping=np.random.rand(), + wet_level=wet_level, + dry_level=(1 - wet_level), + width=np.random.rand(), ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples] data_["audio"]["speech"] = speech data_["audio"]["mixture"] = sum( - [data_["audio"][s] for s in self.stems_no_mixture] + [data_["audio"][s] for s in self.stems_no_mixture] ) return data_ @@ -375,10 +349,10 @@ def __len__(self) -> int: for split_ in ["train", "val", "test"]: ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb( - data_root="$DATA_ROOT/DnR/v2np", - split=split_, - target_length=100, - chunk_size_second=6.0 + data_root="$DATA_ROOT/DnR/v2np", + split=split_, + target_length=100, + chunk_size_second=6.0, ) print(split_, len(ds)) diff --git a/programs/music_separation_code/models/bandit/core/data/dnr/preprocess.py b/programs/music_separation_code/models/bandit/core/data/dnr/preprocess.py index 9d0b586..18d68b1 100644 --- a/programs/music_separation_code/models/bandit/core/data/dnr/preprocess.py +++ b/programs/music_separation_code/models/bandit/core/data/dnr/preprocess.py @@ -16,7 +16,9 @@ def process_one(inputs: Tuple[str, str, int]) -> None: data, fs = ta.load(infile) if fs != target_fs: - data = ta.functional.resample(data, fs, target_fs, resampling_method="sinc_interp_kaiser") + data = ta.functional.resample( + data, fs, target_fs, resampling_method="sinc_interp_kaiser" + ) fs = target_fs data = data.numpy() @@ -30,16 +32,11 @@ def process_one(inputs: Tuple[str, str, int]) -> None: np.save(outfile, data) -def preprocess( - data_path: str, - output_path: str, - fs: int -) -> None: +def preprocess(data_path: str, output_path: str, fs: int) -> None: files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) print(files) outfiles = [ - f.replace(data_path, output_path).replace(".wav", ".npy") for f in - files + f.replace(data_path, output_path).replace(".wav", ".npy") for f in files ] os.makedirs(output_path, exist_ok=True) diff --git a/programs/music_separation_code/models/bandit/core/data/musdb/datamodule.py b/programs/music_separation_code/models/bandit/core/data/musdb/datamodule.py index a8984da..7b3c25e 100644 --- a/programs/music_separation_code/models/bandit/core/data/musdb/datamodule.py +++ b/programs/music_separation_code/models/bandit/core/data/musdb/datamodule.py @@ -7,21 +7,21 @@ MUSDB18BaseDataset, MUSDB18FullTrackDataset, MUSDB18SadDataset, - MUSDB18SadOnTheFlyAugmentedDataset + MUSDB18SadOnTheFlyAugmentedDataset, ) def MUSDB18DataModule( - data_root: str = "$DATA_ROOT/MUSDB18/HQ", - target_stem: str = "vocals", - batch_size: int = 2, - num_workers: int = 8, - train_kwargs: Optional[Mapping] = None, - val_kwargs: Optional[Mapping] = None, - test_kwargs: Optional[Mapping] = None, - datamodule_kwargs: Optional[Mapping] = None, - use_on_the_fly: bool = True, - npy_memmap: bool = True + data_root: str = "$DATA_ROOT/MUSDB18/HQ", + target_stem: str = "vocals", + batch_size: int = 2, + num_workers: int = 8, + train_kwargs: Optional[Mapping] = None, + val_kwargs: Optional[Mapping] = None, + test_kwargs: Optional[Mapping] = None, + datamodule_kwargs: Optional[Mapping] = None, + use_on_the_fly: bool = True, + npy_memmap: bool = True, ) -> pl.LightningDataModule: if train_kwargs is None: train_kwargs = {} @@ -39,39 +39,37 @@ def MUSDB18DataModule( if use_on_the_fly: train_dataset = MUSDB18SadOnTheFlyAugmentedDataset( - data_root=os.path.join(data_root, "saded-np"), - split="train", - target_stem=target_stem, - **train_kwargs + data_root=os.path.join(data_root, "saded-np"), + split="train", + target_stem=target_stem, + **train_kwargs ) else: train_dataset = MUSDB18SadDataset( - data_root=os.path.join(data_root, "saded-np"), - split="train", - target_stem=target_stem, - **train_kwargs + data_root=os.path.join(data_root, "saded-np"), + split="train", + target_stem=target_stem, + **train_kwargs ) datamodule = pl.LightningDataModule.from_datasets( - train_dataset=train_dataset, - val_dataset=MUSDB18SadDataset( - data_root=os.path.join(data_root, "saded-np"), - split="val", - target_stem=target_stem, - **val_kwargs - ), - test_dataset=MUSDB18FullTrackDataset( - data_root=os.path.join(data_root, "canonical"), - split="test", - **test_kwargs - ), - batch_size=batch_size, - num_workers=num_workers, - **datamodule_kwargs + train_dataset=train_dataset, + val_dataset=MUSDB18SadDataset( + data_root=os.path.join(data_root, "saded-np"), + split="val", + target_stem=target_stem, + **val_kwargs + ), + test_dataset=MUSDB18FullTrackDataset( + data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs + ), + batch_size=batch_size, + num_workers=num_workers, + **datamodule_kwargs ) datamodule.predict_dataloader = ( # type: ignore[method-assign] - datamodule.test_dataloader + datamodule.test_dataloader ) return datamodule diff --git a/programs/music_separation_code/models/bandit/core/data/musdb/dataset.py b/programs/music_separation_code/models/bandit/core/data/musdb/dataset.py index c59a07d..f66319f 100644 --- a/programs/music_separation_code/models/bandit/core/data/musdb/dataset.py +++ b/programs/music_separation_code/models/bandit/core/data/musdb/dataset.py @@ -16,22 +16,22 @@ class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC): ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"] def __init__( - self, - split: str, - stems: List[str], - files: List[str], - data_path: str, - fs: int = 44100, - npy_memmap=False, + self, + split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int = 44100, + npy_memmap=False, ) -> None: super().__init__( - split=split, - stems=stems, - files=files, - data_path=data_path, - fs=fs, - npy_memmap=npy_memmap, - recompute_mixture=False + split=split, + stems=stems, + files=files, + data_path=data_path, + fs=fs, + npy_memmap=npy_memmap, + recompute_mixture=False, ) def get_stem(self, *, stem: str, identifier) -> torch.Tensor: @@ -61,25 +61,24 @@ class MUSDB18FullTrackDataset(MUSDB18BaseDataset): N_TRAIN_TRACKS = 100 N_TEST_TRACKS = 50 VALIDATION_FILES = [ - "Actions - One Minute Smile", - "Clara Berry And Wooldog - Waltz For My Victims", - "Johnny Lokke - Promises & Lies", - "Patrick Talbot - A Reason To Leave", - "Triviul - Angelsaint", - "Alexander Ross - Goodbye Bolero", - "Fergessen - Nos Palpitants", - "Leaf - Summerghost", - "Skelpolu - Human Mistakes", - "Young Griffo - Pennies", - "ANiMAL - Rockshow", - "James May - On The Line", - "Meaxic - Take A Step", - "Traffic Experiment - Sirens", + "Actions - One Minute Smile", + "Clara Berry And Wooldog - Waltz For My Victims", + "Johnny Lokke - Promises & Lies", + "Patrick Talbot - A Reason To Leave", + "Triviul - Angelsaint", + "Alexander Ross - Goodbye Bolero", + "Fergessen - Nos Palpitants", + "Leaf - Summerghost", + "Skelpolu - Human Mistakes", + "Young Griffo - Pennies", + "ANiMAL - Rockshow", + "James May - On The Line", + "Meaxic - Take A Step", + "Traffic Experiment - Sirens", ] def __init__( - self, data_root: str, split: str, stems: Optional[List[ - str]] = None + self, data_root: str, split: str, stems: Optional[List[str]] = None ) -> None: if stems is None: @@ -112,25 +111,21 @@ def __init__( self.n_tracks = len(files) - super().__init__( - data_path=data_path, - split=split, - stems=stems, - files=files - ) + super().__init__(data_path=data_path, split=split, stems=stems, files=files) def __len__(self) -> int: return self.n_tracks + class MUSDB18SadDataset(MUSDB18BaseDataset): def __init__( - self, - data_root: str, - split: str, - target_stem: str, - stems: Optional[List[str]] = None, - target_length: Optional[int] = None, - npy_memmap=False, + self, + data_root: str, + split: str, + target_stem: str, + stems: Optional[List[str]] = None, + target_length: Optional[int] = None, + npy_memmap=False, ) -> None: if stems is None: @@ -142,16 +137,16 @@ def __init__( files = [f for f in files if not f.startswith(".")] super().__init__( - data_path=data_path, - split=split, - stems=stems, - files=files, - npy_memmap=npy_memmap + data_path=data_path, + split=split, + stems=stems, + files=files, + npy_memmap=npy_memmap, ) self.n_segments = len(files) self.target_stem = target_stem self.target_length = ( - target_length if target_length is not None else self.n_segments + target_length if target_length is not None else self.n_segments ) def __len__(self) -> int: @@ -169,23 +164,22 @@ def get_identifier(self, index): class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset): def __init__( - self, - data_root: str, - split: str, - target_stem: str, - stems: Optional[List[str]] = None, - target_length: int = 20000, - apply_probability: Optional[float] = None, - chunk_size_second: float = 3.0, - random_scale_range_db: Tuple[float, float] = (-10, 10), - drop_probability: float = 0.1, - rescale: bool = True, + self, + data_root: str, + split: str, + target_stem: str, + stems: Optional[List[str]] = None, + target_length: int = 20000, + apply_probability: Optional[float] = None, + chunk_size_second: float = 3.0, + random_scale_range_db: Tuple[float, float] = (-10, 10), + drop_probability: float = 0.1, + rescale: bool = True, ) -> None: super().__init__(data_root, split, target_stem, stems) if apply_probability is None: - apply_probability = ( - target_length - self.n_segments) / target_length + apply_probability = (target_length - self.n_segments) / target_length self.apply_probability = apply_probability self.drop_probability = drop_probability @@ -226,7 +220,7 @@ def __getitem__(self, index: int) -> DataDict: if self.chunk_size_sample < audio[stem].shape[-1]: chunk_start = np.random.randint( - audio[stem].shape[-1] - self.chunk_size_sample + audio[stem].shape[-1] - self.chunk_size_sample ) else: chunk_start = 0 @@ -239,18 +233,16 @@ def __getitem__(self, index: int) -> DataDict: linear_scale = np.power(10, db_scale / 20) # db_scale = f"{db_scale:+2.1f}" # print(linear_scale) - audio[stem][..., - chunk_start: chunk_start + self.chunk_size_sample] = ( - linear_scale - * audio[stem][..., - chunk_start: chunk_start + self.chunk_size_sample] + audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = ( + linear_scale + * audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] ) audio["mixture"] = self.compute_mixture(audio) if self.rescale: max_abs_val = max( - [torch.max(torch.abs(audio[stem])) for stem in self.stems] + [torch.max(torch.abs(audio[stem])) for stem in self.stems] ) # type: ignore[type-var] if max_abs_val > 1: audio = {k: v / max_abs_val for k, v in audio.items()} @@ -259,6 +251,7 @@ def __getitem__(self, index: int) -> DataDict: return {"audio": audio, "track": f"{self.split}/{track}"} + # if __name__ == "__main__": # # from pprint import pprint diff --git a/programs/music_separation_code/models/bandit/core/data/musdb/preprocess.py b/programs/music_separation_code/models/bandit/core/data/musdb/preprocess.py index 45b3fe4..bbc02b1 100644 --- a/programs/music_separation_code/models/bandit/core/data/musdb/preprocess.py +++ b/programs/music_separation_code/models/bandit/core/data/musdb/preprocess.py @@ -12,20 +12,21 @@ from core.data.musdb.dataset import MUSDB18FullTrackDataset import pyloudnorm as pyln + class SourceActivityDetector(nn.Module): def __init__( - self, - analysis_stem: str, - output_path: str, - fs: int = 44100, - segment_length_second: float = 6.0, - hop_length_second: float = 3.0, - n_chunks: int = 10, - chunk_epsilon: float = 1e-5, - energy_threshold_quantile: float = 0.15, - segment_epsilon: float = 1e-3, - salient_proportion_threshold: float = 0.5, - target_lufs: float = -24 + self, + analysis_stem: str, + output_path: str, + fs: int = 44100, + segment_length_second: float = 6.0, + hop_length_second: float = 3.0, + n_chunks: int = 10, + chunk_epsilon: float = 1e-5, + energy_threshold_quantile: float = 0.15, + segment_epsilon: float = 1e-3, + salient_proportion_threshold: float = 0.5, + target_lufs: float = -24, ) -> None: super().__init__() @@ -48,8 +49,7 @@ def __init__( def forward(self, data: DataDict) -> None: - stem_ = self.analysis_stem if ( - self.analysis_stem != "none") else "mixture" + stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture" x = data["audio"][stem_] @@ -69,9 +69,7 @@ def forward(self, data: DataDict) -> None: n_chan, n_samples = x.shape n_segments = ( - int( - np.ceil((n_samples - self.segment_length) / self.hop_length) - ) + 1 + int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1 ) segments = torch.zeros((n_segments, n_chan, self.segment_length)) @@ -84,16 +82,12 @@ def forward(self, data: DataDict) -> None: if end - start < self.segment_length: xseg = F.pad( - xseg, - pad=(0, self.segment_length - (end - start)), - value=torch.nan + xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan ) segments[i, :, :] = xseg - chunks = segments.reshape( - (n_segments, n_chan, self.n_chunks, self.chunk_size) - ) + chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size)) if self.analysis_stem != "none": chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3)) @@ -101,7 +95,7 @@ def forward(self, data: DataDict) -> None: chunk_energies[chunk_energies == 0] = self.chunk_epsilon energy_threshold = torch.nanquantile( - chunk_energies, q=self.energy_threshold_quantile + chunk_energies, q=self.energy_threshold_quantile ) if energy_threshold < self.segment_epsilon: @@ -109,11 +103,11 @@ def forward(self, data: DataDict) -> None: chunks_above_threshold = chunk_energies > energy_threshold n_chunks_above_threshold = torch.mean( - chunks_above_threshold.to(torch.float), dim=-1 + chunks_above_threshold.to(torch.float), dim=-1 ) segment_above_threshold = ( - n_chunks_above_threshold > self.salient_proportion_threshold + n_chunks_above_threshold > self.salient_proportion_threshold ) if torch.sum(segment_above_threshold) == 0: @@ -127,9 +121,9 @@ def forward(self, data: DataDict) -> None: continue outpath = os.path.join( - self.output_path, - self.analysis_stem, - f"{data['track']} - {self.analysis_stem}{i:03d}", + self.output_path, + self.analysis_stem, + f"{data['track']} - {self.analysis_stem}{i:03d}", ) os.makedirs(outpath, exist_ok=True) @@ -145,8 +139,7 @@ def forward(self, data: DataDict) -> None: if end - start < self.segment_length: segment = F.pad( - segment, - (0, self.segment_length - (end - start)) + segment, (0, self.segment_length - (end - start)) ) assert segment.shape[-1] == self.segment_length, segment.shape @@ -157,35 +150,35 @@ def forward(self, data: DataDict) -> None: def preprocess( - analysis_stem: str, - output_path: str = "/data/MUSDB18/HQ/saded-np", - fs: int = 44100, - segment_length_second: float = 6.0, - hop_length_second: float = 3.0, - n_chunks: int = 10, - chunk_epsilon: float = 1e-5, - energy_threshold_quantile: float = 0.15, - segment_epsilon: float = 1e-3, - salient_proportion_threshold: float = 0.5, + analysis_stem: str, + output_path: str = "/data/MUSDB18/HQ/saded-np", + fs: int = 44100, + segment_length_second: float = 6.0, + hop_length_second: float = 3.0, + n_chunks: int = 10, + chunk_epsilon: float = 1e-5, + energy_threshold_quantile: float = 0.15, + segment_epsilon: float = 1e-3, + salient_proportion_threshold: float = 0.5, ) -> None: sad = SourceActivityDetector( - analysis_stem=analysis_stem, - output_path=output_path, - fs=fs, - segment_length_second=segment_length_second, - hop_length_second=hop_length_second, - n_chunks=n_chunks, - chunk_epsilon=chunk_epsilon, - energy_threshold_quantile=energy_threshold_quantile, - segment_epsilon=segment_epsilon, - salient_proportion_threshold=salient_proportion_threshold, + analysis_stem=analysis_stem, + output_path=output_path, + fs=fs, + segment_length_second=segment_length_second, + hop_length_second=hop_length_second, + n_chunks=n_chunks, + chunk_epsilon=chunk_epsilon, + energy_threshold_quantile=energy_threshold_quantile, + segment_epsilon=segment_epsilon, + salient_proportion_threshold=salient_proportion_threshold, ) for split in ["train", "val", "test"]: ds = MUSDB18FullTrackDataset( - data_root="/data/MUSDB18/HQ/canonical", - split=split, + data_root="/data/MUSDB18/HQ/canonical", + split=split, ) tracks = [] @@ -196,9 +189,8 @@ def preprocess( tracks.append(track) process_map(sad, tracks, max_workers=8) -def loudness_norm_one( - inputs -): + +def loudness_norm_one(inputs): infile, outfile, target_lufs = inputs audio, fs = ta.load(infile) @@ -211,25 +203,21 @@ def loudness_norm_one( os.makedirs(os.path.dirname(outfile), exist_ok=True) np.save(outfile, audio.T) + def loudness_norm( - data_path: str, - # output_path: str, - target_lufs = -17.0, + data_path: str, + # output_path: str, + target_lufs=-17.0, ): - files = glob.glob( - os.path.join(data_path, "**", "*.wav"), recursive=True - ) + files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) - outfiles = [ - f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files - ] + outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files] files = [(f, o, target_lufs) for f, o in zip(files, outfiles)] process_map(loudness_norm_one, files, chunksize=2) - if __name__ == "__main__": from tqdm import tqdm diff --git a/programs/music_separation_code/models/bandit/core/loss/__init__.py b/programs/music_separation_code/models/bandit/core/loss/__init__.py index 0ab803a..993be52 100644 --- a/programs/music_separation_code/models/bandit/core/loss/__init__.py +++ b/programs/music_separation_code/models/bandit/core/loss/__init__.py @@ -1,2 +1,8 @@ from ._multistem import MultiStemWrapperFromConfig -from ._timefreq import ReImL1Loss, ReImL2Loss, TimeFreqL1Loss, TimeFreqL2Loss, TimeFreqSignalNoisePNormRatioLoss +from ._timefreq import ( + ReImL1Loss, + ReImL2Loss, + TimeFreqL1Loss, + TimeFreqL2Loss, + TimeFreqSignalNoisePNormRatioLoss, +) diff --git a/programs/music_separation_code/models/bandit/core/loss/_complex.py b/programs/music_separation_code/models/bandit/core/loss/_complex.py index 1d97e5d..68c82f2 100644 --- a/programs/music_separation_code/models/bandit/core/loss/_complex.py +++ b/programs/music_separation_code/models/bandit/core/loss/_complex.py @@ -11,15 +11,8 @@ def __init__(self, module: _Loss) -> None: super().__init__() self.module = module - def forward( - self, - preds: torch.Tensor, - target: torch.Tensor - ) -> torch.Tensor: - return self.module( - torch.view_as_real(preds), - torch.view_as_real(target) - ) + def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + return self.module(torch.view_as_real(preds), torch.view_as_real(target)) class ReImL1Loss(ReImLossWrapper): diff --git a/programs/music_separation_code/models/bandit/core/loss/_multistem.py b/programs/music_separation_code/models/bandit/core/loss/_multistem.py index 675e0ff..e9c4a4f 100644 --- a/programs/music_separation_code/models/bandit/core/loss/_multistem.py +++ b/programs/music_separation_code/models/bandit/core/loss/_multistem.py @@ -24,16 +24,14 @@ def __init__(self, module: _Loss, modality: str = "audio") -> None: self.modality = modality def forward( - self, - preds: Dict[str, Dict[str, torch.Tensor]], - target: Dict[str, Dict[str, torch.Tensor]], + self, + preds: Dict[str, Dict[str, torch.Tensor]], + target: Dict[str, Dict[str, torch.Tensor]], ) -> torch.Tensor: loss = { - stem: self.loss( - preds[self.modality][stem], - target[self.modality][stem] - ) - for stem in preds[self.modality] if stem in target[self.modality] + stem: self.loss(preds[self.modality][stem], target[self.modality][stem]) + for stem in preds[self.modality] + if stem in target[self.modality] } return sum(list(loss.values())) diff --git a/programs/music_separation_code/models/bandit/core/loss/_timefreq.py b/programs/music_separation_code/models/bandit/core/loss/_timefreq.py index 6ea9d59..96080e8 100644 --- a/programs/music_separation_code/models/bandit/core/loss/_timefreq.py +++ b/programs/music_separation_code/models/bandit/core/loss/_timefreq.py @@ -8,14 +8,15 @@ from models.bandit.core.loss._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper from models.bandit.core.loss.snr import SignalNoisePNormRatio + class TimeFreqWrapper(_Loss): def __init__( - self, - time_module: _Loss, - freq_module: Optional[_Loss] = None, - time_weight: float = 1.0, - freq_weight: float = 1.0, - multistem: bool = True, + self, + time_module: _Loss, + freq_module: Optional[_Loss] = None, + time_weight: float = 1.0, + freq_weight: float = 1.0, + multistem: bool = True, ) -> None: super().__init__() @@ -36,42 +37,36 @@ def __init__( def forward(self, preds: Any, target: Any) -> torch.Tensor: return self.time_weight * self.time_module( - preds, target + preds, target ) + self.freq_weight * self.freq_module(preds, target) class TimeFreqL1Loss(TimeFreqWrapper): def __init__( - self, - time_weight: float = 1.0, - freq_weight: float = 1.0, - tkwargs: Optional[Dict[str, Any]] = None, - fkwargs: Optional[Dict[str, Any]] = None, - multistem: bool = True, + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, ) -> None: if tkwargs is None: tkwargs = {} if fkwargs is None: fkwargs = {} - time_module = (nn.L1Loss(**tkwargs)) + time_module = nn.L1Loss(**tkwargs) freq_module = ReImL1Loss(**fkwargs) - super().__init__( - time_module, - freq_module, - time_weight, - freq_weight, - multistem - ) + super().__init__(time_module, freq_module, time_weight, freq_weight, multistem) class TimeFreqL2Loss(TimeFreqWrapper): def __init__( - self, - time_weight: float = 1.0, - freq_weight: float = 1.0, - tkwargs: Optional[Dict[str, Any]] = None, - fkwargs: Optional[Dict[str, Any]] = None, - multistem: bool = True, + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, ) -> None: if tkwargs is None: tkwargs = {} @@ -79,24 +74,17 @@ def __init__( fkwargs = {} time_module = nn.MSELoss(**tkwargs) freq_module = ReImL2Loss(**fkwargs) - super().__init__( - time_module, - freq_module, - time_weight, - freq_weight, - multistem - ) - + super().__init__(time_module, freq_module, time_weight, freq_weight, multistem) class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper): def __init__( - self, - time_weight: float = 1.0, - freq_weight: float = 1.0, - tkwargs: Optional[Dict[str, Any]] = None, - fkwargs: Optional[Dict[str, Any]] = None, - multistem: bool = True, + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, ) -> None: if tkwargs is None: tkwargs = {} @@ -104,10 +92,4 @@ def __init__( fkwargs = {} time_module = SignalNoisePNormRatio(**tkwargs) freq_module = SignalNoisePNormRatio(**fkwargs) - super().__init__( - time_module, - freq_module, - time_weight, - freq_weight, - multistem - ) + super().__init__(time_module, freq_module, time_weight, freq_weight, multistem) diff --git a/programs/music_separation_code/models/bandit/core/loss/snr.py b/programs/music_separation_code/models/bandit/core/loss/snr.py index 2996dd5..8d712a5 100644 --- a/programs/music_separation_code/models/bandit/core/loss/snr.py +++ b/programs/music_separation_code/models/bandit/core/loss/snr.py @@ -2,15 +2,16 @@ from torch.nn.modules.loss import _Loss from torch.nn import functional as F + class SignalNoisePNormRatio(_Loss): def __init__( - self, - p: float = 1.0, - scale_invariant: bool = False, - zero_mean: bool = False, - take_log: bool = True, - reduction: str = "mean", - EPS: float = 1e-3, + self, + p: float = 1.0, + scale_invariant: bool = False, + zero_mean: bool = False, + take_log: bool = True, + reduction: str = "mean", + EPS: float = 1e-3, ) -> None: assert reduction != "sum", NotImplementedError super().__init__(reduction=reduction) @@ -23,23 +24,21 @@ def __init__( self.scale_invariant = scale_invariant - def forward( - self, - est_target: torch.Tensor, - target: torch.Tensor - ) -> torch.Tensor: + def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target_ = target if self.scale_invariant: ndim = target.ndim dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True) - s_target_energy = ( - torch.sum(target * torch.conj(target), dim=-1, keepdim=True) + s_target_energy = torch.sum( + target * torch.conj(target), dim=-1, keepdim=True ) if ndim > 2: dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True) - s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True) + s_target_energy = torch.sum( + s_target_energy, dim=list(range(1, ndim)), keepdim=True + ) target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8) target = target_ * target_scaler @@ -48,25 +47,26 @@ def forward( est_target = torch.view_as_real(est_target) target = torch.view_as_real(target) - batch_size = est_target.shape[0] est_target = est_target.reshape(batch_size, -1) target = target.reshape(batch_size, -1) # target_ = target_.reshape(batch_size, -1) if self.p == 1: - e_error = torch.abs(est_target-target).mean(dim=-1) + e_error = torch.abs(est_target - target).mean(dim=-1) e_target = torch.abs(target).mean(dim=-1) elif self.p == 2: - e_error = torch.square(est_target-target).mean(dim=-1) + e_error = torch.square(est_target - target).mean(dim=-1) e_target = torch.square(target).mean(dim=-1) else: raise NotImplementedError - + if self.take_log: - loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)) + loss = 10 * ( + torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS) + ) else: - loss = (e_error + self.EPS)/(e_target + self.EPS) + loss = (e_error + self.EPS) / (e_target + self.EPS) if self.reduction == "mean": loss = loss.mean() @@ -75,17 +75,16 @@ def forward( return loss - class MultichannelSingleSrcNegSDR(_Loss): def __init__( - self, - sdr_type: str, - p: float = 2.0, - zero_mean: bool = True, - take_log: bool = True, - reduction: str = "mean", - EPS: float = 1e-8, + self, + sdr_type: str, + p: float = 2.0, + zero_mean: bool = True, + take_log: bool = True, + reduction: str = "mean", + EPS: float = 1e-8, ) -> None: assert reduction != "sum", NotImplementedError super().__init__(reduction=reduction) @@ -98,14 +97,10 @@ def __init__( self.p = p - def forward( - self, - est_target: torch.Tensor, - target: torch.Tensor - ) -> torch.Tensor: + def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.size() != est_target.size() or target.ndim != 3: raise TypeError( - f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead" + f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead" ) # Step 1. Zero-mean norm if self.zero_mean: @@ -118,9 +113,7 @@ def forward( # [batch, 1] dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True) # [batch, 1] - s_target_energy = ( - torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS - ) + s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS # [batch, time] scaled_target = dot * target / s_target_energy else: @@ -133,12 +126,12 @@ def forward( # [batch] if self.p == 2.0: - losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / ( - torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS + losses = torch.sum(scaled_target**2, dim=[1, 2]) / ( + torch.sum(e_noise**2, dim=[1, 2]) + self.EPS ) else: losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / ( - torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS + torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS ) if self.take_log: losses = 10 * torch.log10(losses + self.EPS) diff --git a/programs/music_separation_code/models/bandit/core/metrics/_squim.py b/programs/music_separation_code/models/bandit/core/metrics/_squim.py index ec76b5f..71c993a 100644 --- a/programs/music_separation_code/models/bandit/core/metrics/_squim.py +++ b/programs/music_separation_code/models/bandit/core/metrics/_squim.py @@ -40,7 +40,10 @@ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None: self.sigmoid: nn.modules.Module = nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: - out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0] + out = ( + self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + + self.val_range[0] + ) return out @@ -72,7 +75,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SingleRNN(nn.Module): - def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None: + def __init__( + self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0 + ) -> None: super(SingleRNN, self).__init__() self.rnn_type = rnn_type @@ -144,7 +149,10 @@ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: # input shape: (B, N, T) seq_len = x.shape[-1] - rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size + rest = ( + self.chunk_size + - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size + ) out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride]) return out, rest @@ -153,18 +161,42 @@ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: out, rest = self.pad_chunk(x) batch_size, feat_dim, seq_len = out.shape - segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size) - segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size) + segments1 = ( + out[:, :, : -self.chunk_stride] + .contiguous() + .view(batch_size, feat_dim, -1, self.chunk_size) + ) + segments2 = ( + out[:, :, self.chunk_stride :] + .contiguous() + .view(batch_size, feat_dim, -1, self.chunk_size) + ) out = torch.cat([segments1, segments2], dim=3) - out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous() + out = ( + out.view(batch_size, feat_dim, -1, self.chunk_size) + .transpose(2, 3) + .contiguous() + ) return out, rest def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor: batch_size, dim, _, _ = x.shape - out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2) - out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :] - out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride] + out = ( + x.transpose(2, 3) + .contiguous() + .view(batch_size, dim, -1, self.chunk_size * 2) + ) + out1 = ( + out[:, :, :, : self.chunk_size] + .contiguous() + .view(batch_size, dim, -1)[:, :, self.chunk_stride :] + ) + out2 = ( + out[:, :, :, self.chunk_size :] + .contiguous() + .view(batch_size, dim, -1)[:, :, : -self.chunk_stride] + ) out = out1 + out2 if rest > 0: out = out[:, :, :-rest] @@ -175,16 +207,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x, rest = self.chunking(x) batch_size, _, dim1, dim2 = x.shape out = x - for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm): - row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous() + for row_rnn, row_norm, col_rnn, col_norm in zip( + self.row_rnn, self.row_norm, self.col_rnn, self.col_norm + ): + row_in = ( + out.permute(0, 3, 2, 1) + .contiguous() + .view(batch_size * dim2, dim1, -1) + .contiguous() + ) row_out = row_rnn(row_in) - row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous() + row_out = ( + row_out.view(batch_size, dim2, dim1, -1) + .permute(0, 3, 2, 1) + .contiguous() + ) row_out = row_norm(row_out) out = out + row_out - col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous() + col_in = ( + out.permute(0, 2, 3, 1) + .contiguous() + .view(batch_size * dim1, dim2, -1) + .contiguous() + ) col_out = col_rnn(col_in) - col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous() + col_out = ( + col_out.view(batch_size, dim1, dim2, -1) + .permute(0, 3, 1, 2) + .contiguous() + ) col_out = col_norm(col_out) out = out + col_out out = self.conv(out) @@ -236,7 +288,9 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`. """ if x.ndim != 2: - raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.") + raise ValueError( + f"The input must be a 2D Tensor. Found dimension {x.ndim}." + ) x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20) out = self.encoder(x) out = self.dprnn(out) @@ -257,7 +311,9 @@ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module: Returns: (nn.Module): Returned module to predict corresponding metric score. """ - layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True) + layer1 = nn.TransformerEncoderLayer( + d_model, nhead, d_model * 4, dropout=0.0, batch_first=True + ) layer2 = AutoPool() if metric == "stoi": layer3 = nn.Sequential( @@ -274,7 +330,9 @@ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module: RangeSigmoid(val_range=PESQRange), ) else: - layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)) + layer3: nn.modules.Module = nn.Sequential( + nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1) + ) return nn.Sequential(layer1, layer2, layer3) @@ -305,7 +363,9 @@ def squim_objective_model( if chunk_stride is None: chunk_stride = chunk_size // 2 encoder = Encoder(feat_dim, win_len) - dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride) + dprnn = DPRNN( + feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride + ) branches = nn.ModuleList( [ _create_branch(d_model, nhead, "stoi"), @@ -329,6 +389,7 @@ def squim_objective_base() -> SquimObjective: chunk_size=71, ) + @dataclass class SquimObjectiveBundle: @@ -380,4 +441,3 @@ def sample_rate(self): Please refer to :py:class:`SquimObjectiveBundle` for usage instructions. """ - diff --git a/programs/music_separation_code/models/bandit/core/metrics/snr.py b/programs/music_separation_code/models/bandit/core/metrics/snr.py index d2830b2..6b7a168 100644 --- a/programs/music_separation_code/models/bandit/core/metrics/snr.py +++ b/programs/music_separation_code/models/bandit/core/metrics/snr.py @@ -25,11 +25,11 @@ def compute(self) -> Any: class BaseChunkMedianSignalRatio(tm.Metric): def __init__( - self, - func: Callable, - window_size: int, - hop_size: int = None, - zero_mean: bool = False, + self, + func: Callable, + window_size: int, + hop_size: int = None, + zero_mean: bool = False, ) -> None: super().__init__() @@ -40,20 +40,14 @@ def __init__( hop_size = window_size self.hop_size = hop_size - self.add_state( - "sum_snr", - default=torch.tensor(0.0), - dist_reduce_fx="sum" - ) + self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: n_samples = target.shape[-1] - n_chunks = int( - np.ceil((n_samples - self.window_size) / self.hop_size) + 1 - ) + n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1) snr_chunk = [] @@ -66,10 +60,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: end = start + self.window_size try: - chunk_snr = self.func( - preds[..., start:end], - target[..., start:end] - ) + chunk_snr = self.func(preds[..., start:end], target[..., start:end]) # print(preds.shape, chunk_snr.shape) @@ -90,61 +81,47 @@ def compute(self) -> Any: class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio): def __init__( - self, - window_size: int, - hop_size: int = None, - zero_mean: bool = False + self, window_size: int, hop_size: int = None, zero_mean: bool = False ) -> None: super().__init__( - func=tmF.signal_noise_ratio, - window_size=window_size, - hop_size=hop_size, - zero_mean=zero_mean, + func=tmF.signal_noise_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, ) class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio): def __init__( - self, - window_size: int, - hop_size: int = None, - zero_mean: bool = False + self, window_size: int, hop_size: int = None, zero_mean: bool = False ) -> None: super().__init__( - func=tmF.scale_invariant_signal_noise_ratio, - window_size=window_size, - hop_size=hop_size, - zero_mean=zero_mean, + func=tmF.scale_invariant_signal_noise_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, ) class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio): def __init__( - self, - window_size: int, - hop_size: int = None, - zero_mean: bool = False + self, window_size: int, hop_size: int = None, zero_mean: bool = False ) -> None: super().__init__( - func=tmF.signal_distortion_ratio, - window_size=window_size, - hop_size=hop_size, - zero_mean=zero_mean, + func=tmF.signal_distortion_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, ) -class ChunkMedianScaleInvariantSignalDistortionRatio( - BaseChunkMedianSignalRatio - ): +class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio): def __init__( - self, - window_size: int, - hop_size: int = None, - zero_mean: bool = False + self, window_size: int, hop_size: int = None, zero_mean: bool = False ) -> None: super().__init__( - func=tmF.scale_invariant_signal_distortion_ratio, - window_size=window_size, - hop_size=hop_size, - zero_mean=zero_mean, + func=tmF.scale_invariant_signal_distortion_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, ) diff --git a/programs/music_separation_code/models/bandit/core/model/_spectral.py b/programs/music_separation_code/models/bandit/core/model/_spectral.py index 564cd28..6af5cbd 100644 --- a/programs/music_separation_code/models/bandit/core/model/_spectral.py +++ b/programs/music_separation_code/models/bandit/core/model/_spectral.py @@ -7,18 +7,18 @@ class _SpectralComponent(nn.Module): def __init__( - self, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, - **kwargs, + self, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + **kwargs, ) -> None: super().__init__() @@ -26,33 +26,29 @@ def __init__( window_fn = torch.__dict__[window_fn] - self.stft = ( - ta.transforms.Spectrogram( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - pad_mode=pad_mode, - pad=0, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - normalized=normalized, - center=center, - onesided=onesided, - ) + self.stft = ta.transforms.Spectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + normalized=normalized, + center=center, + onesided=onesided, ) - self.istft = ( - ta.transforms.InverseSpectrogram( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - pad_mode=pad_mode, - pad=0, - window_fn=window_fn, - wkwargs=wkwargs, - normalized=normalized, - center=center, - onesided=onesided, - ) + self.istft = ta.transforms.InverseSpectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + normalized=normalized, + center=center, + onesided=onesided, ) diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/bandsplit.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/bandsplit.py index 63e6255..4321765 100644 --- a/programs/music_separation_code/models/bandit/core/model/bsrnn/bandsplit.py +++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/bandsplit.py @@ -13,12 +13,12 @@ class NormFC(nn.Module): def __init__( - self, - emb_dim: int, - bandwidth: int, - in_channel: int, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, + self, + emb_dim: int, + bandwidth: int, + in_channel: int, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, ) -> None: super().__init__() @@ -67,14 +67,14 @@ def forward(self, xb): class BandSplitModule(nn.Module): def __init__( - self, - band_specs: List[Tuple[float, float]], - emb_dim: int, - in_channel: int, - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + in_channel: int, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, ) -> None: super().__init__() @@ -94,18 +94,18 @@ def __init__( self.emb_dim = emb_dim self.norm_fc_modules = nn.ModuleList( - [ # type: ignore - ( - NormFC( - emb_dim=emb_dim, - bandwidth=bw, - in_channel=in_channel, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - ) - ) - for bw in self.band_widths - ] + [ # type: ignore + ( + NormFC( + emb_dim=emb_dim, + bandwidth=bw, + in_channel=in_channel, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + ) + ) + for bw in self.band_widths + ] ) def forward(self, x: torch.Tensor): @@ -114,15 +114,11 @@ def forward(self, x: torch.Tensor): batch, in_chan, _, n_time = x.shape z = torch.zeros( - size=(batch, self.n_bands, n_time, self.emb_dim), - device=x.device + size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device ) xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2 - xr = torch.permute( - xr, - (0, 3, 1, 4, 2) - ) # batch, n_time, in_chan, 2, n_freq + xr = torch.permute(xr, (0, 3, 1, 4, 2)) # batch, n_time, in_chan, 2, n_freq batch, n_time, in_chan, reim, band_width = xr.shape for i, nfm in enumerate(self.norm_fc_modules): # print(f"bandsplit/band{i:02d}") diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/core.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/core.py index 7fd3625..1dbfb32 100644 --- a/programs/music_separation_code/models/bandit/core/model/bsrnn/core.py +++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/core.py @@ -8,12 +8,12 @@ from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule from models.bandit.core.model.bsrnn.maskestim import ( MaskEstimationModule, - OverlappingMaskEstimationModule + OverlappingMaskEstimationModule, ) from models.bandit.core.model.bsrnn.tfmodel import ( ConvolutionalTimeFreqModule, SeqBandModellingModule, - TransformerTimeFreqModule + TransformerTimeFreqModule, ) @@ -36,7 +36,6 @@ def forward(self, x, cond=None, compute_residual: bool = True): q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) # print(q) - # if torch.any(torch.isnan(q)): # raise ValueError("q nan") @@ -54,25 +53,23 @@ def forward(self, x, cond=None, compute_residual: bool = True): return {"spectrogram": out} - - - def instantiate_mask_estim(self, - in_channel: int, - stems: List[str], - band_specs: List[Tuple[float, float]], - emb_dim: int, - mlp_dim: int, - cond_dim: int, - hidden_activation: str, - - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - overlapping_band: bool = False, - freq_weights: Optional[List[torch.Tensor]] = None, - n_freq: Optional[int] = None, - use_freq_weights: bool = True, - mult_add_mask: bool = False - ): + def instantiate_mask_estim( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + cond_dim: int, + hidden_activation: str, + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + mult_add_mask: bool = False, + ): if hidden_activation_kwargs is None: hidden_activation_kwargs = {} @@ -86,75 +83,77 @@ def instantiate_mask_estim(self, if mult_add_mask: self.mask_estim = nn.ModuleDict( - { - stem: MultAddMaskEstimationModule( - band_specs=band_specs, - freq_weights=freq_weights, - n_freq=n_freq, - emb_dim=emb_dim, - mlp_dim=mlp_dim, - in_channel=in_channel, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - use_freq_weights=use_freq_weights, - ) - for stem in stems - } + { + stem: MultAddMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + use_freq_weights=use_freq_weights, + ) + for stem in stems + } ) else: self.mask_estim = nn.ModuleDict( - { - stem: OverlappingMaskEstimationModule( - band_specs=band_specs, - freq_weights=freq_weights, - n_freq=n_freq, - emb_dim=emb_dim, - mlp_dim=mlp_dim, - in_channel=in_channel, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - use_freq_weights=use_freq_weights, - ) - for stem in stems - } + { + stem: OverlappingMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + use_freq_weights=use_freq_weights, + ) + for stem in stems + } ) else: self.mask_estim = nn.ModuleDict( - { - stem: MaskEstimationModule( - band_specs=band_specs, - emb_dim=emb_dim, - mlp_dim=mlp_dim, - cond_dim=cond_dim, - in_channel=in_channel, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - ) - for stem in stems - } + { + stem: MaskEstimationModule( + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for stem in stems + } ) - def instantiate_bandsplit(self, - in_channel: int, - band_specs: List[Tuple[float, float]], - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - emb_dim: int = 128 - ): + def instantiate_bandsplit( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + emb_dim: int = 128, + ): self.band_split = BandSplitModule( - in_channel=in_channel, - band_specs=band_specs, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - emb_dim=emb_dim, - ) + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + class SingleMaskBandsplitCoreBase(BandsplitCoreBase): def __init__(self, **kwargs) -> None: @@ -172,169 +171,166 @@ def forward(self, x): class SingleMaskBandsplitCoreRNN( - SingleMaskBandsplitCoreBase, + SingleMaskBandsplitCoreBase, ): def __init__( - self, - in_channel: int, - band_specs: List[Tuple[float, float]], - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - rnn_type: str = "LSTM", - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, ) -> None: super().__init__() - self.band_split = (BandSplitModule( - in_channel=in_channel, - band_specs=band_specs, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - emb_dim=emb_dim, - )) - self.tf_model = (SeqBandModellingModule( - n_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - )) - self.mask_estim = (MaskEstimationModule( - in_channel=in_channel, - band_specs=band_specs, - emb_dim=emb_dim, - mlp_dim=mlp_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - )) + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + self.tf_model = SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + self.mask_estim = MaskEstimationModule( + in_channel=in_channel, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) class SingleMaskBandsplitCoreTransformer( - SingleMaskBandsplitCoreBase, + SingleMaskBandsplitCoreBase, ): def __init__( - self, - in_channel: int, - band_specs: List[Tuple[float, float]], - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - tf_dropout: float = 0.0, - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, ) -> None: super().__init__() self.band_split = BandSplitModule( - in_channel=in_channel, - band_specs=band_specs, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - emb_dim=emb_dim, + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, ) self.tf_model = TransformerTimeFreqModule( - n_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - dropout=tf_dropout, + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, ) self.mask_estim = MaskEstimationModule( - in_channel=in_channel, - band_specs=band_specs, - emb_dim=emb_dim, - mlp_dim=mlp_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, + in_channel=in_channel, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, ) class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase): def __init__( - self, - in_channel: int, - stems: List[str], - band_specs: List[Tuple[float, float]], - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - rnn_type: str = "LSTM", - mlp_dim: int = 512, - cond_dim: int = 0, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - overlapping_band: bool = False, - freq_weights: Optional[List[torch.Tensor]] = None, - n_freq: Optional[int] = None, - use_freq_weights: bool = True, - mult_add_mask: bool = False + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + cond_dim: int = 0, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + mult_add_mask: bool = False, ) -> None: super().__init__() self.instantiate_bandsplit( - in_channel=in_channel, - band_specs=band_specs, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - emb_dim=emb_dim + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, ) - - self.tf_model = ( - SeqBandModellingModule( - n_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - ) + self.tf_model = SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, ) self.mult_add_mask = mult_add_mask self.instantiate_mask_estim( - in_channel=in_channel, - stems=stems, - band_specs=band_specs, - emb_dim=emb_dim, - mlp_dim=mlp_dim, - cond_dim=cond_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - overlapping_band=overlapping_band, - freq_weights=freq_weights, - n_freq=n_freq, - use_freq_weights=use_freq_weights, - mult_add_mask=mult_add_mask + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, ) @staticmethod @@ -358,133 +354,132 @@ def mask(self, x, m): class MultiSourceMultiMaskBandSplitCoreTransformer( - MultiMaskBandSplitCoreBase, + MultiMaskBandSplitCoreBase, ): def __init__( - self, - in_channel: int, - stems: List[str], - band_specs: List[Tuple[float, float]], - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - tf_dropout: float = 0.0, - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - overlapping_band: bool = False, - freq_weights: Optional[List[torch.Tensor]] = None, - n_freq: Optional[int] = None, - use_freq_weights:bool=True, - rnn_type: str = "LSTM", - cond_dim: int = 0, - mult_add_mask: bool = False + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + rnn_type: str = "LSTM", + cond_dim: int = 0, + mult_add_mask: bool = False, ) -> None: super().__init__() self.instantiate_bandsplit( - in_channel=in_channel, - band_specs=band_specs, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - emb_dim=emb_dim + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, ) self.tf_model = TransformerTimeFreqModule( - n_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - dropout=tf_dropout, + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, ) - + self.instantiate_mask_estim( - in_channel=in_channel, - stems=stems, - band_specs=band_specs, - emb_dim=emb_dim, - mlp_dim=mlp_dim, - cond_dim=cond_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - overlapping_band=overlapping_band, - freq_weights=freq_weights, - n_freq=n_freq, - use_freq_weights=use_freq_weights, - mult_add_mask=mult_add_mask + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, ) - class MultiSourceMultiMaskBandSplitCoreConv( - MultiMaskBandSplitCoreBase, + MultiMaskBandSplitCoreBase, ): def __init__( - self, - in_channel: int, - stems: List[str], - band_specs: List[Tuple[float, float]], - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - tf_dropout: float = 0.0, - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - overlapping_band: bool = False, - freq_weights: Optional[List[torch.Tensor]] = None, - n_freq: Optional[int] = None, - use_freq_weights:bool=True, - rnn_type: str = "LSTM", - cond_dim: int = 0, - mult_add_mask: bool = False + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + rnn_type: str = "LSTM", + cond_dim: int = 0, + mult_add_mask: bool = False, ) -> None: super().__init__() self.instantiate_bandsplit( - in_channel=in_channel, - band_specs=band_specs, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - emb_dim=emb_dim + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, ) self.tf_model = ConvolutionalTimeFreqModule( - n_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - dropout=tf_dropout, + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, ) - + self.instantiate_mask_estim( - in_channel=in_channel, - stems=stems, - band_specs=band_specs, - emb_dim=emb_dim, - mlp_dim=mlp_dim, - cond_dim=cond_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - overlapping_band=overlapping_band, - freq_weights=freq_weights, - n_freq=n_freq, - use_freq_weights=use_freq_weights, - mult_add_mask=mult_add_mask + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, ) @@ -500,40 +495,40 @@ def mask(self, x, m): padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2) xf = F.unfold( - x, - kernel_size=(kernel_freq, kernel_time), - padding=padding, - stride=(1, 1), + x, + kernel_size=(kernel_freq, kernel_time), + padding=padding, + stride=(1, 1), ) xf = xf.view( - -1, - n_channel, - kernel_freq, - kernel_time, - n_freq, - n_time, + -1, + n_channel, + kernel_freq, + kernel_time, + n_freq, + n_time, ) sf = xf * m sf = sf.view( - -1, - n_channel * kernel_freq * kernel_time, - n_freq * n_time, + -1, + n_channel * kernel_freq * kernel_time, + n_freq * n_time, ) s = F.fold( - sf, - output_size=(n_freq, n_time), - kernel_size=(kernel_freq, kernel_time), - padding=padding, - stride=(1, 1), + sf, + output_size=(n_freq, n_time), + kernel_size=(kernel_freq, kernel_time), + padding=padding, + stride=(1, 1), ).view( - -1, - n_channel, - n_freq, - n_time, + -1, + n_channel, + n_freq, + n_time, ) return s @@ -570,64 +565,59 @@ def old_mask(self, x, m): fslice = slice(max(0, df), min(n_freq, n_freq + df)) tslice = slice(max(0, dt), min(n_time, n_time + dt)) - s[:, :, fslice, tslice] += x[:, :, fslice, tslice] * m[ifreq, - itime, :, - :, fslice, - tslice] + s[:, :, fslice, tslice] += ( + x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice] + ) return s -class MultiSourceMultiPatchingMaskBandSplitCoreRNN( - PatchingMaskBandsplitCoreBase -): +class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase): def __init__( - self, - in_channel: int, - stems: List[str], - band_specs: List[Tuple[float, float]], - mask_kernel_freq: int, - mask_kernel_time: int, - conv_kernel_freq: int, - conv_kernel_time: int, - kernel_norm_mlp_version: int, - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - rnn_type: str = "LSTM", - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - overlapping_band: bool = False, - freq_weights: Optional[List[torch.Tensor]] = None, - n_freq: Optional[int] = None, + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + mask_kernel_freq: int, + mask_kernel_time: int, + conv_kernel_freq: int, + conv_kernel_time: int, + kernel_norm_mlp_version: int, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, ) -> None: super().__init__() self.band_split = BandSplitModule( - in_channel=in_channel, - band_specs=band_specs, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - emb_dim=emb_dim, + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, ) - self.tf_model = ( - SeqBandModellingModule( - n_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - ) + self.tf_model = SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, ) if hidden_activation_kwargs is None: @@ -637,25 +627,25 @@ def __init__( assert freq_weights is not None assert n_freq is not None self.mask_estim = nn.ModuleDict( - { - stem: PatchingMaskEstimationModule( - band_specs=band_specs, - freq_weights=freq_weights, - n_freq=n_freq, - emb_dim=emb_dim, - mlp_dim=mlp_dim, - in_channel=in_channel, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - mask_kernel_freq=mask_kernel_freq, - mask_kernel_time=mask_kernel_time, - conv_kernel_freq=conv_kernel_freq, - conv_kernel_time=conv_kernel_time, - kernel_norm_mlp_version=kernel_norm_mlp_version - ) - for stem in stems - } + { + stem: PatchingMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + mask_kernel_freq=mask_kernel_freq, + mask_kernel_time=mask_kernel_time, + conv_kernel_freq=conv_kernel_freq, + conv_kernel_time=conv_kernel_time, + kernel_norm_mlp_version=kernel_norm_mlp_version, + ) + for stem in stems + } ) else: raise NotImplementedError diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/maskestim.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/maskestim.py index 0b9289d..6049596 100644 --- a/programs/music_separation_code/models/bandit/core/model/bsrnn/maskestim.py +++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/maskestim.py @@ -1,4 +1,3 @@ -import warnings from typing import Dict, List, Optional, Tuple, Type import torch @@ -15,26 +14,27 @@ class BaseNormMLP(nn.Module): def __init__( - self, - emb_dim: int, - mlp_dim: int, - bandwidth: int, - in_channel: Optional[int], - hidden_activation: str = "Tanh", - hidden_activation_kwargs=None, - complex_mask: bool = True, ): + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ): super().__init__() if hidden_activation_kwargs is None: hidden_activation_kwargs = {} self.hidden_activation_kwargs = hidden_activation_kwargs self.norm = nn.LayerNorm(emb_dim) - self.hidden = torch.jit.script(nn.Sequential( + self.hidden = torch.jit.script( + nn.Sequential( nn.Linear(in_features=emb_dim, out_features=mlp_dim), - activation.__dict__[hidden_activation]( - **self.hidden_activation_kwargs - ), - )) + activation.__dict__[hidden_activation](**self.hidden_activation_kwargs), + ) + ) self.bandwidth = bandwidth self.in_channel = in_channel @@ -46,33 +46,33 @@ def __init__( class NormMLP(BaseNormMLP): def __init__( - self, - emb_dim: int, - mlp_dim: int, - bandwidth: int, - in_channel: Optional[int], - hidden_activation: str = "Tanh", - hidden_activation_kwargs=None, - complex_mask: bool = True, + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, ) -> None: super().__init__( - emb_dim=emb_dim, - mlp_dim=mlp_dim, - bandwidth=bandwidth, - in_channel=in_channel, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + bandwidth=bandwidth, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, ) self.output = torch.jit.script( - nn.Sequential( - nn.Linear( - in_features=mlp_dim, - out_features=bandwidth * in_channel * self.reim * 2, - ), - nn.GLU(dim=-1), - ) + nn.Sequential( + nn.Linear( + in_features=mlp_dim, + out_features=bandwidth * in_channel * self.reim * 2, + ), + nn.GLU(dim=-1), + ) ) def reshape_output(self, mb): @@ -80,23 +80,14 @@ def reshape_output(self, mb): batch, n_time, _ = mb.shape if self.complex_mask: mb = mb.reshape( - batch, - n_time, - self.in_channel, - self.bandwidth, - self.reim + batch, n_time, self.in_channel, self.bandwidth, self.reim ).contiguous() # print(mb.shape) - mb = torch.view_as_complex( - mb - ) # (batch, n_time, in_channel, bandwidth) + mb = torch.view_as_complex(mb) # (batch, n_time, in_channel, bandwidth) else: mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth) - mb = torch.permute( - mb, - (0, 2, 3, 1) - ) # (batch, in_channel, bandwidth, n_time) + mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channel, bandwidth, n_time) return mb @@ -106,7 +97,6 @@ def forward(self, qb): # if torch.any(torch.isnan(qb)): # raise ValueError("qb0") - qb = self.norm(qb) # (batch, n_time, emb_dim) # if torch.any(torch.isnan(qb)): @@ -124,17 +114,34 @@ def forward(self, qb): class MultAddNormMLP(NormMLP): - def __init__(self, emb_dim: int, mlp_dim: int, bandwidth: int, in_channel: "int | None", hidden_activation: str = "Tanh", hidden_activation_kwargs=None, complex_mask: bool = True) -> None: - super().__init__(emb_dim, mlp_dim, bandwidth, in_channel, hidden_activation, hidden_activation_kwargs, complex_mask) + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channel: "int | None", + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ) -> None: + super().__init__( + emb_dim, + mlp_dim, + bandwidth, + in_channel, + hidden_activation, + hidden_activation_kwargs, + complex_mask, + ) self.output2 = torch.jit.script( - nn.Sequential( - nn.Linear( - in_features=mlp_dim, - out_features=bandwidth * in_channel * self.reim * 2, - ), - nn.GLU(dim=-1), - ) + nn.Sequential( + nn.Linear( + in_features=mlp_dim, + out_features=bandwidth * in_channel * self.reim * 2, + ), + nn.GLU(dim=-1), + ) ) def forward(self, qb): @@ -155,16 +162,16 @@ class MaskEstimationModuleSuperBase(nn.Module): class MaskEstimationModuleBase(MaskEstimationModuleSuperBase): def __init__( - self, - band_specs: List[Tuple[float, float]], - emb_dim: int, - mlp_dim: int, - in_channel: Optional[int], - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Dict = None, - complex_mask: bool = True, - norm_mlp_cls: Type[nn.Module] = NormMLP, - norm_mlp_kwargs: Dict = None, + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, ) -> None: super().__init__() @@ -178,21 +185,21 @@ def __init__( norm_mlp_kwargs = {} self.norm_mlp = nn.ModuleList( - [ - ( - norm_mlp_cls( - bandwidth=self.band_widths[b], - emb_dim=emb_dim, - mlp_dim=mlp_dim, - in_channel=in_channel, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - **norm_mlp_kwargs, - ) - ) - for b in range(self.n_bands) - ] + [ + ( + norm_mlp_cls( + bandwidth=self.band_widths[b], + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + **norm_mlp_kwargs, + ) + ) + for b in range(self.n_bands) + ] ) def compute_masks(self, q): @@ -209,23 +216,22 @@ def compute_masks(self, q): return masks - class OverlappingMaskEstimationModule(MaskEstimationModuleBase): def __init__( - self, - in_channel: int, - band_specs: List[Tuple[float, float]], - freq_weights: List[torch.Tensor], - n_freq: int, - emb_dim: int, - mlp_dim: int, - cond_dim: int = 0, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Dict = None, - complex_mask: bool = True, - norm_mlp_cls: Type[nn.Module] = NormMLP, - norm_mlp_kwargs: Dict = None, - use_freq_weights: bool = True, + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + freq_weights: List[torch.Tensor], + n_freq: int, + emb_dim: int, + mlp_dim: int, + cond_dim: int = 0, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + use_freq_weights: bool = True, ) -> None: check_nonzero_bandwidth(band_specs) check_no_gap(band_specs) @@ -234,15 +240,15 @@ def __init__( # raise NotImplementedError super().__init__( - band_specs=band_specs, - emb_dim=emb_dim + cond_dim, - mlp_dim=mlp_dim, - in_channel=in_channel, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - norm_mlp_cls=norm_mlp_cls, - norm_mlp_kwargs=norm_mlp_kwargs, + band_specs=band_specs, + emb_dim=emb_dim + cond_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + norm_mlp_cls=norm_mlp_cls, + norm_mlp_kwargs=norm_mlp_kwargs, ) self.n_freq = n_freq @@ -276,22 +282,22 @@ def forward(self, q, cond=None): q = torch.cat([q, cond], dim=-1) elif self.cond_dim > 0: cond = torch.ones( - (batch, n_bands, n_time, self.cond_dim), - device=q.device, - dtype=q.dtype, + (batch, n_bands, n_time, self.cond_dim), + device=q.device, + dtype=q.dtype, ) q = torch.cat([q, cond], dim=-1) else: pass mask_list = self.compute_masks( - q + q ) # [n_bands * (batch, in_channel, bandwidth, n_time)] masks = torch.zeros( - (batch, self.in_channel, self.n_freq, n_time), - device=q.device, - dtype=mask_list[0].dtype, + (batch, self.in_channel, self.n_freq, n_time), + device=q.device, + dtype=mask_list[0].dtype, ) for im, mask in enumerate(mask_list): @@ -306,42 +312,39 @@ def forward(self, q, cond=None): class MaskEstimationModule(OverlappingMaskEstimationModule): def __init__( - self, - band_specs: List[Tuple[float, float]], - emb_dim: int, - mlp_dim: int, - in_channel: Optional[int], - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Dict = None, - complex_mask: bool = True, - **kwargs, + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + **kwargs, ) -> None: check_nonzero_bandwidth(band_specs) check_no_gap(band_specs) check_no_overlap(band_specs) super().__init__( - in_channel=in_channel, - band_specs=band_specs, - freq_weights=None, - n_freq=None, - emb_dim=emb_dim, - mlp_dim=mlp_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, + in_channel=in_channel, + band_specs=band_specs, + freq_weights=None, + n_freq=None, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, ) def forward(self, q, cond=None): # q = (batch, n_bands, n_time, emb_dim) masks = self.compute_masks( - q + q ) # [n_bands * (batch, in_channel, bandwidth, n_time)] # TODO: currently this requires band specs to have no gap and no overlap - masks = torch.concat( - masks, - dim=2 - ) # (batch, in_channel, n_freq, n_time) + masks = torch.concat(masks, dim=2) # (batch, in_channel, n_freq, n_time) return masks diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/tfmodel.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/tfmodel.py index ba71079..f482a11 100644 --- a/programs/music_separation_code/models/bandit/core/model/bsrnn/tfmodel.py +++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/tfmodel.py @@ -15,13 +15,13 @@ def __init__(self) -> None: class ResidualRNN(nn.Module): def __init__( - self, - emb_dim: int, - rnn_dim: int, - bidirectional: bool = True, - rnn_type: str = "LSTM", - use_batch_trick: bool = True, - use_layer_norm: bool = True, + self, + emb_dim: int, + rnn_dim: int, + bidirectional: bool = True, + rnn_type: str = "LSTM", + use_batch_trick: bool = True, + use_layer_norm: bool = True, ) -> None: # n_group is the size of the 2nd dim super().__init__() @@ -33,16 +33,15 @@ def __init__( self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim) self.rnn = rnn.__dict__[rnn_type]( - input_size=emb_dim, - hidden_size=rnn_dim, - num_layers=1, - batch_first=True, - bidirectional=bidirectional, + input_size=emb_dim, + hidden_size=rnn_dim, + num_layers=1, + batch_first=True, + bidirectional=bidirectional, ) self.fc = nn.Linear( - in_features=rnn_dim * (2 if bidirectional else 1), - out_features=emb_dim + in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim ) self.use_batch_trick = use_batch_trick @@ -60,13 +59,13 @@ def forward(self, z): z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim) else: z = torch.permute( - z, (0, 3, 1, 2) + z, (0, 3, 1, 2) ) # (batch, emb_dim, n_uncrossed, n_across) z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across) z = torch.permute( - z, (0, 2, 3, 1) + z, (0, 2, 3, 1) ) # (batch, n_uncrossed, n_across, emb_dim) batch, n_uncrossed, n_across, emb_dim = z.shape @@ -74,7 +73,9 @@ def forward(self, z): if self.use_batch_trick: z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim)) - z = self.rnn(z.contiguous())[0] # (batch * n_uncrossed, n_across, dir_rnn_dim) + z = self.rnn(z.contiguous())[ + 0 + ] # (batch * n_uncrossed, n_across, dir_rnn_dim) z = torch.reshape(z, (batch, n_uncrossed, n_across, -1)) # (batch, n_uncrossed, n_across, dir_rnn_dim) @@ -85,10 +86,7 @@ def forward(self, z): zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim) zlist.append(zi) - z = torch.stack( - zlist, - dim=1 - ) # (batch, n_uncrossed, n_across, dir_rnn_dim) + z = torch.stack(zlist, dim=1) # (batch, n_uncrossed, n_across, dir_rnn_dim) z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim) @@ -99,13 +97,13 @@ def forward(self, z): class SeqBandModellingModule(TimeFrequencyModellingModule): def __init__( - self, - n_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - rnn_type: str = "LSTM", - parallel_mode=False, + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + parallel_mode=False, ) -> None: super().__init__() self.seqband = nn.ModuleList([]) @@ -113,31 +111,33 @@ def __init__( if parallel_mode: for _ in range(n_modules): self.seqband.append( - nn.ModuleList( - [ResidualRNN( - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - ), - ResidualRNN( - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - )] - ) + nn.ModuleList( + [ + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + ] + ) ) else: for _ in range(2 * n_modules): self.seqband.append( - ResidualRNN( - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - ) + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) ) self.parallel_mode = parallel_mode @@ -149,8 +149,8 @@ def forward(self, z): for sbm_pair in self.seqband: # z: (batch, n_bands, n_time, emb_dim) sbm_t, sbm_f = sbm_pair[0], sbm_pair[1] - zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim) - zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim) + zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim) + zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim) z = zt + zf.transpose(1, 2) else: for sbm in self.seqband: @@ -169,20 +169,17 @@ def forward(self, z): class ResidualTransformer(nn.Module): def __init__( - self, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - dropout: float = 0.0, + self, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, ) -> None: # n_group is the size of the 2nd dim super().__init__() self.tf = nn.TransformerEncoderLayer( - d_model=emb_dim, - nhead=4, - dim_feedforward=rnn_dim, - batch_first=True + d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True ) self.is_causal = not bidirectional @@ -191,7 +188,9 @@ def __init__( def forward(self, z): batch, n_uncrossed, n_across, emb_dim = z.shape z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim)) - z = self.tf(z, is_causal=self.is_causal) # (batch, n_uncrossed, n_across, emb_dim) + z = self.tf( + z, is_causal=self.is_causal + ) # (batch, n_uncrossed, n_across, emb_dim) z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim)) return z @@ -199,12 +198,12 @@ def forward(self, z): class TransformerTimeFreqModule(TimeFrequencyModellingModule): def __init__( - self, - n_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - dropout: float = 0.0, + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, ) -> None: super().__init__() self.norm = nn.LayerNorm(emb_dim) @@ -212,12 +211,12 @@ def __init__( for _ in range(2 * n_modules): self.seqband.append( - ResidualTransformer( - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - dropout=dropout, - ) + ResidualTransformer( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=dropout, + ) ) def forward(self, z): @@ -238,14 +237,13 @@ def forward(self, z): return q # (batch, n_bands, n_time, emb_dim) - class ResidualConvolution(nn.Module): def __init__( - self, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - dropout: float = 0.0, + self, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, ) -> None: # n_group is the size of the 2nd dim super().__init__() @@ -258,22 +256,21 @@ def __init__( kernel_size=(3, 3), padding="same", stride=(1, 1), - ), - nn.Tanhshrink() + ), + nn.Tanhshrink(), ) self.is_causal = not bidirectional self.dropout = dropout self.fc = nn.Conv2d( - in_channels=rnn_dim, - out_channels=emb_dim, - kernel_size=(1, 1), - padding="same", - stride=(1, 1), + in_channels=rnn_dim, + out_channels=emb_dim, + kernel_size=(1, 1), + padding="same", + stride=(1, 1), ) - def forward(self, z): # z = (batch, n_uncrossed, n_across, emb_dim) @@ -289,29 +286,35 @@ def forward(self, z): class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule): def __init__( - self, - n_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - dropout: float = 0.0, + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, ) -> None: super().__init__() - self.seqband = torch.jit.script(nn.Sequential( - *[ResidualConvolution( - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - dropout=dropout, - ) for _ in range(2 * n_modules) ])) + self.seqband = torch.jit.script( + nn.Sequential( + *[ + ResidualConvolution( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=dropout, + ) + for _ in range(2 * n_modules) + ] + ) + ) def forward(self, z): # z = (batch, n_bands, n_time, emb_dim) - z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time) + z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time) - z = self.seqband(z) # (batch, emb_dim, n_bands, n_time) + z = self.seqband(z) # (batch, emb_dim, n_bands, n_time) - z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim) + z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim) return z diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/utils.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/utils.py index bf8636e..d5f32ba 100644 --- a/programs/music_separation_code/models/bandit/core/model/bsrnn/utils.py +++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/utils.py @@ -1,6 +1,6 @@ import os from abc import abstractmethod -from typing import Any, Callable +from typing import Callable import numpy as np import torch @@ -70,12 +70,7 @@ def hertz_to_index(self, hz: float, round: bool = True): return index - def get_band_specs_with_bandwidth( - self, - start_index, - end_index, - bandwidth_hz - ): + def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz): band_specs = [] lower = start_index @@ -105,110 +100,84 @@ def get_band_specs(self): @property def version1(self): return self.get_band_specs_with_bandwidth( - start_index=0, end_index=self.max_index, bandwidth_hz=1000 + start_index=0, end_index=self.max_index, bandwidth_hz=1000 ) def version2(self): below16k = self.get_band_specs_with_bandwidth( - start_index=0, end_index=self.split16k, bandwidth_hz=1000 + start_index=0, end_index=self.split16k, bandwidth_hz=1000 ) below20k = self.get_band_specs_with_bandwidth( - start_index=self.split16k, - end_index=self.split20k, - bandwidth_hz=2000 + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 ) return below16k + below20k + self.above20k def version3(self): below8k = self.get_band_specs_with_bandwidth( - start_index=0, end_index=self.split8k, bandwidth_hz=1000 + start_index=0, end_index=self.split8k, bandwidth_hz=1000 ) below16k = self.get_band_specs_with_bandwidth( - start_index=self.split8k, - end_index=self.split16k, - bandwidth_hz=2000 + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 ) return below8k + below16k + self.above16k def version4(self): below1k = self.get_band_specs_with_bandwidth( - start_index=0, end_index=self.split1k, bandwidth_hz=100 + start_index=0, end_index=self.split1k, bandwidth_hz=100 ) below8k = self.get_band_specs_with_bandwidth( - start_index=self.split1k, - end_index=self.split8k, - bandwidth_hz=1000 + start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000 ) below16k = self.get_band_specs_with_bandwidth( - start_index=self.split8k, - end_index=self.split16k, - bandwidth_hz=2000 + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 ) return below1k + below8k + below16k + self.above16k def version5(self): below1k = self.get_band_specs_with_bandwidth( - start_index=0, end_index=self.split1k, bandwidth_hz=100 + start_index=0, end_index=self.split1k, bandwidth_hz=100 ) below16k = self.get_band_specs_with_bandwidth( - start_index=self.split1k, - end_index=self.split16k, - bandwidth_hz=1000 + start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000 ) below20k = self.get_band_specs_with_bandwidth( - start_index=self.split16k, - end_index=self.split20k, - bandwidth_hz=2000 + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 ) return below1k + below16k + below20k + self.above20k def version6(self): below1k = self.get_band_specs_with_bandwidth( - start_index=0, end_index=self.split1k, bandwidth_hz=100 + start_index=0, end_index=self.split1k, bandwidth_hz=100 ) below4k = self.get_band_specs_with_bandwidth( - start_index=self.split1k, - end_index=self.split4k, - bandwidth_hz=500 + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500 ) below8k = self.get_band_specs_with_bandwidth( - start_index=self.split4k, - end_index=self.split8k, - bandwidth_hz=1000 + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000 ) below16k = self.get_band_specs_with_bandwidth( - start_index=self.split8k, - end_index=self.split16k, - bandwidth_hz=2000 + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 ) return below1k + below4k + below8k + below16k + self.above16k def version7(self): below1k = self.get_band_specs_with_bandwidth( - start_index=0, end_index=self.split1k, bandwidth_hz=100 + start_index=0, end_index=self.split1k, bandwidth_hz=100 ) below4k = self.get_band_specs_with_bandwidth( - start_index=self.split1k, - end_index=self.split4k, - bandwidth_hz=250 + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250 ) below8k = self.get_band_specs_with_bandwidth( - start_index=self.split4k, - end_index=self.split8k, - bandwidth_hz=500 + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500 ) below16k = self.get_band_specs_with_bandwidth( - start_index=self.split8k, - end_index=self.split16k, - bandwidth_hz=1000 + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000 ) below20k = self.get_band_specs_with_bandwidth( - start_index=self.split16k, - end_index=self.split20k, - bandwidth_hz=2000 + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 ) return below1k + below4k + below8k + below16k + below20k + self.above20k @@ -224,27 +193,19 @@ def __init__(self, nfft: int, fs: int, version: str = "7") -> None: def get_band_specs(self): below500 = self.get_band_specs_with_bandwidth( - start_index=0, end_index=self.split500, bandwidth_hz=50 + start_index=0, end_index=self.split500, bandwidth_hz=50 ) below1k = self.get_band_specs_with_bandwidth( - start_index=self.split500, - end_index=self.split1k, - bandwidth_hz=100 + start_index=self.split500, end_index=self.split1k, bandwidth_hz=100 ) below4k = self.get_band_specs_with_bandwidth( - start_index=self.split1k, - end_index=self.split4k, - bandwidth_hz=500 + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500 ) below8k = self.get_band_specs_with_bandwidth( - start_index=self.split4k, - end_index=self.split8k, - bandwidth_hz=1000 + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000 ) below16k = self.get_band_specs_with_bandwidth( - start_index=self.split8k, - end_index=self.split16k, - bandwidth_hz=2000 + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 ) above16k = [(self.split16k, self.max_index)] @@ -257,59 +218,43 @@ def __init__(self, nfft: int, fs: int) -> None: def get_band_specs(self): below1k = self.get_band_specs_with_bandwidth( - start_index=0, end_index=self.split1k, bandwidth_hz=50 + start_index=0, end_index=self.split1k, bandwidth_hz=50 ) below2k = self.get_band_specs_with_bandwidth( - start_index=self.split1k, - end_index=self.split2k, - bandwidth_hz=100 + start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100 ) below4k = self.get_band_specs_with_bandwidth( - start_index=self.split2k, - end_index=self.split4k, - bandwidth_hz=250 + start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250 ) below8k = self.get_band_specs_with_bandwidth( - start_index=self.split4k, - end_index=self.split8k, - bandwidth_hz=500 + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500 ) below16k = self.get_band_specs_with_bandwidth( - start_index=self.split8k, - end_index=self.split16k, - bandwidth_hz=1000 + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000 ) above16k = [(self.split16k, self.max_index)] return below1k + below2k + below4k + below8k + below16k + above16k - - class PerceptualBandsplitSpecification(BandsplitSpecification): def __init__( - self, - nfft: int, - fs: int, - fbank_fn: Callable[[int, int, float, float, int], torch.Tensor], - n_bands: int, - f_min: float = 0.0, - f_max: float = None + self, + nfft: int, + fs: int, + fbank_fn: Callable[[int, int, float, float, int], torch.Tensor], + n_bands: int, + f_min: float = 0.0, + f_max: float = None, ) -> None: super().__init__(nfft=nfft, fs=fs) self.n_bands = n_bands if f_max is None: f_max = fs / 2 - self.filterbank = fbank_fn( - n_bands, fs, f_min, f_max, self.max_index - ) + self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index) - weight_per_bin = torch.sum( - self.filterbank, - dim=0, - keepdim=True - ) # (1, n_freqs) + weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs) normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs) freq_weights = [] @@ -342,22 +287,23 @@ def save_to_file(self, dir_path: str) -> None: with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f: pickle.dump( - { - "band_specs": self.band_specs, - "freq_weights": self.freq_weights, - "filterbank": self.filterbank, - }, - f, + { + "band_specs": self.band_specs, + "freq_weights": self.freq_weights, + "filterbank": self.filterbank, + }, + f, ) + def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs): fb = taF.melscale_fbanks( - n_mels=n_bands, - sample_rate=fs, - f_min=f_min, - f_max=f_max, - n_freqs=n_freqs, - ).T + n_mels=n_bands, + sample_rate=fs, + f_min=f_min, + f_max=f_max, + n_freqs=n_freqs, + ).T fb[0, 0] = 1.0 @@ -366,17 +312,19 @@ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs): class MelBandsplitSpecification(PerceptualBandsplitSpecification): def __init__( - self, - nfft: int, - fs: int, - n_bands: int, - f_min: float = 0.0, - f_max: float = None + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None ) -> None: - super().__init__(fbank_fn=mel_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + super().__init__( + fbank_fn=mel_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + -def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, - scale="constant"): +def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"): nfft = 2 * (n_freqs - 1) df = fs / nfft @@ -403,55 +351,57 @@ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, fb = np.zeros((n_bands, n_freqs)) for i in range(n_bands): - fb[i, low_bins[i]:high_bins[i]+1] = 1.0 + fb[i, low_bins[i] : high_bins[i] + 1] = 1.0 - fb[0, :low_bins[0]] = 1.0 - fb[-1, high_bins[-1]+1:] = 1.0 + fb[0, : low_bins[0]] = 1.0 + fb[-1, high_bins[-1] + 1 :] = 1.0 return torch.as_tensor(fb) + class MusicalBandsplitSpecification(PerceptualBandsplitSpecification): def __init__( - self, - nfft: int, - fs: int, - n_bands: int, - f_min: float = 0.0, - f_max: float = None + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None ) -> None: - super().__init__(fbank_fn=musical_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + super().__init__( + fbank_fn=musical_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) -def bark_filterbank( - n_bands, fs, f_min, f_max, n_freqs -): - nfft = 2 * (n_freqs -1) +def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs): + nfft = 2 * (n_freqs - 1) fb, _ = bark_fbanks.bark_filter_banks( - nfilts=n_bands, - nfft=nfft, - fs=fs, - low_freq=f_min, - high_freq=f_max, - scale="constant" + nfilts=n_bands, + nfft=nfft, + fs=fs, + low_freq=f_min, + high_freq=f_max, + scale="constant", ) return torch.as_tensor(fb) + class BarkBandsplitSpecification(PerceptualBandsplitSpecification): def __init__( - self, - nfft: int, - fs: int, - n_bands: int, - f_min: float = 0.0, - f_max: float = None + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None ) -> None: - super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + super().__init__( + fbank_fn=bark_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) -def triangular_bark_filterbank( - n_bands, fs, f_min, f_max, n_freqs -): +def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs): all_freqs = torch.linspace(0, fs // 2, n_freqs) @@ -474,47 +424,41 @@ def triangular_bark_filterbank( return fb + class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification): def __init__( - self, - nfft: int, - fs: int, - n_bands: int, - f_min: float = 0.0, - f_max: float = None + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None ) -> None: - super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) - + super().__init__( + fbank_fn=triangular_bark_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) -def minibark_filterbank( - n_bands, fs, f_min, f_max, n_freqs -): - fb = bark_filterbank( - n_bands, - fs, - f_min, - f_max, - n_freqs - ) +def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs): + fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs) fb[fb < np.sqrt(0.5)] = 0.0 return fb + class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification): def __init__( - self, - nfft: int, - fs: int, - n_bands: int, - f_min: float = 0.0, - f_max: float = None + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None ) -> None: - super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) - - - + super().__init__( + fbank_fn=minibark_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) def erb_filterbank( @@ -533,14 +477,13 @@ def erb_filterbank( m_max = hz2erb(f_max) m_pts = torch.linspace(m_min, m_max, n_bands + 2) - f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437 + f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437 # create filterbank fb = _create_triangular_filterbank(all_freqs, f_pts) fb = fb.T - first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] @@ -549,35 +492,34 @@ def erb_filterbank( return fb - class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification): def __init__( - self, - nfft: int, - fs: int, - n_bands: int, - f_min: float = 0.0, - f_max: float = None + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None ) -> None: - super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + super().__init__( + fbank_fn=erb_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + if __name__ == "__main__": import pandas as pd band_defs = [] - for bands in [VocalBandsplitSpecification]: + for bands in [VocalBandsplitSpecification]: band_name = bands.__name__.replace("BandsplitSpecification", "") mbs = bands(nfft=2048, fs=44100).get_band_specs() for i, (f_min, f_max) in enumerate(mbs): - band_defs.append({ - "band": band_name, - "band_index": i, - "f_min": f_min, - "f_max": f_max - }) + band_defs.append( + {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max} + ) df = pd.DataFrame(band_defs) - df.to_csv("vox7bands.csv", index=False) \ No newline at end of file + df.to_csv("vox7bands.csv", index=False) diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/wrapper.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/wrapper.py index a31c087..6f26e9d 100644 --- a/programs/music_separation_code/models/bandit/core/model/bsrnn/wrapper.py +++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/wrapper.py @@ -1,4 +1,3 @@ -from pprint import pprint from typing import Dict, List, Optional, Tuple, Union import torch @@ -6,76 +5,62 @@ from models.bandit.core.model._spectral import _SpectralComponent from models.bandit.core.model.bsrnn.utils import ( - BarkBandsplitSpecification, BassBandsplitSpecification, + BarkBandsplitSpecification, + BassBandsplitSpecification, DrumBandsplitSpecification, - EquivalentRectangularBandsplitSpecification, MelBandsplitSpecification, - MusicalBandsplitSpecification, OtherBandsplitSpecification, - TriangularBarkBandsplitSpecification, VocalBandsplitSpecification, + EquivalentRectangularBandsplitSpecification, + MelBandsplitSpecification, + MusicalBandsplitSpecification, + OtherBandsplitSpecification, + TriangularBarkBandsplitSpecification, + VocalBandsplitSpecification, ) from .core import ( MultiSourceMultiMaskBandSplitCoreConv, MultiSourceMultiMaskBandSplitCoreRNN, MultiSourceMultiMaskBandSplitCoreTransformer, - MultiSourceMultiPatchingMaskBandSplitCoreRNN, SingleMaskBandsplitCoreRNN, + MultiSourceMultiPatchingMaskBandSplitCoreRNN, + SingleMaskBandsplitCoreRNN, SingleMaskBandsplitCoreTransformer, ) import pytorch_lightning as pl + def get_band_specs(band_specs, n_fft, fs, n_bands=None): if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]: - bsm = VocalBandsplitSpecification( - nfft=n_fft, fs=fs - ).get_band_specs() + bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs() freq_weights = None overlapping_band = False elif "tribark" in band_specs: assert n_bands is not None - specs = TriangularBarkBandsplitSpecification( - nfft=n_fft, - fs=fs, - n_bands=n_bands - ) + specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands) bsm = specs.get_band_specs() freq_weights = specs.get_freq_weights() overlapping_band = True elif "bark" in band_specs: assert n_bands is not None - specs = BarkBandsplitSpecification( - nfft=n_fft, - fs=fs, - n_bands=n_bands - ) + specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands) bsm = specs.get_band_specs() freq_weights = specs.get_freq_weights() overlapping_band = True elif "erb" in band_specs: assert n_bands is not None specs = EquivalentRectangularBandsplitSpecification( - nfft=n_fft, - fs=fs, - n_bands=n_bands + nfft=n_fft, fs=fs, n_bands=n_bands ) bsm = specs.get_band_specs() freq_weights = specs.get_freq_weights() overlapping_band = True elif "musical" in band_specs: assert n_bands is not None - specs = MusicalBandsplitSpecification( - nfft=n_fft, - fs=fs, - n_bands=n_bands - ) + specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands) bsm = specs.get_band_specs() freq_weights = specs.get_freq_weights() overlapping_band = True elif band_specs == "dnr:mel" or "mel" in band_specs: assert n_bands is not None - specs = MelBandsplitSpecification( - nfft=n_fft, - fs=fs, - n_bands=n_bands - ) + specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands) bsm = specs.get_band_specs() freq_weights = specs.get_freq_weights() overlapping_band = True @@ -88,38 +73,24 @@ def get_band_specs(band_specs, n_fft, fs, n_bands=None): def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None): if band_specs_map == "musdb:all": bsm = { - "vocals": VocalBandsplitSpecification( - nfft=n_fft, fs=fs - ).get_band_specs(), - "drums": DrumBandsplitSpecification( - nfft=n_fft, fs=fs - ).get_band_specs(), - "bass": BassBandsplitSpecification( - nfft=n_fft, fs=fs - ).get_band_specs(), - "other": OtherBandsplitSpecification( - nfft=n_fft, fs=fs - ).get_band_specs(), + "vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(), + "drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(), + "bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(), + "other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(), } freq_weights = None overlapping_band = False elif band_specs_map == "dnr:vox7": bsm_, freq_weights, overlapping_band = get_band_specs( - "dnr:speech", n_fft, fs, n_bands + "dnr:speech", n_fft, fs, n_bands ) - bsm = { - "speech": bsm_, - "music": bsm_, - "effects": bsm_ - } + bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_} elif "dnr:vox7:" in band_specs_map: stem = band_specs_map.split(":")[-1] bsm_, freq_weights, overlapping_band = get_band_specs( - "dnr:speech", n_fft, fs, n_bands + "dnr:speech", n_fft, fs, n_bands ) - bsm = { - stem: bsm_ - } + bsm = {stem: bsm_} else: raise NameError @@ -128,51 +99,45 @@ def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None): class BandSplitWrapperBase(pl.LightningModule): bsrnn: nn.Module - + def __init__(self, **kwargs): super().__init__() -class SingleMaskMultiSourceBandSplitBase( - BandSplitWrapperBase, - _SpectralComponent -): +class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent): def __init__( - self, - band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], - fs: int = 44100, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, - n_bands: int = None, + self, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, ) -> None: super().__init__( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - center=center, - normalized=normalized, - pad_mode=pad_mode, - onesided=onesided, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, ) if isinstance(band_specs_map, str): - self.band_specs_map, self.freq_weights, self.overlapping_band = get_band_specs_map( - band_specs_map, - n_fft, - fs, - n_bands=n_bands - ) + self.band_specs_map, self.freq_weights, self.overlapping_band = ( + get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands) + ) self.stems = list(self.band_specs_map.keys()) @@ -180,8 +145,7 @@ def forward(self, batch): audio = batch["audio"] with torch.no_grad(): - batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in - audio} + batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio} X = batch["spectrogram"]["mixture"] length = batch["audio"]["mixture"].shape[-1] @@ -197,47 +161,41 @@ def forward(self, batch): return batch, output -class MultiMaskMultiSourceBandSplitBase( - BandSplitWrapperBase, - _SpectralComponent -): +class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent): def __init__( - self, - stems: List[str], - band_specs: Union[str, List[Tuple[float, float]]], - fs: int = 44100, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, - n_bands: int = None, + self, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, ) -> None: super().__init__( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - center=center, - normalized=normalized, - pad_mode=pad_mode, - onesided=onesided, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, ) if isinstance(band_specs, str): self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs( - band_specs, - n_fft, - fs, - n_bands - ) + band_specs, n_fft, fs, n_bands + ) self.stems = stems @@ -246,8 +204,7 @@ def forward(self, batch): audio = batch["audio"] cond = batch.get("condition", None) with torch.no_grad(): - batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in - audio} + batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio} X = batch["spectrogram"]["mixture"] length = batch["audio"]["mixture"].shape[-1] @@ -262,47 +219,41 @@ def forward(self, batch): return batch, output -class MultiMaskMultiSourceBandSplitBaseSimple( - BandSplitWrapperBase, - _SpectralComponent -): +class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent): def __init__( - self, - stems: List[str], - band_specs: Union[str, List[Tuple[float, float]]], - fs: int = 44100, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, - n_bands: int = None, + self, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, ) -> None: super().__init__( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - center=center, - normalized=normalized, - pad_mode=pad_mode, - onesided=onesided, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, ) if isinstance(band_specs, str): self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs( - band_specs, - n_fft, - fs, - n_bands - ) + band_specs, n_fft, fs, n_bands + ) self.stems = stems @@ -321,221 +272,219 @@ def forward(self, batch): class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase): def __init__( - self, - in_channel: int, - band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], - fs: int = 44100, - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - rnn_type: str = "LSTM", - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, + self, + in_channel: int, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, ) -> None: super().__init__( - band_specs_map=band_specs_map, - fs=fs, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - center=center, - normalized=normalized, - pad_mode=pad_mode, - onesided=onesided, + band_specs_map=band_specs_map, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, ) self.bsrnn = nn.ModuleDict( - { - src: SingleMaskBandsplitCoreRNN( - band_specs=specs, - in_channel=in_channel, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - n_sqm_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - mlp_dim=mlp_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - ) - for src, specs in self.band_specs_map.items() - } + { + src: SingleMaskBandsplitCoreRNN( + band_specs=specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for src, specs in self.band_specs_map.items() + } ) -class SingleMaskMultiSourceBandSplitTransformer( - SingleMaskMultiSourceBandSplitBase -): +class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase): def __init__( - self, - in_channel: int, - band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], - fs: int = 44100, - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - tf_dropout: float = 0.0, - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, + self, + in_channel: int, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, ) -> None: super().__init__( - band_specs_map=band_specs_map, - fs=fs, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - center=center, - normalized=normalized, - pad_mode=pad_mode, - onesided=onesided, + band_specs_map=band_specs_map, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, ) self.bsrnn = nn.ModuleDict( - { - src: SingleMaskBandsplitCoreTransformer( - band_specs=specs, - in_channel=in_channel, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - n_sqm_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - tf_dropout=tf_dropout, - mlp_dim=mlp_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - ) - for src, specs in self.band_specs_map.items() - } + { + src: SingleMaskBandsplitCoreTransformer( + band_specs=specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + tf_dropout=tf_dropout, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for src, specs in self.band_specs_map.items() + } ) class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase): def __init__( - self, - in_channel: int, - stems: List[str], - band_specs: Union[str, List[Tuple[float, float]]], - fs: int = 44100, - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - cond_dim: int = 0, - bidirectional: bool = True, - rnn_type: str = "LSTM", - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, - n_bands: int = None, - use_freq_weights: bool = True, - normalize_input: bool = False, - mult_add_mask: bool = False, - freeze_encoder: bool = False, + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, + freeze_encoder: bool = False, ) -> None: super().__init__( - stems=stems, - band_specs=band_specs, - fs=fs, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - center=center, - normalized=normalized, - pad_mode=pad_mode, - onesided=onesided, - n_bands=n_bands, + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, ) self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN( - stems=stems, - band_specs=self.band_specs, - in_channel=in_channel, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - n_sqm_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - mlp_dim=mlp_dim, - cond_dim=cond_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - overlapping_band=self.overlapping_band, - freq_weights=self.freq_weights, - n_freq=n_fft // 2 + 1, - use_freq_weights=use_freq_weights, - mult_add_mask=mult_add_mask + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, ) self.normalize_input = normalize_input @@ -551,81 +500,81 @@ def __init__( class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple): def __init__( - self, - in_channel: int, - stems: List[str], - band_specs: Union[str, List[Tuple[float, float]]], - fs: int = 44100, - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - cond_dim: int = 0, - bidirectional: bool = True, - rnn_type: str = "LSTM", - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, - n_bands: int = None, - use_freq_weights: bool = True, - normalize_input: bool = False, - mult_add_mask: bool = False, - freeze_encoder: bool = False, + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, + freeze_encoder: bool = False, ) -> None: super().__init__( - stems=stems, - band_specs=band_specs, - fs=fs, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - center=center, - normalized=normalized, - pad_mode=pad_mode, - onesided=onesided, - n_bands=n_bands, + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, ) self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN( - stems=stems, - band_specs=self.band_specs, - in_channel=in_channel, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - n_sqm_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - mlp_dim=mlp_dim, - cond_dim=cond_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - overlapping_band=self.overlapping_band, - freq_weights=self.freq_weights, - n_freq=n_fft // 2 + 1, - use_freq_weights=use_freq_weights, - mult_add_mask=mult_add_mask + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, ) self.normalize_input = normalize_input @@ -639,244 +588,241 @@ def __init__( param.requires_grad = False -class MultiMaskMultiSourceBandSplitTransformer( - MultiMaskMultiSourceBandSplitBase -): +class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase): def __init__( - self, - in_channel: int, - stems: List[str], - band_specs: Union[str, List[Tuple[float, float]]], - fs: int = 44100, - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - cond_dim: int = 0, - bidirectional: bool = True, - rnn_type: str = "LSTM", - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, - n_bands: int = None, - use_freq_weights: bool = True, - normalize_input: bool = False, - mult_add_mask: bool = False + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, ) -> None: super().__init__( - stems=stems, - band_specs=band_specs, - fs=fs, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - center=center, - normalized=normalized, - pad_mode=pad_mode, - onesided=onesided, - n_bands=n_bands, + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, ) self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer( - stems=stems, - band_specs=self.band_specs, - in_channel=in_channel, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - n_sqm_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - mlp_dim=mlp_dim, - cond_dim=cond_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - overlapping_band=self.overlapping_band, - freq_weights=self.freq_weights, - n_freq=n_fft // 2 + 1, - use_freq_weights=use_freq_weights, - mult_add_mask=mult_add_mask + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, ) - -class MultiMaskMultiSourceBandSplitConv( - MultiMaskMultiSourceBandSplitBase -): +class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase): def __init__( - self, - in_channel: int, - stems: List[str], - band_specs: Union[str, List[Tuple[float, float]]], - fs: int = 44100, - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - cond_dim: int = 0, - bidirectional: bool = True, - rnn_type: str = "LSTM", - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, - n_bands: int = None, - use_freq_weights: bool = True, - normalize_input: bool = False, - mult_add_mask: bool = False + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, ) -> None: super().__init__( - stems=stems, - band_specs=band_specs, - fs=fs, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - center=center, - normalized=normalized, - pad_mode=pad_mode, - onesided=onesided, - n_bands=n_bands, + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, ) self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv( - stems=stems, - band_specs=self.band_specs, - in_channel=in_channel, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - n_sqm_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - mlp_dim=mlp_dim, - cond_dim=cond_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - overlapping_band=self.overlapping_band, - freq_weights=self.freq_weights, - n_freq=n_fft // 2 + 1, - use_freq_weights=use_freq_weights, - mult_add_mask=mult_add_mask + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, ) + + class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase): def __init__( - self, - in_channel: int, - stems: List[str], - band_specs: Union[str, List[Tuple[float, float]]], - kernel_norm_mlp_version: int = 1, - mask_kernel_freq: int = 3, - mask_kernel_time: int = 3, - conv_kernel_freq: int = 1, - conv_kernel_time: int = 1, - fs: int = 44100, - require_no_overlap: bool = False, - require_no_gap: bool = True, - normalize_channel_independently: bool = False, - treat_channel_as_feature: bool = True, - n_sqm_modules: int = 12, - emb_dim: int = 128, - rnn_dim: int = 256, - bidirectional: bool = True, - rnn_type: str = "LSTM", - mlp_dim: int = 512, - hidden_activation: str = "Tanh", - hidden_activation_kwargs: Optional[Dict] = None, - complex_mask: bool = True, - n_fft: int = 2048, - win_length: Optional[int] = 2048, - hop_length: int = 512, - window_fn: str = "hann_window", - wkwargs: Optional[Dict] = None, - power: Optional[int] = None, - center: bool = True, - normalized: bool = True, - pad_mode: str = "constant", - onesided: bool = True, - n_bands: int = None, + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + kernel_norm_mlp_version: int = 1, + mask_kernel_freq: int = 3, + mask_kernel_time: int = 3, + conv_kernel_freq: int = 1, + conv_kernel_time: int = 1, + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, ) -> None: super().__init__( - stems=stems, - band_specs=band_specs, - fs=fs, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window_fn=window_fn, - wkwargs=wkwargs, - power=power, - center=center, - normalized=normalized, - pad_mode=pad_mode, - onesided=onesided, - n_bands=n_bands, + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, ) self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN( - stems=stems, - band_specs=self.band_specs, - in_channel=in_channel, - require_no_overlap=require_no_overlap, - require_no_gap=require_no_gap, - normalize_channel_independently=normalize_channel_independently, - treat_channel_as_feature=treat_channel_as_feature, - n_sqm_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - mlp_dim=mlp_dim, - hidden_activation=hidden_activation, - hidden_activation_kwargs=hidden_activation_kwargs, - complex_mask=complex_mask, - overlapping_band=self.overlapping_band, - freq_weights=self.freq_weights, - n_freq=n_fft // 2 + 1, - mask_kernel_freq=mask_kernel_freq, - mask_kernel_time=mask_kernel_time, - conv_kernel_freq=conv_kernel_freq, - conv_kernel_time=conv_kernel_time, - kernel_norm_mlp_version=kernel_norm_mlp_version, + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + mask_kernel_freq=mask_kernel_freq, + mask_kernel_time=mask_kernel_time, + conv_kernel_freq=conv_kernel_freq, + conv_kernel_time=conv_kernel_time, + kernel_norm_mlp_version=kernel_norm_mlp_version, ) diff --git a/programs/music_separation_code/models/bandit/core/utils/audio.py b/programs/music_separation_code/models/bandit/core/utils/audio.py index e4066d7..6bdea55 100644 --- a/programs/music_separation_code/models/bandit/core/utils/audio.py +++ b/programs/music_separation_code/models/bandit/core/utils/audio.py @@ -1,7 +1,7 @@ from collections import defaultdict from tqdm import tqdm -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, Tuple import numpy as np import torch @@ -11,19 +11,17 @@ @torch.jit.script def merge( - combined: torch.Tensor, - original_batch_size: int, - n_channel: int, - n_chunks: int, - chunk_size: int, ): + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_chunks: int, + chunk_size: int, +): combined = torch.reshape( - combined, - (original_batch_size, n_chunks, n_channel, chunk_size) + combined, (original_batch_size, n_chunks, n_channel, chunk_size) ) combined = torch.permute(combined, (0, 2, 3, 1)).reshape( - original_batch_size * n_channel, - chunk_size, - n_chunks + original_batch_size * n_channel, chunk_size, n_chunks ) return combined @@ -31,33 +29,23 @@ def merge( @torch.jit.script def unfold( - padded_audio: torch.Tensor, - original_batch_size: int, - n_channel: int, - chunk_size: int, - hop_size: int - ) -> torch.Tensor: + padded_audio: torch.Tensor, + original_batch_size: int, + n_channel: int, + chunk_size: int, + hop_size: int, +) -> torch.Tensor: unfolded_input = F.unfold( - padded_audio[:, :, None, :], - kernel_size=(1, chunk_size), - stride=(1, hop_size) + padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size) ) _, _, n_chunks = unfolded_input.shape unfolded_input = unfolded_input.view( - original_batch_size, - n_channel, - chunk_size, - n_chunks + original_batch_size, n_channel, chunk_size, n_chunks ) - unfolded_input = torch.permute( - unfolded_input, - (0, 3, 1, 2) - ).reshape( - original_batch_size * n_chunks, - n_channel, - chunk_size + unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape( + original_batch_size * n_chunks, n_channel, chunk_size ) return unfolded_input @@ -66,40 +54,31 @@ def unfold( @torch.jit.script # @torch.compile def merge_chunks_all( - combined: torch.Tensor, - original_batch_size: int, - n_channel: int, - n_samples: int, - n_padded_samples: int, - n_chunks: int, - chunk_size: int, - hop_size: int, - edge_frame_pad_sizes: Tuple[int, int], - standard_window: torch.Tensor, - first_window: torch.Tensor, - last_window: torch.Tensor + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_samples: int, + n_padded_samples: int, + n_chunks: int, + chunk_size: int, + hop_size: int, + edge_frame_pad_sizes: Tuple[int, int], + standard_window: torch.Tensor, + first_window: torch.Tensor, + last_window: torch.Tensor, ): - combined = merge( - combined, - original_batch_size, - n_channel, - n_chunks, - chunk_size - ) + combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size) combined = combined * standard_window[:, None].to(combined.device) combined = F.fold( - combined.to(torch.float32), output_size=(1, n_padded_samples), - kernel_size=(1, chunk_size), - stride=(1, hop_size) + combined.to(torch.float32), + output_size=(1, n_padded_samples), + kernel_size=(1, chunk_size), + stride=(1, hop_size), ) - combined = combined.view( - original_batch_size, - n_channel, - n_padded_samples - ) + combined = combined.view(original_batch_size, n_channel, n_padded_samples) pad_front, pad_back = edge_frame_pad_sizes combined = combined[..., pad_front:-pad_back] @@ -112,43 +91,33 @@ def merge_chunks_all( def merge_chunks_edge( - combined: torch.Tensor, - original_batch_size: int, - n_channel: int, - n_samples: int, - n_padded_samples: int, - n_chunks: int, - chunk_size: int, - hop_size: int, - edge_frame_pad_sizes: Tuple[int, int], - standard_window: torch.Tensor, - first_window: torch.Tensor, - last_window: torch.Tensor + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_samples: int, + n_padded_samples: int, + n_chunks: int, + chunk_size: int, + hop_size: int, + edge_frame_pad_sizes: Tuple[int, int], + standard_window: torch.Tensor, + first_window: torch.Tensor, + last_window: torch.Tensor, ): - combined = merge( - combined, - original_batch_size, - n_channel, - n_chunks, - chunk_size - ) + combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size) combined[..., 0] = combined[..., 0] * first_window combined[..., -1] = combined[..., -1] * last_window - combined[..., 1:-1] = combined[..., - 1:-1] * standard_window[:, None] + combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None] combined = F.fold( - combined, output_size=(1, n_padded_samples), - kernel_size=(1, chunk_size), - stride=(1, hop_size) + combined, + output_size=(1, n_padded_samples), + kernel_size=(1, chunk_size), + stride=(1, hop_size), ) - combined = combined.view( - original_batch_size, - n_channel, - n_padded_samples - ) + combined = combined.view(original_batch_size, n_channel, n_padded_samples) combined = combined[..., :n_samples] @@ -157,12 +126,12 @@ def merge_chunks_edge( class BaseFader(nn.Module): def __init__( - self, - chunk_size_second: float, - hop_size_second: float, - fs: int, - fade_edge_frames: bool, - batch_size: int, + self, + chunk_size_second: float, + hop_size_second: float, + fs: int, + fade_edge_frames: bool, + batch_size: int, ) -> None: super().__init__() @@ -179,9 +148,7 @@ def prepare(self, audio): audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect") n_samples = audio.shape[-1] - n_chunks = int( - np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1 - ) + n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1) padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size pad_size = padded_size - n_samples @@ -191,9 +158,9 @@ def prepare(self, audio): return padded_audio, n_chunks def forward( - self, - audio: torch.Tensor, - model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]], + self, + audio: torch.Tensor, + model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]], ): original_dtype = audio.dtype @@ -208,14 +175,11 @@ def forward( if n_channel > 1: padded_audio = padded_audio.view( - original_batch_size * n_channel, 1, n_padded_samples + original_batch_size * n_channel, 1, n_padded_samples ) unfolded_input = unfold( - padded_audio, - original_batch_size, - n_channel, - self.chunk_size, self.hop_size + padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size ) n_total_chunks, n_channel, chunk_size = unfolded_input.shape @@ -223,15 +187,12 @@ def forward( n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int) chunks_in = [ - unfolded_input[ - b * self.batch_size:(b + 1) * self.batch_size, ...].clone() - for b in range(n_batch) + unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone() + for b in range(n_batch) ] all_chunks_out = defaultdict( - lambda: torch.zeros_like( - unfolded_input, device="cpu" - ) + lambda: torch.zeros_like(unfolded_input, device="cpu") ) # for b, cin in enumerate(tqdm(chunks_in)): @@ -243,8 +204,9 @@ def forward( chunks_out = model_fn(cin.to(original_device)) del cin for s, c in chunks_out.items(): - all_chunks_out[s][b * self.batch_size:(b + 1) * self.batch_size, - ...] = c.cpu() + all_chunks_out[s][ + b * self.batch_size : (b + 1) * self.batch_size, ... + ] = c.cpu() del chunks_out del unfolded_input @@ -260,28 +222,24 @@ def forward( for s, c in all_chunks_out.items(): combined: torch.Tensor = fn( - c, - original_batch_size, - n_channel, - n_samples, - n_padded_samples, - n_chunks, - self.chunk_size, - self.hop_size, - self.edge_frame_pad_sizes, - self.standard_window, - self.__dict__.get("first_window", self.standard_window), - self.__dict__.get("last_window", self.standard_window) + c, + original_batch_size, + n_channel, + n_samples, + n_padded_samples, + n_chunks, + self.chunk_size, + self.hop_size, + self.edge_frame_pad_sizes, + self.standard_window, + self.__dict__.get("first_window", self.standard_window), + self.__dict__.get("last_window", self.standard_window), ) - outputs[s] = combined.to( - dtype=original_dtype, - device=original_device - ) + outputs[s] = combined.to(dtype=original_dtype, device=original_device) + + return {"audio": outputs} - return { - "audio": outputs - } # # def old_forward( # self, @@ -366,22 +324,22 @@ def forward( class LinearFader(BaseFader): def __init__( - self, - chunk_size_second: float, - hop_size_second: float, - fs: int, - fade_edge_frames: bool = False, - batch_size: int = 1, + self, + chunk_size_second: float, + hop_size_second: float, + fs: int, + fade_edge_frames: bool = False, + batch_size: int = 1, ) -> None: assert hop_size_second >= chunk_size_second / 2 super().__init__( - chunk_size_second=chunk_size_second, - hop_size_second=hop_size_second, - fs=fs, - fade_edge_frames=fade_edge_frames, - batch_size=batch_size, + chunk_size_second=chunk_size_second, + hop_size_second=hop_size_second, + fs=fs, + fade_edge_frames=fade_edge_frames, + batch_size=batch_size, ) in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1] @@ -391,8 +349,7 @@ def __init__( # using nn.Parameters allows lightning to take care of devices for us self.register_buffer( - "standard_window", - torch.concat([in_fade, center_ones, out_fade]) + "standard_window", torch.concat([in_fade, center_ones, out_fade]) ) self.fade_edge_frames = fade_edge_frames @@ -400,23 +357,21 @@ def __init__( if not self.fade_edge_frames: self.first_window = nn.Parameter( - torch.concat([inout_ones, center_ones, out_fade]), - requires_grad=False + torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False ) self.last_window = nn.Parameter( - torch.concat([in_fade, center_ones, inout_ones]), - requires_grad=False + torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False ) class OverlapAddFader(BaseFader): def __init__( - self, - window_type: str, - chunk_size_second: float, - hop_size_second: float, - fs: int, - batch_size: int = 1, + self, + window_type: str, + chunk_size_second: float, + hop_size_second: float, + fs: int, + batch_size: int = 1, ) -> None: assert (chunk_size_second / hop_size_second) % 2 == 0 assert int(chunk_size_second * fs) % 2 == 0 @@ -432,31 +387,25 @@ def __init__( self.hop_multiplier = self.chunk_size / (2 * self.hop_size) # print(f"hop multiplier: {self.hop_multiplier}") - self.edge_frame_pad_sizes = ( - 2 * self.overlap_size, - 2 * self.overlap_size - ) + self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size) self.register_buffer( - "standard_window", torch.windows.__dict__[window_type]( - self.chunk_size, sym=False, # dtype=torch.float64 - ) / self.hop_multiplier + "standard_window", + torch.windows.__dict__[window_type]( + self.chunk_size, + sym=False, # dtype=torch.float64 + ) + / self.hop_multiplier, ) if __name__ == "__main__": import torchaudio as ta + fs = 44100 - ola = OverlapAddFader( - "hann", - 6.0, - 1.0, - fs, - batch_size=16 - ) + ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16) audio_, _ = ta.load( - "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " - "Much/vocals.wav" + "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav" ) audio_ = audio_[None, ...] out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"] diff --git a/programs/music_separation_code/models/bandit/model_from_config.py b/programs/music_separation_code/models/bandit/model_from_config.py index 00ea586..9735bda 100644 --- a/programs/music_separation_code/models/bandit/model_from_config.py +++ b/programs/music_separation_code/models/bandit/model_from_config.py @@ -2,7 +2,7 @@ import os.path import torch -code_path = os.path.dirname(os.path.abspath(__file__)) + '/' +code_path = os.path.dirname(os.path.abspath(__file__)) + "/" sys.path.append(code_path) import yaml @@ -22,10 +22,8 @@ def get_model( config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) f.close() - model = MultiMaskMultiSourceBandSplitRNNSimple( - **config.model - ) - d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt') + model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model) + d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt") model.load_state_dict(d) model.to(device) return model, config diff --git a/programs/music_separation_code/models/bandit_v2/bandit.py b/programs/music_separation_code/models/bandit_v2/bandit.py index ac4e13f..fba3296 100644 --- a/programs/music_separation_code/models/bandit_v2/bandit.py +++ b/programs/music_separation_code/models/bandit_v2/bandit.py @@ -11,7 +11,6 @@ from .utils import MusicalBandsplitSpecification - class BaseEndToEndModule(pl.LightningModule): def __init__( self, @@ -178,12 +177,12 @@ def instantiate_tf_modelling( ) except Exception as e: self.tf_model = SeqBandModellingModule( - n_modules=n_sqm_modules, - emb_dim=emb_dim, - rnn_dim=rnn_dim, - bidirectional=bidirectional, - rnn_type=rnn_type, - ) + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) def mask(self, x, m): return x * m @@ -193,11 +192,7 @@ def forward(self, batch, mode="train"): init_shape = batch.shape if not isinstance(batch, dict): mono = batch.view(-1, 1, batch.shape[-1]) - batch = { - "mixture": { - "audio": mono - } - } + batch = {"mixture": {"audio": mono}} with torch.no_grad(): mixture = batch["mixture"]["audio"] @@ -217,7 +212,9 @@ def forward(self, batch, mode="train"): b = [] for s in self.stems: # We need to obtain stereo again - r = batch['estimates'][s]['audio'].view(-1, init_shape[1], init_shape[2]) + r = batch["estimates"][s]["audio"].view( + -1, init_shape[1], init_shape[2] + ) b.append(r) # And we need to return back tensor and not independent stems batch = torch.stack(b, dim=1) @@ -364,4 +361,3 @@ def separate(self, batch): } return batch - diff --git a/programs/music_separation_code/models/bandit_v2/film.py b/programs/music_separation_code/models/bandit_v2/film.py index e307953..253594a 100644 --- a/programs/music_separation_code/models/bandit_v2/film.py +++ b/programs/music_separation_code/models/bandit_v2/film.py @@ -1,10 +1,11 @@ from torch import nn import torch + class FiLM(nn.Module): def __init__(self): super().__init__() - + def forward(self, x, gamma, beta): return gamma * x + beta @@ -13,13 +14,10 @@ class BTFBroadcastedFiLM(nn.Module): def __init__(self): super().__init__() self.film = FiLM() - + def forward(self, x, gamma, beta): - + gamma = gamma[None, None, None, :] beta = beta[None, None, None, :] - + return self.film(x, gamma, beta) - - - \ No newline at end of file diff --git a/programs/music_separation_code/models/bs_roformer/attend.py b/programs/music_separation_code/models/bs_roformer/attend.py index d6dc4b3..9ebb7c9 100644 --- a/programs/music_separation_code/models/bs_roformer/attend.py +++ b/programs/music_separation_code/models/bs_roformer/attend.py @@ -11,18 +11,24 @@ # constants -FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) +FlashAttentionConfig = namedtuple( + "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] +) # helpers + def exists(val): return val is not None + def default(v, d): return v if exists(v) else d + def once(fn): called = False + @wraps(fn) def inner(x): nonlocal called @@ -30,26 +36,26 @@ def inner(x): return called = True return fn(x) + return inner + print_once = once(print) # main class + class Attend(nn.Module): - def __init__( - self, - dropout = 0., - flash = False, - scale = None - ): + def __init__(self, dropout=0.0, flash=False, scale=None): super().__init__() self.scale = scale self.dropout = dropout self.attn_dropout = nn.Dropout(dropout) self.flash = flash - assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' + assert not ( + flash and version.parse(torch.__version__) < version.parse("2.0.0") + ), "in order to use flash attention, you must be using pytorch 2.0 or above" # determine efficient attention configs for cuda and cpu @@ -59,22 +65,35 @@ def __init__( if not torch.cuda.is_available() or not flash: return - device_properties = torch.cuda.get_device_properties(torch.device('cuda')) - device_version = version.parse(f'{device_properties.major}.{device_properties.minor}') + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + device_version = version.parse( + f"{device_properties.major}.{device_properties.minor}" + ) - if device_version >= version.parse('8.0'): - if os.name == 'nt': - print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda') + if device_version >= version.parse("8.0"): + if os.name == "nt": + print_once( + "Windows OS detected, using math or mem efficient attention if input tensor is on cuda" + ) self.cuda_config = FlashAttentionConfig(False, True, True) else: - print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda') + print_once( + "GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda" + ) self.cuda_config = FlashAttentionConfig(True, False, False) else: - print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda') + print_once( + "GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda" + ) self.cuda_config = FlashAttentionConfig(False, True, True) def flash_attn(self, q, k, v): - _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device + _, heads, q_len, _, k_len, is_cuda, device = ( + *q.shape, + k.shape[-2], + q.is_cuda, + q.device, + ) if exists(self.scale): default_scale = q.shape[-1] ** -0.5 @@ -88,8 +107,7 @@ def flash_attn(self, q, k, v): with torch.backends.cuda.sdp_kernel(**config._asdict()): out = F.scaled_dot_product_attention( - q, k, v, - dropout_p = self.dropout if self.training else 0. + q, k, v, dropout_p=self.dropout if self.training else 0.0 ) return out diff --git a/programs/music_separation_code/models/bs_roformer/bs_roformer.py b/programs/music_separation_code/models/bs_roformer/bs_roformer.py index 2fda0cc..3ed1544 100644 --- a/programs/music_separation_code/models/bs_roformer/bs_roformer.py +++ b/programs/music_separation_code/models/bs_roformer/bs_roformer.py @@ -17,6 +17,7 @@ # helper functions + def exists(val): return val is not None @@ -35,14 +36,15 @@ def unpack_one(t, ps, pattern): # norm + def l2norm(t): - return F.normalize(t, dim = -1, p = 2) + return F.normalize(t, dim=-1, p=2) class RMSNorm(Module): def __init__(self, dim): super().__init__() - self.scale = dim ** 0.5 + self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(dim)) def forward(self, x): @@ -51,13 +53,9 @@ def forward(self, x): # attention + class FeedForward(Module): - def __init__( - self, - dim, - mult=4, - dropout=0. - ): + def __init__(self, dim, mult=4, dropout=0.0): super().__init__() dim_inner = int(dim * mult) self.net = nn.Sequential( @@ -66,7 +64,7 @@ def __init__( nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim), - nn.Dropout(dropout) + nn.Dropout(dropout), ) def forward(self, x): @@ -75,17 +73,11 @@ def forward(self, x): class Attention(Module): def __init__( - self, - dim, - heads=8, - dim_head=64, - dropout=0., - rotary_embed=None, - flash=True + self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True ): super().__init__() self.heads = heads - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 dim_inner = heads * dim_head self.rotary_embed = rotary_embed @@ -98,14 +90,15 @@ def __init__( self.to_gates = nn.Linear(dim, heads) self.to_out = nn.Sequential( - nn.Linear(dim_inner, dim, bias=False), - nn.Dropout(dropout) + nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout) ) def forward(self, x): x = self.norm(x) - q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + q, k, v = rearrange( + self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads + ) if exists(self.rotary_embed): q = self.rotary_embed.rotate_queries_or_keys(q) @@ -114,9 +107,9 @@ def forward(self, x): out = self.attend(q, k, v) gates = self.to_gates(x) - out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + out = out * rearrange(gates, "b n h -> b h n 1").sigmoid() - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) @@ -126,42 +119,25 @@ class LinearAttention(Module): """ @beartype - def __init__( - self, - *, - dim, - dim_head=32, - heads=8, - scale=8, - flash=False, - dropout=0. - ): + def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0): super().__init__() dim_inner = dim_head * heads self.norm = RMSNorm(dim) self.to_qkv = nn.Sequential( nn.Linear(dim, dim_inner * 3, bias=False), - Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads) + Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads), ) self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) - self.attend = Attend( - scale=scale, - dropout=dropout, - flash=flash - ) + self.attend = Attend(scale=scale, dropout=dropout, flash=flash) self.to_out = nn.Sequential( - Rearrange('b h d n -> b n (h d)'), - nn.Linear(dim_inner, dim, bias=False) + Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False) ) - def forward( - self, - x - ): + def forward(self, x): x = self.norm(x) q, k, v = self.to_qkv(x) @@ -176,34 +152,47 @@ def forward( class Transformer(Module): def __init__( - self, - *, - dim, - depth, - dim_head=64, - heads=8, - attn_dropout=0., - ff_dropout=0., - ff_mult=4, - norm_output=True, - rotary_embed=None, - flash_attn=True, - linear_attn=False + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False, ): super().__init__() self.layers = ModuleList([]) for _ in range(depth): if linear_attn: - attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) + attn = LinearAttention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + flash=flash_attn, + ) else: - attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, - rotary_embed=rotary_embed, flash=flash_attn) - - self.layers.append(ModuleList([ - attn, - FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) - ])) + attn = Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + rotary_embed=rotary_embed, + flash=flash_attn, + ) + + self.layers.append( + ModuleList( + [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)] + ) + ) self.norm = RMSNorm(dim) if norm_output else nn.Identity() @@ -218,22 +207,16 @@ def forward(self, x): # bandsplit module + class BandSplit(Module): @beartype - def __init__( - self, - dim, - dim_inputs: Tuple[int, ...] - ): + def __init__(self, dim, dim_inputs: Tuple[int, ...]): super().__init__() self.dim_inputs = dim_inputs self.to_features = ModuleList([]) for dim_in in dim_inputs: - net = nn.Sequential( - RMSNorm(dim_in), - nn.Linear(dim_in, dim) - ) + net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim)) self.to_features.append(net) @@ -248,13 +231,7 @@ def forward(self, x): return torch.stack(outs, dim=-2) -def MLP( - dim_in, - dim_out, - dim_hidden=None, - depth=1, - activation=nn.Tanh -): +def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh): dim_hidden = default(dim_hidden, dim_in) net = [] @@ -275,13 +252,7 @@ def MLP( class MaskEstimator(Module): @beartype - def __init__( - self, - dim, - dim_inputs: Tuple[int, ...], - depth, - mlp_expansion_factor=4 - ): + def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4): super().__init__() self.dim_inputs = dim_inputs self.to_freqs = ModuleList([]) @@ -291,8 +262,7 @@ def __init__( net = [] mlp = nn.Sequential( - MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), - nn.GLU(dim=-1) + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1) ) self.to_freqs.append(mlp) @@ -312,14 +282,68 @@ def forward(self, x): # main class DEFAULT_FREQS_PER_BANDS = ( - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, - 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, - 12, 12, 12, 12, 12, 12, 12, 12, - 24, 24, 24, 24, 24, 24, 24, 24, - 48, 48, 48, 48, 48, 48, 48, 48, - 128, 129, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 48, + 48, + 48, + 48, + 48, + 48, + 48, + 48, + 128, + 129, ) @@ -327,35 +351,41 @@ class BSRoformer(Module): @beartype def __init__( - self, - dim, - *, - depth, - stereo=False, - num_stems=1, - time_transformer_depth=2, - freq_transformer_depth=2, - linear_transformer_depth=0, - freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, - # in the paper, they divide into ~60 bands, test with 1 for starters - dim_head=64, - heads=8, - attn_dropout=0., - ff_dropout=0., - flash_attn=True, - dim_freqs_in=1025, - stft_n_fft=2048, - stft_hop_length=512, - # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction - stft_win_length=2048, - stft_normalized=False, - stft_window_fn: Optional[Callable] = None, - mask_estimator_depth=2, - multi_stft_resolution_loss_weight=1., - multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), - multi_stft_hop_size=147, - multi_stft_normalized=False, - multi_stft_window_fn: Callable = torch.hann_window + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, + # in the paper, they divide into ~60 bands, test with 1 for starters + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + flash_attn=True, + dim_freqs_in=1025, + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=2, + multi_stft_resolution_loss_weight=1.0, + multi_stft_resolutions_window_sizes: Tuple[int, ...] = ( + 4096, + 2048, + 1024, + 512, + 256, + ), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, ): super().__init__() @@ -372,7 +402,7 @@ def __init__( attn_dropout=attn_dropout, ff_dropout=ff_dropout, flash_attn=flash_attn, - norm_output=False + norm_output=False, ) time_rotary_embed = RotaryEmbedding(dim=dim_head) @@ -381,12 +411,26 @@ def __init__( for _ in range(depth): tran_modules = [] if linear_transformer_depth > 0: - tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) + tran_modules.append( + Transformer( + depth=linear_transformer_depth, + linear_attn=True, + **transformer_kwargs, + ) + ) tran_modules.append( - Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) + Transformer( + depth=time_transformer_depth, + rotary_embed=time_rotary_embed, + **transformer_kwargs, + ) ) tran_modules.append( - Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) + Transformer( + depth=freq_transformer_depth, + rotary_embed=freq_rotary_embed, + **transformer_kwargs, + ) ) self.layers.append(nn.ModuleList(tran_modules)) @@ -396,31 +440,38 @@ def __init__( n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, - normalized=stft_normalized + normalized=stft_normalized, ) - self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) + self.stft_window_fn = partial( + default(stft_window_fn, torch.hann_window), stft_win_length + ) - freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1] + freqs = torch.stft( + torch.randn(1, 4096), + **self.stft_kwargs, + window=torch.ones(stft_n_fft), + return_complex=True, + ).shape[1] assert len(freqs_per_bands) > 1 - assert sum( - freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}' + assert ( + sum(freqs_per_bands) == freqs + ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}" - freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands) - - self.band_split = BandSplit( - dim=dim, - dim_inputs=freqs_per_bands_with_complex + freqs_per_bands_with_complex = tuple( + 2 * f * self.audio_channels for f in freqs_per_bands ) + self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex) + self.mask_estimators = nn.ModuleList([]) for _ in range(num_stems): mask_estimator = MaskEstimator( dim=dim, dim_inputs=freqs_per_bands_with_complex, - depth=mask_estimator_depth + depth=mask_estimator_depth, ) self.mask_estimators.append(mask_estimator) @@ -433,16 +484,10 @@ def __init__( self.multi_stft_window_fn = multi_stft_window_fn self.multi_stft_kwargs = dict( - hop_length=multi_stft_hop_size, - normalized=multi_stft_normalized + hop_length=multi_stft_hop_size, normalized=multi_stft_normalized ) - def forward( - self, - raw_audio, - target=None, - return_loss_breakdown=False - ): + def forward(self, raw_audio, target=None, return_loss_breakdown=False): """ einops @@ -461,32 +506,41 @@ def forward( x_is_mps = True if device.type == "mps" else False if raw_audio.ndim == 2: - raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + raw_audio = rearrange(raw_audio, "b t -> b 1 t") channels = raw_audio.shape[1] assert (not self.stereo and channels == 1) or ( - self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' + self.stereo and channels == 2 + ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)" # to stft - raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t") stft_window = self.stft_window_fn(device=device) # RuntimeError: FFT operations are only supported on MacOS 14+ # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used try: - stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + stft_repr = torch.stft( + raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True + ) except: - stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device) + stft_repr = torch.stft( + raw_audio.cpu() if x_is_mps else raw_audio, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=True, + ).to(device) stft_repr = torch.view_as_real(stft_repr) - stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') - stft_repr = rearrange(stft_repr, - 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c") + stft_repr = rearrange( + stft_repr, "b s f t c -> b (f s) t c" + ) # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting - x = rearrange(stft_repr, 'b f t c -> b t (f c)') + x = rearrange(stft_repr, "b f t c -> b t (f c)") x = self.band_split(x) @@ -495,37 +549,39 @@ def forward( for transformer_block in self.layers: if len(transformer_block) == 3: - linear_transformer, time_transformer, freq_transformer = transformer_block + linear_transformer, time_transformer, freq_transformer = ( + transformer_block + ) - x, ft_ps = pack([x], 'b * d') + x, ft_ps = pack([x], "b * d") x = linear_transformer(x) - x, = unpack(x, ft_ps, 'b * d') + (x,) = unpack(x, ft_ps, "b * d") else: time_transformer, freq_transformer = transformer_block - x = rearrange(x, 'b t f d -> b f t d') - x, ps = pack([x], '* t d') + x = rearrange(x, "b t f d -> b f t d") + x, ps = pack([x], "* t d") x = time_transformer(x) - x, = unpack(x, ps, '* t d') - x = rearrange(x, 'b f t d -> b t f d') - x, ps = pack([x], '* f d') + (x,) = unpack(x, ps, "* t d") + x = rearrange(x, "b f t d -> b t f d") + x, ps = pack([x], "* f d") x = freq_transformer(x) - x, = unpack(x, ps, '* f d') + (x,) = unpack(x, ps, "* f d") x = self.final_norm(x) num_stems = len(self.mask_estimators) mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) - mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2) + mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2) # modulate frequency representation - stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c") # complex number multiplication @@ -536,18 +592,29 @@ def forward( # istft - stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + stft_repr = rearrange( + stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels + ) # same as torch.stft() fix for MacOS MPS above try: - recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False) + recon_audio = torch.istft( + stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False + ) except: - recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False).to(device) - - recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems) + recon_audio = torch.istft( + stft_repr.cpu() if x_is_mps else stft_repr, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=False, + ).to(device) + + recon_audio = rearrange( + recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems + ) if num_stems == 1: - recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') + recon_audio = rearrange(recon_audio, "b 1 s t -> b s t") # if a target is passed in, calculate loss for learning @@ -558,33 +625,45 @@ def forward( assert target.ndim == 4 and target.shape[1] == self.num_stems if target.ndim == 2: - target = rearrange(target, '... t -> ... 1 t') + target = rearrange(target, "... t -> ... 1 t") - target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft + target = target[ + ..., : recon_audio.shape[-1] + ] # protect against lost length on istft loss = F.l1_loss(recon_audio, target) - multi_stft_resolution_loss = 0. + multi_stft_resolution_loss = 0.0 for window_size in self.multi_stft_resolutions_window_sizes: res_stft_kwargs = dict( - n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft + n_fft=max( + window_size, self.multi_stft_n_fft + ), # not sure what n_fft is across multi resolution stft win_length=window_size, return_complex=True, window=self.multi_stft_window_fn(window_size, device=device), **self.multi_stft_kwargs, ) - recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) - target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) + recon_Y = torch.stft( + rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs + ) + target_Y = torch.stft( + rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs + ) - multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss( + recon_Y, target_Y + ) - weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + weighted_multi_resolution_loss = ( + multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + ) total_loss = loss + weighted_multi_resolution_loss if not return_loss_breakdown: return total_loss - return total_loss, (loss, multi_stft_resolution_loss) \ No newline at end of file + return total_loss, (loss, multi_stft_resolution_loss) diff --git a/programs/music_separation_code/models/bs_roformer/mel_band_roformer.py b/programs/music_separation_code/models/bs_roformer/mel_band_roformer.py index 3ce7fe1..105ced1 100644 --- a/programs/music_separation_code/models/bs_roformer/mel_band_roformer.py +++ b/programs/music_separation_code/models/bs_roformer/mel_band_roformer.py @@ -20,6 +20,7 @@ # helper functions + def exists(val): return val is not None @@ -36,9 +37,9 @@ def unpack_one(t, ps, pattern): return unpack(t, ps, pattern)[0] -def pad_at_dim(t, pad, dim=-1, value=0.): - dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) - zeros = ((0, 0) * dims_from_right) +def pad_at_dim(t, pad, dim=-1, value=0.0): + dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = (0, 0) * dims_from_right return F.pad(t, (*zeros, *pad), value=value) @@ -48,10 +49,11 @@ def l2norm(t): # norm + class RMSNorm(Module): def __init__(self, dim): super().__init__() - self.scale = dim ** 0.5 + self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(dim)) def forward(self, x): @@ -60,13 +62,9 @@ def forward(self, x): # attention + class FeedForward(Module): - def __init__( - self, - dim, - mult=4, - dropout=0. - ): + def __init__(self, dim, mult=4, dropout=0.0): super().__init__() dim_inner = int(dim * mult) self.net = nn.Sequential( @@ -75,7 +73,7 @@ def __init__( nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim), - nn.Dropout(dropout) + nn.Dropout(dropout), ) def forward(self, x): @@ -84,17 +82,11 @@ def forward(self, x): class Attention(Module): def __init__( - self, - dim, - heads=8, - dim_head=64, - dropout=0., - rotary_embed=None, - flash=True + self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True ): super().__init__() self.heads = heads - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 dim_inner = heads * dim_head self.rotary_embed = rotary_embed @@ -107,14 +99,15 @@ def __init__( self.to_gates = nn.Linear(dim, heads) self.to_out = nn.Sequential( - nn.Linear(dim_inner, dim, bias=False), - nn.Dropout(dropout) + nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout) ) def forward(self, x): x = self.norm(x) - q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + q, k, v = rearrange( + self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads + ) if exists(self.rotary_embed): q = self.rotary_embed.rotate_queries_or_keys(q) @@ -123,9 +116,9 @@ def forward(self, x): out = self.attend(q, k, v) gates = self.to_gates(x) - out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + out = out * rearrange(gates, "b n h -> b h n 1").sigmoid() - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) @@ -135,42 +128,25 @@ class LinearAttention(Module): """ @beartype - def __init__( - self, - *, - dim, - dim_head=32, - heads=8, - scale=8, - flash=False, - dropout=0. - ): + def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0): super().__init__() dim_inner = dim_head * heads self.norm = RMSNorm(dim) self.to_qkv = nn.Sequential( nn.Linear(dim, dim_inner * 3, bias=False), - Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads) + Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads), ) self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) - self.attend = Attend( - scale=scale, - dropout=dropout, - flash=flash - ) + self.attend = Attend(scale=scale, dropout=dropout, flash=flash) self.to_out = nn.Sequential( - Rearrange('b h d n -> b n (h d)'), - nn.Linear(dim_inner, dim, bias=False) + Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False) ) - def forward( - self, - x - ): + def forward(self, x): x = self.norm(x) q, k, v = self.to_qkv(x) @@ -185,34 +161,47 @@ def forward( class Transformer(Module): def __init__( - self, - *, - dim, - depth, - dim_head=64, - heads=8, - attn_dropout=0., - ff_dropout=0., - ff_mult=4, - norm_output=True, - rotary_embed=None, - flash_attn=True, - linear_attn=False + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False, ): super().__init__() self.layers = ModuleList([]) for _ in range(depth): if linear_attn: - attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) + attn = LinearAttention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + flash=flash_attn, + ) else: - attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, - rotary_embed=rotary_embed, flash=flash_attn) - - self.layers.append(ModuleList([ - attn, - FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) - ])) + attn = Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + rotary_embed=rotary_embed, + flash=flash_attn, + ) + + self.layers.append( + ModuleList( + [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)] + ) + ) self.norm = RMSNorm(dim) if norm_output else nn.Identity() @@ -227,22 +216,16 @@ def forward(self, x): # bandsplit module + class BandSplit(Module): @beartype - def __init__( - self, - dim, - dim_inputs: Tuple[int, ...] - ): + def __init__(self, dim, dim_inputs: Tuple[int, ...]): super().__init__() self.dim_inputs = dim_inputs self.to_features = ModuleList([]) for dim_in in dim_inputs: - net = nn.Sequential( - RMSNorm(dim_in), - nn.Linear(dim_in, dim) - ) + net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim)) self.to_features.append(net) @@ -257,13 +240,7 @@ def forward(self, x): return torch.stack(outs, dim=-2) -def MLP( - dim_in, - dim_out, - dim_hidden=None, - depth=1, - activation=nn.Tanh -): +def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh): dim_hidden = default(dim_hidden, dim_in) net = [] @@ -284,13 +261,7 @@ def MLP( class MaskEstimator(Module): @beartype - def __init__( - self, - dim, - dim_inputs: Tuple[int, ...], - depth, - mlp_expansion_factor=4 - ): + def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4): super().__init__() self.dim_inputs = dim_inputs self.to_freqs = ModuleList([]) @@ -300,8 +271,7 @@ def __init__( net = [] mlp = nn.Sequential( - MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), - nn.GLU(dim=-1) + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1) ) self.to_freqs.append(mlp) @@ -320,40 +290,47 @@ def forward(self, x): # main class + class MelBandRoformer(Module): @beartype def __init__( - self, - dim, - *, - depth, - stereo=False, - num_stems=1, - time_transformer_depth=2, - freq_transformer_depth=2, - linear_transformer_depth=0, - num_bands=60, - dim_head=64, - heads=8, - attn_dropout=0.1, - ff_dropout=0.1, - flash_attn=True, - dim_freqs_in=1025, - sample_rate=44100, # needed for mel filter bank from librosa - stft_n_fft=2048, - stft_hop_length=512, - # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction - stft_win_length=2048, - stft_normalized=False, - stft_window_fn: Optional[Callable] = None, - mask_estimator_depth=1, - multi_stft_resolution_loss_weight=1., - multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), - multi_stft_hop_size=147, - multi_stft_normalized=False, - multi_stft_window_fn: Callable = torch.hann_window, - match_input_audio_length=False, # if True, pad output tensor to match length of input tensor + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + num_bands=60, + dim_head=64, + heads=8, + attn_dropout=0.1, + ff_dropout=0.1, + flash_attn=True, + dim_freqs_in=1025, + sample_rate=44100, # needed for mel filter bank from librosa + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=1, + multi_stft_resolution_loss_weight=1.0, + multi_stft_resolutions_window_sizes: Tuple[int, ...] = ( + 4096, + 2048, + 1024, + 512, + 256, + ), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + match_input_audio_length=False, # if True, pad output tensor to match length of input tensor ): super().__init__() @@ -369,7 +346,7 @@ def __init__( dim_head=dim_head, attn_dropout=attn_dropout, ff_dropout=ff_dropout, - flash_attn=flash_attn + flash_attn=flash_attn, ) time_rotary_embed = RotaryEmbedding(dim=dim_head) @@ -378,80 +355,104 @@ def __init__( for _ in range(depth): tran_modules = [] if linear_transformer_depth > 0: - tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) + tran_modules.append( + Transformer( + depth=linear_transformer_depth, + linear_attn=True, + **transformer_kwargs, + ) + ) tran_modules.append( - Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) + Transformer( + depth=time_transformer_depth, + rotary_embed=time_rotary_embed, + **transformer_kwargs, + ) ) tran_modules.append( - Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) + Transformer( + depth=freq_transformer_depth, + rotary_embed=freq_rotary_embed, + **transformer_kwargs, + ) ) self.layers.append(nn.ModuleList(tran_modules)) - self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) + self.stft_window_fn = partial( + default(stft_window_fn, torch.hann_window), stft_win_length + ) self.stft_kwargs = dict( n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, - normalized=stft_normalized + normalized=stft_normalized, ) - freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1] + freqs = torch.stft( + torch.randn(1, 4096), + **self.stft_kwargs, + window=torch.ones(stft_n_fft), + return_complex=True, + ).shape[1] # create mel filter bank # with librosa.filters.mel as in section 2 of paper - mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands) + mel_filter_bank_numpy = filters.mel( + sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands + ) mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy) # for some reason, it doesn't include the first freq? just force a value for now - mel_filter_bank[0][0] = 1. + mel_filter_bank[0][0] = 1.0 # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position, # so let's force a positive value - mel_filter_bank[-1, -1] = 1. + mel_filter_bank[-1, -1] = 1.0 # binary as in paper (then estimated masks are averaged for overlapping regions) freqs_per_band = mel_filter_bank > 0 - assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now' + assert freqs_per_band.any( + dim=0 + ).all(), "all frequencies need to be covered by all bands for now" - repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands) + repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands) freq_indices = repeated_freq_indices[freqs_per_band] if stereo: - freq_indices = repeat(freq_indices, 'f -> f s', s=2) + freq_indices = repeat(freq_indices, "f -> f s", s=2) freq_indices = freq_indices * 2 + torch.arange(2) - freq_indices = rearrange(freq_indices, 'f s -> (f s)') + freq_indices = rearrange(freq_indices, "f s -> (f s)") - self.register_buffer('freq_indices', freq_indices, persistent=False) - self.register_buffer('freqs_per_band', freqs_per_band, persistent=False) + self.register_buffer("freq_indices", freq_indices, persistent=False) + self.register_buffer("freqs_per_band", freqs_per_band, persistent=False) - num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum') - num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum') + num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum") + num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum") - self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False) - self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False) + self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False) + self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False) # band split and mask estimator - freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist()) - - self.band_split = BandSplit( - dim=dim, - dim_inputs=freqs_per_bands_with_complex + freqs_per_bands_with_complex = tuple( + 2 * f * self.audio_channels for f in num_freqs_per_band.tolist() ) + self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex) + self.mask_estimators = nn.ModuleList([]) for _ in range(num_stems): mask_estimator = MaskEstimator( dim=dim, dim_inputs=freqs_per_bands_with_complex, - depth=mask_estimator_depth + depth=mask_estimator_depth, ) self.mask_estimators.append(mask_estimator) @@ -464,18 +465,12 @@ def __init__( self.multi_stft_window_fn = multi_stft_window_fn self.multi_stft_kwargs = dict( - hop_length=multi_stft_hop_size, - normalized=multi_stft_normalized + hop_length=multi_stft_hop_size, normalized=multi_stft_normalized ) self.match_input_audio_length = match_input_audio_length - def forward( - self, - raw_audio, - target=None, - return_loss_breakdown=False - ): + def forward(self, raw_audio, target=None, return_loss_breakdown=False): """ einops @@ -491,27 +486,31 @@ def forward( device = raw_audio.device if raw_audio.ndim == 2: - raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + raw_audio = rearrange(raw_audio, "b t -> b 1 t") batch, channels, raw_audio_length = raw_audio.shape istft_length = raw_audio_length if self.match_input_audio_length else None assert (not self.stereo and channels == 1) or ( - self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' + self.stereo and channels == 2 + ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)" # to stft - raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t") stft_window = self.stft_window_fn(device=device) - stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + stft_repr = torch.stft( + raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True + ) stft_repr = torch.view_as_real(stft_repr) - stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') - stft_repr = rearrange(stft_repr, - 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c") + stft_repr = rearrange( + stft_repr, "b s f t c -> b (f s) t c" + ) # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting # index out all frequencies for all frequency ranges across bands ascending in one go @@ -523,7 +522,7 @@ def forward( # fold the complex (real and imag) into the frequencies dimension - x = rearrange(x, 'b f t c -> b t (f c)') + x = rearrange(x, "b f t c -> b t (f c)") x = self.band_split(x) @@ -532,35 +531,37 @@ def forward( for transformer_block in self.layers: if len(transformer_block) == 3: - linear_transformer, time_transformer, freq_transformer = transformer_block + linear_transformer, time_transformer, freq_transformer = ( + transformer_block + ) - x, ft_ps = pack([x], 'b * d') + x, ft_ps = pack([x], "b * d") x = linear_transformer(x) - x, = unpack(x, ft_ps, 'b * d') + (x,) = unpack(x, ft_ps, "b * d") else: time_transformer, freq_transformer = transformer_block - x = rearrange(x, 'b t f d -> b f t d') - x, ps = pack([x], '* t d') + x = rearrange(x, "b t f d -> b f t d") + x, ps = pack([x], "* t d") x = time_transformer(x) - x, = unpack(x, ps, '* t d') - x = rearrange(x, 'b f t d -> b t f d') - x, ps = pack([x], '* f d') + (x,) = unpack(x, ps, "* t d") + x = rearrange(x, "b f t d -> b t f d") + x, ps = pack([x], "* f d") x = freq_transformer(x) - x, = unpack(x, ps, '* f d') + (x,) = unpack(x, ps, "* f d") num_stems = len(self.mask_estimators) masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) - masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2) + masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2) # modulate frequency representation - stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c") # complex number multiplication @@ -571,12 +572,20 @@ def forward( # need to average the estimated mask for the overlapped frequencies - scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) + scatter_indices = repeat( + self.freq_indices, + "f -> b n f t", + b=batch, + n=num_stems, + t=stft_repr.shape[-1], + ) - stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems) - masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks) + stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems) + masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_( + 2, scatter_indices, masks + ) - denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels) + denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels) masks_averaged = masks_summed / denom.clamp(min=1e-8) @@ -586,15 +595,28 @@ def forward( # istft - stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + stft_repr = rearrange( + stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels + ) - recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, - length=istft_length) + recon_audio = torch.istft( + stft_repr, + **self.stft_kwargs, + window=stft_window, + return_complex=False, + length=istft_length, + ) - recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems) + recon_audio = rearrange( + recon_audio, + "(b n s) t -> b n s t", + b=batch, + s=self.audio_channels, + n=num_stems, + ) if num_stems == 1: - recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') + recon_audio = rearrange(recon_audio, "b 1 s t -> b s t") # if a target is passed in, calculate loss for learning @@ -605,29 +627,41 @@ def forward( assert target.ndim == 4 and target.shape[1] == self.num_stems if target.ndim == 2: - target = rearrange(target, '... t -> ... 1 t') + target = rearrange(target, "... t -> ... 1 t") - target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft + target = target[ + ..., : recon_audio.shape[-1] + ] # protect against lost length on istft loss = F.l1_loss(recon_audio, target) - multi_stft_resolution_loss = 0. + multi_stft_resolution_loss = 0.0 for window_size in self.multi_stft_resolutions_window_sizes: res_stft_kwargs = dict( - n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft + n_fft=max( + window_size, self.multi_stft_n_fft + ), # not sure what n_fft is across multi resolution stft win_length=window_size, return_complex=True, window=self.multi_stft_window_fn(window_size, device=device), **self.multi_stft_kwargs, ) - recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) - target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) + recon_Y = torch.stft( + rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs + ) + target_Y = torch.stft( + rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs + ) - multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss( + recon_Y, target_Y + ) - weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + weighted_multi_resolution_loss = ( + multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + ) total_loss = loss + weighted_multi_resolution_loss diff --git a/programs/music_separation_code/models/demucs4ht.py b/programs/music_separation_code/models/demucs4ht.py index 06c279c..bf87cb1 100644 --- a/programs/music_separation_code/models/demucs4ht.py +++ b/programs/music_separation_code/models/demucs4ht.py @@ -1,11 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from functools import partial import numpy as np import torch -import json from omegaconf import OmegaConf from demucs.demucs import Demucs from demucs.hdemucs import HDemucs @@ -317,7 +315,7 @@ def __init__( dconv=dconv_mode & 1, context=context_enc, empty=last_freq, - **kwt + **kwt, ) self.tencoder.append(tenc) @@ -337,7 +335,7 @@ def __init__( dconv=dconv_mode & 2, last=index == 0, context=context, - **kw_dec + **kw_dec, ) if multi: dec = MultiWrap(dec, multi_freqs) @@ -349,7 +347,7 @@ def __init__( empty=last_freq, last=index == 0, context=context, - **kwt + **kwt, ) self.tdecoder.insert(0, tdec) self.decoder.insert(0, dec) @@ -443,7 +441,7 @@ def _spec(self, x): z = spectro(x, nfft, hl)[..., :-1, :] assert z.shape[-1] == le + 4, (z.shape, x.shape, le) - z = z[..., 2: 2 + le] + z = z[..., 2 : 2 + le] return z def _ispec(self, z, length=None, scale=0): @@ -453,7 +451,7 @@ def _ispec(self, z, length=None, scale=0): pad = hl // 2 * 3 le = hl * int(math.ceil(length / hl)) + 2 * pad x = ispectro(z, hl, length=le) - x = x[..., pad: pad + length] + x = x[..., pad : pad + length] return x def _magnitude(self, z): @@ -527,8 +525,9 @@ def valid_length(self, length: int): training_length = int(self.segment * self.samplerate) if training_length < length: raise ValueError( - f"Given length {length} is longer than " - f"training length {training_length}") + f"Given length {length} is longer than " + f"training length {training_length}" + ) return training_length def cac2cws(self, x): @@ -695,19 +694,17 @@ def forward(self, mix): def get_model(args): extra = { - 'sources': list(args.training.instruments), - 'audio_channels': args.training.channels, - 'samplerate': args.training.samplerate, + "sources": list(args.training.instruments), + "audio_channels": args.training.channels, + "samplerate": args.training.samplerate, # 'segment': args.model_segment or 4 * args.dset.segment, - 'segment': args.training.segment, + "segment": args.training.segment, } klass = { - 'demucs': Demucs, - 'hdemucs': HDemucs, - 'htdemucs': HTDemucs, + "demucs": Demucs, + "hdemucs": HDemucs, + "htdemucs": HTDemucs, }[args.model] kw = OmegaConf.to_container(getattr(args, args.model), resolve=True) model = klass(**extra, **kw) return model - - diff --git a/programs/music_separation_code/models/mdx23c_tfc_tdf_v3.py b/programs/music_separation_code/models/mdx23c_tfc_tdf_v3.py index caa818c..ad89c85 100644 --- a/programs/music_separation_code/models/mdx23c_tfc_tdf_v3.py +++ b/programs/music_separation_code/models/mdx23c_tfc_tdf_v3.py @@ -22,12 +22,14 @@ def __call__(self, x): hop_length=self.hop_length, window=window, center=True, - return_complex=True + return_complex=True, ) x = torch.view_as_real(x) x = x.permute([0, 3, 1, 2]) - x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) - return x[..., :self.dim_f, :] + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape( + [*batch_dims, c * 2, -1, x.shape[-1]] + ) + return x[..., : self.dim_f, :] def inverse(self, x): window = self.window.to(x.device) @@ -38,20 +40,22 @@ def inverse(self, x): x = torch.cat([x, f_pad], -2) x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) x = x.permute([0, 2, 3, 1]) - x = x[..., 0] + x[..., 1] * 1.j - x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True) + x = x[..., 0] + x[..., 1] * 1.0j + x = torch.istft( + x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True + ) x = x.reshape([*batch_dims, 2, -1]) return x def get_norm(norm_type): def norm(c, norm_type): - if norm_type == 'BatchNorm': + if norm_type == "BatchNorm": return nn.BatchNorm2d(c) - elif norm_type == 'InstanceNorm': + elif norm_type == "InstanceNorm": return nn.InstanceNorm2d(c, affine=True) - elif 'GroupNorm' in norm_type: - g = int(norm_type.replace('GroupNorm', '')) + elif "GroupNorm" in norm_type: + g = int(norm_type.replace("GroupNorm", "")) return nn.GroupNorm(num_groups=g, num_channels=c) else: return nn.Identity() @@ -60,12 +64,12 @@ def norm(c, norm_type): def get_act(act_type): - if act_type == 'gelu': + if act_type == "gelu": return nn.GELU() - elif act_type == 'relu': + elif act_type == "relu": return nn.ReLU() - elif act_type[:3] == 'elu': - alpha = float(act_type.replace('elu', '')) + elif act_type[:3] == "elu": + alpha = float(act_type.replace("elu", "")) return nn.ELU(alpha) else: raise Exception @@ -77,7 +81,13 @@ def __init__(self, in_c, out_c, scale, norm, act): self.conv = nn.Sequential( norm(in_c), act, - nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + nn.ConvTranspose2d( + in_channels=in_c, + out_channels=out_c, + kernel_size=scale, + stride=scale, + bias=False, + ), ) def forward(self, x): @@ -90,7 +100,13 @@ def __init__(self, in_c, out_c, scale, norm, act): self.conv = nn.Sequential( norm(in_c), act, - nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + nn.Conv2d( + in_channels=in_c, + out_channels=out_c, + kernel_size=scale, + stride=scale, + bias=False, + ), ) def forward(self, x): @@ -146,7 +162,9 @@ def __init__(self, config): norm = get_norm(norm_type=config.model.norm) act = get_act(act_type=config.model.act) - self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) + self.num_target_instruments = ( + 1 if config.training.target_instrument else len(config.training.instruments) + ) self.num_subbands = config.model.num_subbands dim_c = self.num_subbands * config.audio.num_channels * 2 @@ -183,7 +201,7 @@ def __init__(self, config): self.final_conv = nn.Sequential( nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), act, - nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False), ) self.stft = STFT(config.audio) diff --git a/programs/music_separation_code/models/scnet/scnet.py b/programs/music_separation_code/models/scnet/scnet.py index b27704d..37bdaad 100644 --- a/programs/music_separation_code/models/scnet/scnet.py +++ b/programs/music_separation_code/models/scnet/scnet.py @@ -3,7 +3,6 @@ import torch.nn.functional as F from collections import deque from .separation import SeparationNet -import typing as tp import math @@ -21,7 +20,7 @@ class ConvolutionModule(nn.Module): depth (int): number of layers in the residual branch. Each layer has its own compress (float): amount of channel compression. kernel (int): kernel size for the convolutions. - """ + """ def __init__(self, channels, depth=2, compress=4, kernel=3): super().__init__() @@ -31,12 +30,18 @@ def __init__(self, channels, depth=2, compress=4, kernel=3): norm = lambda d: nn.GroupNorm(1, d) self.layers = nn.ModuleList([]) for _ in range(self.depth): - padding = (kernel // 2) + padding = kernel // 2 mods = [ norm(channels), nn.Conv1d(channels, hidden_size * 2, kernel, padding=padding), nn.GLU(1), - nn.Conv1d(hidden_size, hidden_size, kernel, padding=padding, groups=hidden_size), + nn.Conv1d( + hidden_size, + hidden_size, + kernel, + padding=padding, + groups=hidden_size, + ), norm(hidden_size), Swish(), nn.Conv1d(hidden_size, channels, 1), @@ -63,7 +68,9 @@ class FusionLayer(nn.Module): def __init__(self, channels, kernel_size=3, stride=1, padding=1): super(FusionLayer, self).__init__() - self.conv = nn.Conv2d(channels * 2, channels * 2, kernel_size, stride=stride, padding=padding) + self.conv = nn.Conv2d( + channels * 2, channels * 2, kernel_size, stride=stride, padding=padding + ) def forward(self, x, skip=None): if skip is not None: @@ -96,13 +103,20 @@ def __init__(self, channels_in, channels_out, band_configs): self.kernels = [] for config in band_configs.values(): self.convs.append( - nn.Conv2d(channels_in, channels_out, (config['kernel'], 1), (config['stride'], 1), (0, 0))) - self.strides.append(config['stride']) - self.kernels.append(config['kernel']) + nn.Conv2d( + channels_in, + channels_out, + (config["kernel"], 1), + (config["stride"], 1), + (0, 0), + ) + ) + self.strides.append(config["stride"]) + self.kernels.append(config["kernel"]) # Saving rate proportions for determining splits - self.SR_low = band_configs['low']['SR'] - self.SR_mid = band_configs['mid']['SR'] + self.SR_low = band_configs["low"]["SR"] + self.SR_mid = band_configs["mid"]["SR"] def forward(self, x): B, C, Fr, T = x.shape @@ -110,13 +124,15 @@ def forward(self, x): splits = [ (0, math.ceil(Fr * self.SR_low)), (math.ceil(Fr * self.SR_low), math.ceil(Fr * (self.SR_low + self.SR_mid))), - (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr) + (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr), ] # Processing each band with the corresponding convolution outputs = [] original_lengths = [] - for conv, stride, kernel, (start, end) in zip(self.convs, self.strides, self.kernels, splits): + for conv, stride, kernel, (start, end) in zip( + self.convs, self.strides, self.kernels, splits + ): extracted = x[:, :, start:end, :] original_lengths.append(end - start) current_length = extracted.shape[2] @@ -151,10 +167,17 @@ def __init__(self, channels_in, channels_out, band_configs): super(SUlayer, self).__init__() # Initializing convolutional layers for each band - self.convtrs = nn.ModuleList([ - nn.ConvTranspose2d(channels_in, channels_out, [config['kernel'], 1], [config['stride'], 1]) - for _, config in band_configs.items() - ]) + self.convtrs = nn.ModuleList( + [ + nn.ConvTranspose2d( + channels_in, + channels_out, + [config["kernel"], 1], + [config["stride"], 1], + ) + for _, config in band_configs.items() + ] + ) def forward(self, x, lengths, origin_lengths): B, C, Fr, T = x.shape @@ -162,7 +185,7 @@ def forward(self, x, lengths, origin_lengths): splits = [ (0, lengths[0]), (lengths[0], lengths[0] + lengths[1]), - (lengths[0] + lengths[1], None) + (lengths[0] + lengths[1], None), ] # Processing each band with the corresponding convolution outputs = [] @@ -173,7 +196,7 @@ def forward(self, x, lengths, origin_lengths): dist = abs(origin_lengths[idx] - current_Fr_length) // 2 # Trim the output to the original length symmetrically - trimmed_out = out[:, :, dist:dist + origin_lengths[idx], :] + trimmed_out = out[:, :, dist : dist + origin_lengths[idx], :] outputs.append(trimmed_out) @@ -195,16 +218,26 @@ class SDblock(nn.Module): - depths (list of int): List specifying the convolution depths for low, mid, and high frequency bands. """ - def __init__(self, channels_in, channels_out, band_configs={}, conv_config={}, depths=[3, 2, 1], kernel_size=3): + def __init__( + self, + channels_in, + channels_out, + band_configs={}, + conv_config={}, + depths=[3, 2, 1], + kernel_size=3, + ): super(SDblock, self).__init__() self.SDlayer = SDlayer(channels_in, channels_out, band_configs) # Dynamically create convolution modules for each band based on depths - self.conv_modules = nn.ModuleList([ - ConvolutionModule(channels_out, depth, **conv_config) for depth in depths - ]) + self.conv_modules = nn.ModuleList( + [ConvolutionModule(channels_out, depth, **conv_config) for depth in depths] + ) # Set the kernel_size to an odd number. - self.globalconv = nn.Conv2d(channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2) + self.globalconv = nn.Conv2d( + channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2 + ) def forward(self, x): bands, original_lengths = self.SDlayer(x) @@ -216,7 +249,6 @@ def forward(self, x): .permute(0, 2, 1, 3) ) for conv, band in zip(self.conv_modules, bands) - ] lengths = [band.size(-2) for band in bands] full_band = torch.cat(bands, dim=2) @@ -250,47 +282,54 @@ class SCNet(nn.Module): """ - def __init__(self, - sources=['drums', 'bass', 'other', 'vocals'], - audio_channels=2, - # Main structure - dims=[4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large - # STFT - nfft=4096, - hop_size=1024, - win_size=4096, - normalized=True, - # SD/SU layer - band_SR=[0.175, 0.392, 0.433], - band_stride=[1, 4, 16], - band_kernel=[3, 4, 16], - # Convolution Module - conv_depths=[3, 2, 1], - compress=4, - conv_kernel=3, - # Dual-path RNN - num_dplayer=6, - expand=1, - ): + def __init__( + self, + sources=["drums", "bass", "other", "vocals"], + audio_channels=2, + # Main structure + dims=[4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large + # STFT + nfft=4096, + hop_size=1024, + win_size=4096, + normalized=True, + # SD/SU layer + band_SR=[0.175, 0.392, 0.433], + band_stride=[1, 4, 16], + band_kernel=[3, 4, 16], + # Convolution Module + conv_depths=[3, 2, 1], + compress=4, + conv_kernel=3, + # Dual-path RNN + num_dplayer=6, + expand=1, + ): super().__init__() self.sources = sources self.audio_channels = audio_channels self.dims = dims - band_keys = ['low', 'mid', 'high'] - self.band_configs = {band_keys[i]: {'SR': band_SR[i], 'stride': band_stride[i], 'kernel': band_kernel[i]} for i - in range(len(band_keys))} + band_keys = ["low", "mid", "high"] + self.band_configs = { + band_keys[i]: { + "SR": band_SR[i], + "stride": band_stride[i], + "kernel": band_kernel[i], + } + for i in range(len(band_keys)) + } self.hop_length = hop_size self.conv_config = { - 'compress': compress, - 'kernel': conv_kernel, + "compress": compress, + "kernel": conv_kernel, } self.stft_config = { - 'n_fft': nfft, - 'hop_length': hop_size, - 'win_length': win_size, - 'center': True, - 'normalized': normalized + "n_fft": nfft, + "hop_length": hop_size, + "win_length": win_size, + "center": True, + "normalized": normalized, } self.encoder = nn.ModuleList() @@ -302,7 +341,7 @@ def __init__(self, channels_out=dims[index + 1], band_configs=self.band_configs, conv_config=self.conv_config, - depths=conv_depths + depths=conv_depths, ) self.encoder.append(enc) @@ -310,9 +349,11 @@ def __init__(self, FusionLayer(channels=dims[index + 1]), SUlayer( channels_in=dims[index + 1], - channels_out=dims[index] if index != 0 else dims[index] * len(sources), + channels_out=( + dims[index] if index != 0 else dims[index] * len(sources) + ), band_configs=self.band_configs, - ) + ), ) self.decoder.insert(0, dec) @@ -337,8 +378,12 @@ def forward(self, x): x = x.reshape(-1, L) x = torch.stft(x, **self.stft_config, return_complex=True) x = torch.view_as_real(x) - x = x.permute(0, 3, 1, 2).reshape(x.shape[0] // self.audio_channels, x.shape[3] * self.audio_channels, - x.shape[1], x.shape[2]) + x = x.permute(0, 3, 1, 2).reshape( + x.shape[0] // self.audio_channels, + x.shape[3] * self.audio_channels, + x.shape[1], + x.shape[2], + ) B, C, Fr, T = x.shape diff --git a/programs/music_separation_code/models/scnet/separation.py b/programs/music_separation_code/models/scnet/separation.py index d902dac..8965e2c 100644 --- a/programs/music_separation_code/models/scnet/separation.py +++ b/programs/music_separation_code/models/scnet/separation.py @@ -21,8 +21,8 @@ def forward(self, x): # B, C, F, T = x.shape if self.inverse: x = x.float() - x_r = x[:, :self.channels // 2, :, :] - x_i = x[:, self.channels // 2:, :, :] + x_r = x[:, : self.channels // 2, :, :] + x_i = x[:, self.channels // 2 :, :, :] x = torch.complex(x_r, x_i) x = torch.fft.irfft(x, dim=3, norm="ortho") else: @@ -51,12 +51,22 @@ def __init__(self, d_model, expand, bidirectional=True): self.hidden_size = d_model * expand self.bidirectional = bidirectional # Initialize LSTM layers and normalization layers - self.lstm_layers = nn.ModuleList([self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)]) - self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_size * 2, self.d_model) for _ in range(2)]) + self.lstm_layers = nn.ModuleList( + [self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)] + ) + self.linear_layers = nn.ModuleList( + [nn.Linear(self.hidden_size * 2, self.d_model) for _ in range(2)] + ) self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)]) def _init_lstm_layer(self, d_model, hidden_size): - return LSTM(d_model, hidden_size, num_layers=1, bidirectional=self.bidirectional, batch_first=True) + return LSTM( + d_model, + hidden_size, + num_layers=1, + bidirectional=self.bidirectional, + batch_first=True, + ) def forward(self, x): B, C, F, T = x.shape @@ -98,13 +108,19 @@ def __init__(self, channels, expand=1, num_layers=6): self.num_layers = num_layers - self.dp_modules = nn.ModuleList([ - DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) for i in range(num_layers) - ]) - - self.feature_conversion = nn.ModuleList([ - FeatureConversion(channels * 2, inverse=False if i % 2 == 0 else True) for i in range(num_layers) - ]) + self.dp_modules = nn.ModuleList( + [ + DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) + for i in range(num_layers) + ] + ) + + self.feature_conversion = nn.ModuleList( + [ + FeatureConversion(channels * 2, inverse=False if i % 2 == 0 else True) + for i in range(num_layers) + ] + ) def forward(self, x): for i in range(self.num_layers): diff --git a/programs/music_separation_code/models/scnet_unofficial/__init__.py b/programs/music_separation_code/models/scnet_unofficial/__init__.py index 6d034d3..298d993 100644 --- a/programs/music_separation_code/models/scnet_unofficial/__init__.py +++ b/programs/music_separation_code/models/scnet_unofficial/__init__.py @@ -1 +1 @@ -from models.scnet_unofficial.scnet import SCNet \ No newline at end of file +from models.scnet_unofficial.scnet import SCNet diff --git a/programs/music_separation_code/models/scnet_unofficial/modules/dualpath_rnn.py b/programs/music_separation_code/models/scnet_unofficial/modules/dualpath_rnn.py index 2dfcdbc..644d05a 100644 --- a/programs/music_separation_code/models/scnet_unofficial/modules/dualpath_rnn.py +++ b/programs/music_separation_code/models/scnet_unofficial/modules/dualpath_rnn.py @@ -2,10 +2,11 @@ import torch.nn as nn import torch.nn.functional as Func + class RMSNorm(nn.Module): def __init__(self, dim): super().__init__() - self.scale = dim ** 0.5 + self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(dim)) def forward(self, x): @@ -17,11 +18,8 @@ def __init__(self, d_model, d_state, d_conv, d_expand): super().__init__() self.norm = RMSNorm(dim=d_model) self.mamba = Mamba( - d_model=d_model, - d_state=d_state, - d_conv=d_conv, - d_expand=d_expand - ) + d_model=d_model, d_state=d_state, d_conv=d_conv, d_expand=d_expand + ) def forward(self, x): x = x + self.mamba(self.norm(x)) @@ -128,7 +126,7 @@ def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor: x = x.reshape(B, F, T, D // 2, 2) x = torch.view_as_complex(x) x = torch.fft.irfft(x, n=time_dim, dim=2) - + x = x.to(dtype) return x @@ -166,11 +164,10 @@ def __init__( n_layers: int, input_dim: int, hidden_dim: int, - use_mamba: bool = False, d_state: int = 16, d_conv: int = 4, - d_expand: int = 2 + d_expand: int = 2, ): """ Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension. @@ -179,9 +176,20 @@ def __init__( if use_mamba: from mamba_ssm.modules.mamba_simple import Mamba + net = MambaModule - dkwargs = {"d_model": input_dim, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand} - ukwargs = {"d_model": input_dim * 2, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand * 2} + dkwargs = { + "d_model": input_dim, + "d_state": d_state, + "d_conv": d_conv, + "d_expand": d_expand, + } + ukwargs = { + "d_model": input_dim * 2, + "d_state": d_state, + "d_conv": d_conv, + "d_expand": d_expand * 2, + } else: net = RNNModule dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim} @@ -190,13 +198,15 @@ def __init__( self.layers = nn.ModuleList() for i in range(1, n_layers + 1): kwargs = dkwargs if i % 2 == 1 else ukwargs - layer = nn.ModuleList([ - net(**kwargs), - net(**kwargs), - RFFTModule(inverse=(i % 2 == 0)), - ]) + layer = nn.ModuleList( + [ + net(**kwargs), + net(**kwargs), + RFFTModule(inverse=(i % 2 == 0)), + ] + ) self.layers.append(layer) - + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs forward pass through the DualPathRNN. @@ -224,5 +234,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 1, 3) x = rfft_layer(x, time_dim) - + return x diff --git a/programs/music_separation_code/models/scnet_unofficial/scnet.py b/programs/music_separation_code/models/scnet_unofficial/scnet.py index d076f85..d6dcf72 100644 --- a/programs/music_separation_code/models/scnet_unofficial/scnet.py +++ b/programs/music_separation_code/models/scnet_unofficial/scnet.py @@ -1,8 +1,8 @@ -''' +""" SCNet - great paper, great implementation https://arxiv.org/pdf/2401.13276.pdf https://github.com/amanteur/SCNet-PyTorch -''' +""" from typing import List @@ -20,6 +20,7 @@ from beartype.typing import Tuple, Optional, List, Callable from beartype import beartype + def exists(val): return val is not None @@ -39,7 +40,7 @@ def unpack_one(t, ps, pattern): class RMSNorm(nn.Module): def __init__(self, dim): super().__init__() - self.scale = dim ** 0.5 + self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(dim)) def forward(self, x): @@ -48,20 +49,13 @@ def forward(self, x): class BandSplit(nn.Module): @beartype - def __init__( - self, - dim, - dim_inputs: Tuple[int, ...] - ): + def __init__(self, dim, dim_inputs: Tuple[int, ...]): super().__init__() self.dim_inputs = dim_inputs self.to_features = ModuleList([]) for dim_in in dim_inputs: - net = nn.Sequential( - RMSNorm(dim_in), - nn.Linear(dim_in, dim) - ) + net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim)) self.to_features.append(net) @@ -107,6 +101,7 @@ class SCNet(nn.Module): C is channel dim (mono / stereo), T is sequence length, """ + @beartype def __init__( self, @@ -122,7 +117,7 @@ def __init__( win_length: int = 4096, stft_window_fn: Optional[Callable] = None, stft_normalized: bool = False, - **kwargs + **kwargs, ): """ Initializes SCNet with input parameters. @@ -156,7 +151,7 @@ def __init__( n_layers=n_rnn_layers, input_dim=dims[-1], hidden_dim=rnn_hidden_dim, - **kwargs + **kwargs, ) self.su_blocks = nn.ModuleList( SUBlock( @@ -174,10 +169,12 @@ def __init__( n_fft=n_fft, hop_length=hop_length, win_length=win_length, - normalized=stft_normalized + normalized=stft_normalized, ) - self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), win_length) + self.stft_window_fn = partial( + default(stft_window_fn, torch.hann_window), win_length + ) self.n_sources = n_sources self.hop_length = hop_length @@ -208,19 +205,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: stft_window = self.stft_window_fn(device=device) if x.ndim == 2: - x = rearrange(x, 'b t -> b 1 t') + x = rearrange(x, "b t -> b 1 t") c = x.shape[1] - + stft_pad = self.hop_length - x.shape[-1] % self.hop_length x = F.pad(x, (0, stft_pad)) # stft - x, ps = pack_one(x, '* t') + x, ps = pack_one(x, "* t") x = torch.stft(x, **self.stft_kwargs, window=stft_window, return_complex=True) x = torch.view_as_real(x) - x = unpack_one(x, ps, '* c f t') - x = rearrange(x, 'b c f t r -> b f t (c r)') + x = unpack_one(x, ps, "* c f t") + x = rearrange(x, "b c f t r -> b f t (c r)") # encoder part x_skips = [] @@ -236,14 +233,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = su_block(x, x_skip) # istft - x = rearrange(x, 'b f t (c r n) -> b n c f t r', c=c, n=self.n_sources, r=2) + x = rearrange(x, "b f t (c r n) -> b n c f t r", c=c, n=self.n_sources, r=2) x = x.contiguous() - x = torch.view_as_complex(x) - x = rearrange(x, 'b n c f t -> (b n c) f t') + x = torch.view_as_complex(x) + x = rearrange(x, "b n c f t -> (b n c) f t") x = torch.istft(x, **self.stft_kwargs, window=stft_window, return_complex=False) - x = rearrange(x, '(b n c) t -> b n c t', c=c, n=self.n_sources) + x = rearrange(x, "(b n c) t -> b n c t", c=c, n=self.n_sources) - x = x[..., :-stft_pad] + x = x[..., :-stft_pad] return x diff --git a/programs/music_separation_code/models/scnet_unofficial/utils.py b/programs/music_separation_code/models/scnet_unofficial/utils.py index aae1afc..d236d49 100644 --- a/programs/music_separation_code/models/scnet_unofficial/utils.py +++ b/programs/music_separation_code/models/scnet_unofficial/utils.py @@ -1,8 +1,8 @@ -''' +""" SCNet - great paper, great implementation https://arxiv.org/pdf/2401.13276.pdf https://github.com/amanteur/SCNet-PyTorch -''' +""" from typing import List, Tuple, Union @@ -10,7 +10,7 @@ def create_intervals( - splits: List[Union[float, int]] + splits: List[Union[float, int]], ) -> List[Union[Tuple[float, float], Tuple[int, int]]]: """ Create intervals based on splits provided. @@ -132,4 +132,4 @@ def compute_gcr(subband_shapes: List[List[int]]) -> float: gcr = torch.stack( [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)] ).mean() - return float(gcr) \ No newline at end of file + return float(gcr) diff --git a/programs/music_separation_code/models/segm_models.py b/programs/music_separation_code/models/segm_models.py index cf858ec..537d94a 100644 --- a/programs/music_separation_code/models/segm_models.py +++ b/programs/music_separation_code/models/segm_models.py @@ -21,12 +21,14 @@ def __call__(self, x): hop_length=self.hop_length, window=window, center=True, - return_complex=True + return_complex=True, ) x = torch.view_as_real(x) x = x.permute([0, 3, 1, 2]) - x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) - return x[..., :self.dim_f, :] + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape( + [*batch_dims, c * 2, -1, x.shape[-1]] + ) + return x[..., : self.dim_f, :] def inverse(self, x): window = self.window.to(x.device) @@ -37,25 +39,21 @@ def inverse(self, x): x = torch.cat([x, f_pad], -2) x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) x = x.permute([0, 2, 3, 1]) - x = x[..., 0] + x[..., 1] * 1.j + x = x[..., 0] + x[..., 1] * 1.0j x = torch.istft( - x, - n_fft=self.n_fft, - hop_length=self.hop_length, - window=window, - center=True + x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True ) x = x.reshape([*batch_dims, 2, -1]) return x def get_act(act_type): - if act_type == 'gelu': + if act_type == "gelu": return nn.GELU() - elif act_type == 'relu': + elif act_type == "relu": return nn.ReLU() - elif act_type[:3] == 'elu': - alpha = float(act_type.replace('elu', '')) + elif act_type[:3] == "elu": + alpha = float(act_type.replace("elu", "")) return nn.ELU(alpha) else: raise Exception @@ -64,7 +62,7 @@ def get_act(act_type): def get_decoder(config, c): decoder = None decoder_options = dict() - if config.model.decoder_type == 'unet': + if config.model.decoder_type == "unet": try: decoder_options = dict(config.decoder_unet) except: @@ -76,7 +74,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'fpn': + elif config.model.decoder_type == "fpn": try: decoder_options = dict(config.decoder_fpn) except: @@ -88,7 +86,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'unet++': + elif config.model.decoder_type == "unet++": try: decoder_options = dict(config.decoder_unet_plus_plus) except: @@ -100,7 +98,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'manet': + elif config.model.decoder_type == "manet": try: decoder_options = dict(config.decoder_manet) except: @@ -112,7 +110,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'linknet': + elif config.model.decoder_type == "linknet": try: decoder_options = dict(config.decoder_linknet) except: @@ -124,7 +122,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'pspnet': + elif config.model.decoder_type == "pspnet": try: decoder_options = dict(config.decoder_pspnet) except: @@ -136,7 +134,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'pspnet': + elif config.model.decoder_type == "pspnet": try: decoder_options = dict(config.decoder_pspnet) except: @@ -148,7 +146,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'pan': + elif config.model.decoder_type == "pan": try: decoder_options = dict(config.decoder_pan) except: @@ -160,7 +158,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'deeplabv3': + elif config.model.decoder_type == "deeplabv3": try: decoder_options = dict(config.decoder_deeplabv3) except: @@ -172,7 +170,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'deeplabv3plus': + elif config.model.decoder_type == "deeplabv3plus": try: decoder_options = dict(config.decoder_deeplabv3plus) except: @@ -194,7 +192,9 @@ def __init__(self, config): act = get_act(act_type=config.model.act) - self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) + self.num_target_instruments = ( + 1 if config.training.target_instrument else len(config.training.instruments) + ) self.num_subbands = config.model.num_subbands dim_c = self.num_subbands * config.audio.num_channels * 2 @@ -208,7 +208,7 @@ def __init__(self, config): self.final_conv = nn.Sequential( nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), act, - nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False), ) self.stft = STFT(config.audio) diff --git a/programs/music_separation_code/models/torchseg_models.py b/programs/music_separation_code/models/torchseg_models.py index fb4bd9f..92fec69 100644 --- a/programs/music_separation_code/models/torchseg_models.py +++ b/programs/music_separation_code/models/torchseg_models.py @@ -21,12 +21,14 @@ def __call__(self, x): hop_length=self.hop_length, window=window, center=True, - return_complex=True + return_complex=True, ) x = torch.view_as_real(x) x = x.permute([0, 3, 1, 2]) - x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) - return x[..., :self.dim_f, :] + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape( + [*batch_dims, c * 2, -1, x.shape[-1]] + ) + return x[..., : self.dim_f, :] def inverse(self, x): window = self.window.to(x.device) @@ -37,25 +39,21 @@ def inverse(self, x): x = torch.cat([x, f_pad], -2) x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) x = x.permute([0, 2, 3, 1]) - x = x[..., 0] + x[..., 1] * 1.j + x = x[..., 0] + x[..., 1] * 1.0j x = torch.istft( - x, - n_fft=self.n_fft, - hop_length=self.hop_length, - window=window, - center=True + x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True ) x = x.reshape([*batch_dims, 2, -1]) return x def get_act(act_type): - if act_type == 'gelu': + if act_type == "gelu": return nn.GELU() - elif act_type == 'relu': + elif act_type == "relu": return nn.ReLU() - elif act_type[:3] == 'elu': - alpha = float(act_type.replace('elu', '')) + elif act_type[:3] == "elu": + alpha = float(act_type.replace("elu", "")) return nn.ELU(alpha) else: raise Exception @@ -64,7 +62,7 @@ def get_act(act_type): def get_decoder(config, c): decoder = None decoder_options = dict() - if config.model.decoder_type == 'unet': + if config.model.decoder_type == "unet": try: decoder_options = dict(config.decoder_unet) except: @@ -76,7 +74,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'fpn': + elif config.model.decoder_type == "fpn": try: decoder_options = dict(config.decoder_fpn) except: @@ -88,7 +86,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'unet++': + elif config.model.decoder_type == "unet++": try: decoder_options = dict(config.decoder_unet_plus_plus) except: @@ -100,7 +98,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'manet': + elif config.model.decoder_type == "manet": try: decoder_options = dict(config.decoder_manet) except: @@ -112,7 +110,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'linknet': + elif config.model.decoder_type == "linknet": try: decoder_options = dict(config.decoder_linknet) except: @@ -124,7 +122,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'pspnet': + elif config.model.decoder_type == "pspnet": try: decoder_options = dict(config.decoder_pspnet) except: @@ -136,7 +134,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'pspnet': + elif config.model.decoder_type == "pspnet": try: decoder_options = dict(config.decoder_pspnet) except: @@ -148,7 +146,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'pan': + elif config.model.decoder_type == "pan": try: decoder_options = dict(config.decoder_pan) except: @@ -160,7 +158,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'deeplabv3': + elif config.model.decoder_type == "deeplabv3": try: decoder_options = dict(config.decoder_deeplabv3) except: @@ -172,7 +170,7 @@ def get_decoder(config, c): classes=c, **decoder_options, ) - elif config.model.decoder_type == 'deeplabv3plus': + elif config.model.decoder_type == "deeplabv3plus": try: decoder_options = dict(config.decoder_deeplabv3plus) except: @@ -194,7 +192,9 @@ def __init__(self, config): act = get_act(act_type=config.model.act) - self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) + self.num_target_instruments = ( + 1 if config.training.target_instrument else len(config.training.instruments) + ) self.num_subbands = config.model.num_subbands dim_c = self.num_subbands * config.audio.num_channels * 2 @@ -208,7 +208,7 @@ def __init__(self, config): self.final_conv = nn.Sequential( nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), act, - nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False), ) self.stft = STFT(config.audio) diff --git a/programs/music_separation_code/models/upernet_swin_transformers.py b/programs/music_separation_code/models/upernet_swin_transformers.py index d20e289..27f32f4 100644 --- a/programs/music_separation_code/models/upernet_swin_transformers.py +++ b/programs/music_separation_code/models/upernet_swin_transformers.py @@ -22,12 +22,14 @@ def __call__(self, x): hop_length=self.hop_length, window=window, center=True, - return_complex=True + return_complex=True, ) x = torch.view_as_real(x) x = x.permute([0, 3, 1, 2]) - x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) - return x[..., :self.dim_f, :] + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape( + [*batch_dims, c * 2, -1, x.shape[-1]] + ) + return x[..., : self.dim_f, :] def inverse(self, x): window = self.window.to(x.device) @@ -38,13 +40,9 @@ def inverse(self, x): x = torch.cat([x, f_pad], -2) x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) x = x.permute([0, 2, 3, 1]) - x = x[..., 0] + x[..., 1] * 1.j + x = x[..., 0] + x[..., 1] * 1.0j x = torch.istft( - x, - n_fft=self.n_fft, - hop_length=self.hop_length, - window=window, - center=True + x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True ) x = x.reshape([*batch_dims, 2, -1]) return x @@ -52,12 +50,12 @@ def inverse(self, x): def get_norm(norm_type): def norm(c, norm_type): - if norm_type == 'BatchNorm': + if norm_type == "BatchNorm": return nn.BatchNorm2d(c) - elif norm_type == 'InstanceNorm': + elif norm_type == "InstanceNorm": return nn.InstanceNorm2d(c, affine=True) - elif 'GroupNorm' in norm_type: - g = int(norm_type.replace('GroupNorm', '')) + elif "GroupNorm" in norm_type: + g = int(norm_type.replace("GroupNorm", "")) return nn.GroupNorm(num_groups=g, num_channels=c) else: return nn.Identity() @@ -66,12 +64,12 @@ def norm(c, norm_type): def get_act(act_type): - if act_type == 'gelu': + if act_type == "gelu": return nn.GELU() - elif act_type == 'relu': + elif act_type == "relu": return nn.ReLU() - elif act_type[:3] == 'elu': - alpha = float(act_type.replace('elu', '')) + elif act_type[:3] == "elu": + alpha = float(act_type.replace("elu", "")) return nn.ELU(alpha) else: raise Exception @@ -83,7 +81,13 @@ def __init__(self, in_c, out_c, scale, norm, act): self.conv = nn.Sequential( norm(in_c), act, - nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + nn.ConvTranspose2d( + in_channels=in_c, + out_channels=out_c, + kernel_size=scale, + stride=scale, + bias=False, + ), ) def forward(self, x): @@ -96,7 +100,13 @@ def __init__(self, in_c, out_c, scale, norm, act): self.conv = nn.Sequential( norm(in_c), act, - nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + nn.Conv2d( + in_channels=in_c, + out_channels=out_c, + kernel_size=scale, + stride=scale, + bias=False, + ), ) def forward(self, x): @@ -151,7 +161,9 @@ def __init__(self, config): act = get_act(act_type=config.model.act) - self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) + self.num_target_instruments = ( + 1 if config.training.target_instrument else len(config.training.instruments) + ) self.num_subbands = config.model.num_subbands dim_c = self.num_subbands * config.audio.num_channels * 2 @@ -160,16 +172,24 @@ def __init__(self, config): self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) - self.swin_upernet_model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-large") + self.swin_upernet_model = UperNetForSemanticSegmentation.from_pretrained( + "openmmlab/upernet-swin-large" + ) - self.swin_upernet_model.auxiliary_head.classifier = nn.Conv2d(256, c, kernel_size=(1, 1), stride=(1, 1)) - self.swin_upernet_model.decode_head.classifier = nn.Conv2d(512, c, kernel_size=(1, 1), stride=(1, 1)) - self.swin_upernet_model.backbone.embeddings.patch_embeddings.projection = nn.Conv2d(c, 192, kernel_size=(4, 4), stride=(4, 4)) + self.swin_upernet_model.auxiliary_head.classifier = nn.Conv2d( + 256, c, kernel_size=(1, 1), stride=(1, 1) + ) + self.swin_upernet_model.decode_head.classifier = nn.Conv2d( + 512, c, kernel_size=(1, 1), stride=(1, 1) + ) + self.swin_upernet_model.backbone.embeddings.patch_embeddings.projection = ( + nn.Conv2d(c, 192, kernel_size=(4, 4), stride=(4, 4)) + ) self.final_conv = nn.Sequential( nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), act, - nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False), ) self.stft = STFT(config.audio) @@ -217,7 +237,9 @@ def forward(self, x): if __name__ == "__main__": - model = UperNetForSemanticSegmentation.from_pretrained("./results/", ignore_mismatched_sizes=True) + model = UperNetForSemanticSegmentation.from_pretrained( + "./results/", ignore_mismatched_sizes=True + ) print(model) print(model.auxiliary_head.classifier) print(model.decode_head.classifier) @@ -225,4 +247,4 @@ def forward(self, x): x = torch.zeros((2, 16, 512, 512), dtype=torch.float32) res = model(x) print(res.logits.shape) - model.save_pretrained('./results/') \ No newline at end of file + model.save_pretrained("./results/") diff --git a/programs/music_separation_code/utils.py b/programs/music_separation_code/utils.py index 711af16..1daee57 100644 --- a/programs/music_separation_code/utils.py +++ b/programs/music_separation_code/utils.py @@ -1,7 +1,6 @@ # coding: utf-8 -__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' +__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/" -import time import numpy as np import torch import torch.nn as nn @@ -12,64 +11,65 @@ from numpy.typing import NDArray from typing import Dict + def get_model_from_config(model_type, config_path): with open(config_path) as f: - if model_type == 'htdemucs': + if model_type == "htdemucs": config = OmegaConf.load(config_path) else: config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) - if model_type == 'mdx23c': + if model_type == "mdx23c": from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net + model = TFC_TDF_net(config) - elif model_type == 'htdemucs': + elif model_type == "htdemucs": from models.demucs4ht import get_model + model = get_model(config) - elif model_type == 'segm_models': + elif model_type == "segm_models": from models.segm_models import Segm_Models_Net + model = Segm_Models_Net(config) - elif model_type == 'torchseg': + elif model_type == "torchseg": from models.torchseg_models import Torchseg_Net + model = Torchseg_Net(config) - elif model_type == 'mel_band_roformer': + elif model_type == "mel_band_roformer": from models.bs_roformer import MelBandRoformer - model = MelBandRoformer( - **dict(config.model) - ) - elif model_type == 'bs_roformer': + + model = MelBandRoformer(**dict(config.model)) + elif model_type == "bs_roformer": from models.bs_roformer import BSRoformer - model = BSRoformer( - **dict(config.model) - ) - elif model_type == 'swin_upernet': + + model = BSRoformer(**dict(config.model)) + elif model_type == "swin_upernet": from models.upernet_swin_transformers import Swin_UperNet_Model + model = Swin_UperNet_Model(config) - elif model_type == 'bandit': + elif model_type == "bandit": from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple - model = MultiMaskMultiSourceBandSplitRNNSimple( - **config.model - ) - elif model_type == 'bandit_v2': + + model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model) + elif model_type == "bandit_v2": from models.bandit_v2.bandit import Bandit - model = Bandit( - **config.kwargs - ) - elif model_type == 'scnet_unofficial': + + model = Bandit(**config.kwargs) + elif model_type == "scnet_unofficial": from models.scnet_unofficial import SCNet - model = SCNet( - **config.model - ) - elif model_type == 'scnet': + + model = SCNet(**config.model) + elif model_type == "scnet": from models.scnet import SCNet - model = SCNet( - **config.model - ) + + model = SCNet(**config.model) else: - print('Unknown model: {}'.format(model_type)) + print("Unknown model: {}".format(model_type)) model = None return model, config + def _getWindowingArray(window_size, fade_size): fadein = torch.linspace(0, 1, fade_size) fadeout = torch.linspace(1, 0, fade_size) @@ -91,16 +91,16 @@ def demix_track(config, model, mix, device, pbar=False): # Do pad from the beginning and end to account floating window results better if length_init > 2 * border and (border > 0): - mix = nn.functional.pad(mix, (border, border), mode='reflect') + mix = nn.functional.pad(mix, (border, border), mode="reflect") # windowingArray crossfades at segment boundaries to mitigate clicking artifacts windowingArray = _getWindowingArray(C, fade_size) with torch.cuda.amp.autocast(enabled=config.training.use_amp): - use_amp = getattr(config.training, 'use_amp', False) + use_amp = getattr(config.training, "use_amp", False) with torch.inference_mode(): if config.training.target_instrument is not None: - req_shape = (1, ) + tuple(mix.shape) + req_shape = (1,) + tuple(mix.shape) else: req_shape = (len(config.training.instruments),) + tuple(mix.shape) @@ -109,17 +109,28 @@ def demix_track(config, model, mix, device, pbar=False): i = 0 batch_data = [] batch_locations = [] - progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None + progress_bar = ( + tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) + if pbar + else None + ) while i < mix.shape[1]: # print(i, i + C, mix.shape[1]) - part = mix[:, i:i + C].to(device) + part = mix[:, i : i + C].to(device) length = part.shape[-1] if length < C: if length > C // 2 + 1: - part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') + part = nn.functional.pad( + input=part, pad=(0, C - length), mode="reflect" + ) else: - part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) + part = nn.functional.pad( + input=part, + pad=(0, C - length, 0, 0), + mode="constant", + value=0, + ) batch_data.append(part) batch_locations.append((i, length)) i += step @@ -136,8 +147,10 @@ def demix_track(config, model, mix, device, pbar=False): for j in range(len(batch_locations)): start, l = batch_locations[j] - result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l] - counter[..., start:start+l] += window[..., :l] + result[..., start : start + l] += ( + x[j][..., :l].cpu() * window[..., :l] + ) + counter[..., start : start + l] += window[..., :l] batch_data = [] batch_locations = [] @@ -159,7 +172,9 @@ def demix_track(config, model, mix, device, pbar=False): if config.training.target_instrument is None: return {k: v for k, v in zip(config.training.instruments, estimated_sources)} else: - return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)} + return { + k: v for k, v in zip([config.training.target_instrument], estimated_sources) + } def demix_track_demucs(config, model, mix, device, pbar=False): @@ -172,32 +187,37 @@ def demix_track_demucs(config, model, mix, device, pbar=False): with torch.cuda.amp.autocast(enabled=config.training.use_amp): with torch.inference_mode(): - req_shape = (S, ) + tuple(mix.shape) + req_shape = (S,) + tuple(mix.shape) result = torch.zeros(req_shape, dtype=torch.float32) counter = torch.zeros(req_shape, dtype=torch.float32) i = 0 batch_data = [] batch_locations = [] - progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None + progress_bar = ( + tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) + if pbar + else None + ) while i < mix.shape[1]: # print(i, i + C, mix.shape[1]) - part = mix[:, i:i + C].to(device) + part = mix[:, i : i + C].to(device) length = part.shape[-1] if length < C: - part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) + part = nn.functional.pad( + input=part, pad=(0, C - length, 0, 0), mode="constant", value=0 + ) batch_data.append(part) batch_locations.append((i, length)) i += step - if len(batch_data) >= batch_size or (i >= mix.shape[1]): arr = torch.stack(batch_data, dim=0) x = model(arr) for j in range(len(batch_locations)): start, l = batch_locations[j] - result[..., start:start+l] += x[j][..., :l].cpu() - counter[..., start:start+l] += 1. + result[..., start : start + l] += x[j][..., :l].cpu() + counter[..., start : start + l] += 1.0 batch_data = [] batch_locations = [] @@ -226,9 +246,12 @@ def sdr(references, estimates): den += delta return 10 * np.log10(num / den) -def demix(config, model, mix: NDArray, device, pbar=False, model_type: str = None) -> Dict[str, NDArray]: + +def demix( + config, model, mix: NDArray, device, pbar=False, model_type: str = None +) -> Dict[str, NDArray]: mix = torch.tensor(mix, dtype=torch.float32) - if model_type == 'htdemucs': + if model_type == "htdemucs": return demix_track_demucs(config, model, mix, device, pbar=pbar) else: return demix_track(config, model, mix, device, pbar=pbar) diff --git a/tabs/full_inference.py b/tabs/full_inference.py index 08eaafe..bb53dc5 100644 --- a/tabs/full_inference.py +++ b/tabs/full_inference.py @@ -19,191 +19,113 @@ from assets.i18n.i18n import I18nAuto - - - i18n = I18nAuto() - now_dir = os.getcwd() sys.path.append(now_dir) - model_root = os.path.join(now_dir, "logs") audio_root = os.path.join(now_dir, "audio_files", "original_files") - model_root_relative = os.path.relpath(model_root, now_dir) audio_root_relative = os.path.relpath(audio_root, now_dir) - sup_audioext = { - "wav", - "mp3", - "flac", - "ogg", - "opus", - "m4a", - "mp4", - "aac", - "alac", - "wma", - "aiff", - "webm", - "ac3", - } - names = [ - os.path.join(root, file) - for root, _, files in os.walk(model_root_relative, topdown=False) - for file in files - if ( - file.endswith((".pth", ".onnx")) - and not (file.startswith("G_") or file.startswith("D_")) - ) - ] - indexes_list = [ - os.path.join(root, name) - for root, _, files in os.walk(model_root_relative, topdown=False) - for name in files - if name.endswith(".index") and "trained" not in name - ] - audio_paths = [ - os.path.join(root, name) - for root, _, files in os.walk(audio_root_relative, topdown=False) - for name in files - if name.endswith(tuple(sup_audioext)) - and root == audio_root_relative - and "_output" not in name - ] - vocals_model_names = [ - "Mel-Roformer by KimberleyJSN", - "BS-Roformer by ViperX", - "MDX23C", - ] - karaoke_models_names = [ - "Mel-Roformer Karaoke by aufr33 and viperx", - "UVR-BVE", - ] - denoise_models_names = [ - "Mel-Roformer Denoise Normal by aufr33", - "Mel-Roformer Denoise Aggressive by aufr33", - "UVR Denoise", - ] - dereverb_models_names = [ - "MDX23C DeReverb by aufr33 and jarredou", - "UVR-Deecho-Dereverb", - "MDX Reverb HQ by FoxJoy", - "BS-Roformer Dereverb by anvuew", - ] - deeecho_models_names = ["UVR-Deecho-Normal", "UVR-Deecho-Aggressive"] - - - def get_indexes(): indexes_list = [ - os.path.join(dirpath, filename) - for dirpath, _, filenames in os.walk(model_root_relative) - for filename in filenames - if filename.endswith(".index") and "trained" not in filename - ] - - return indexes_list if indexes_list else "" - - - def match_index(model_file_value): if model_file_value: @@ -235,15 +157,10 @@ def match_index(model_file_value): return "" - - - def output_path_fn(input_audio_path): original_name_without_extension = os.path.basename(input_audio_path).rsplit(".", 1)[ - 0 - ] new_name = original_name_without_extension + "_output.wav" @@ -253,9 +170,6 @@ def output_path_fn(input_audio_path): return output_path - - - def get_number_of_gpus(): if torch.cuda.is_available(): @@ -269,9 +183,6 @@ def get_number_of_gpus(): return "-" - - - def max_vram_gpu(gpu): if torch.cuda.is_available(): @@ -287,15 +198,10 @@ def max_vram_gpu(gpu): return "0" - - - def format_title(title): formatted_title = ( - unicodedata.normalize("NFKD", title).encode("ascii", "ignore").decode("utf-8") - ) formatted_title = re.sub(r"[\u2500-\u257F]+", "", formatted_title) @@ -307,9 +213,6 @@ def format_title(title): return formatted_title - - - def save_to_wav(upload_audio): file_path = upload_audio @@ -318,14 +221,10 @@ def save_to_wav(upload_audio): target_path = os.path.join(audio_root_relative, formated_name) - - if os.path.exists(target_path): os.remove(target_path) - - os.makedirs(os.path.dirname(target_path), exist_ok=True) shutil.copy(file_path, target_path) @@ -333,9 +232,6 @@ def save_to_wav(upload_audio): return target_path, output_path_fn(target_path) - - - def delete_outputs(): gr.Info(f"Outputs cleared!") @@ -349,115 +245,64 @@ def delete_outputs(): os.remove(os.path.join(root, name)) - - - def change_choices(): names = [ - os.path.join(root, file) - for root, _, files in os.walk(model_root_relative, topdown=False) - for file in files - if ( - file.endswith((".pth", ".onnx")) - and not (file.startswith("G_") or file.startswith("D_")) - ) - ] - - indexes_list = [ - os.path.join(root, name) - for root, _, files in os.walk(model_root_relative, topdown=False) - for name in files - if name.endswith(".index") and "trained" not in name - ] - - audio_paths = [ - os.path.join(root, name) - for root, _, files in os.walk(audio_root_relative, topdown=False) - for name in files - if name.endswith(tuple(sup_audioext)) - and root == audio_root_relative - and "_output" not in name - ] - - return ( - {"choices": sorted(names), "__type__": "update"}, - {"choices": sorted(indexes_list), "__type__": "update"}, - {"choices": sorted(audio_paths), "__type__": "update"}, - ) - - - def download_music_tab(): with gr.Row(): link = gr.Textbox( - label=i18n("Music URL"), - lines=1, - ) output = gr.Textbox( - label=i18n("Output Information"), - info=i18n("The output information will be displayed here."), - ) download = gr.Button(i18n("Download")) - - download.click( + download.click( download_music, - inputs=[link], - outputs=[output], - ) - - - - - def full_inference_tab(): default_weight = names[0] if names else None @@ -467,67 +312,41 @@ def full_inference_tab(): with gr.Row(): model_file = gr.Dropdown( - label=i18n("Voice Model"), - info=i18n("Select the voice model to use for the conversion."), - choices=sorted(names, key=lambda path: os.path.getsize(path)), - interactive=True, - value=default_weight, - allow_custom_value=True, - ) index_file = gr.Dropdown( - label=i18n("Index File"), - info=i18n("Select the index file to use for the conversion."), - choices=get_indexes(), - value=match_index(default_weight) if default_weight else "", - interactive=True, - allow_custom_value=True, - ) with gr.Column(): with gr.Row(): unload_button = gr.Button(i18n("Unload Voice")) - refresh_button = gr.Button(i18n("Refresh")) - + refresh_button = gr.Button(i18n("Refresh")) unload_button.click( - fn=lambda: ( - {"value": "", "__type__": "update"}, - {"value": "", "__type__": "update"}, - ), - inputs=[], - outputs=[model_file, index_file], - ) model_file.select( - fn=lambda model_file_value: match_index(model_file_value), - inputs=[model_file], - outputs=[index_file], - ) with gr.Tab(i18n("Single")): @@ -535,33 +354,21 @@ def full_inference_tab(): with gr.Column(): upload_audio = gr.Audio( - label=i18n("Upload Audio"), - type="filepath", - editable=False, - sources="upload", - ) with gr.Row(): audio = gr.Dropdown( - label=i18n("Select Audio"), - info=i18n("Select the audio to convert."), - choices=sorted(audio_paths), - value=audio_paths[0] if audio_paths else "", - interactive=True, - allow_custom_value=True, - ) with gr.Accordion(i18n("Advanced Settings"), open=False): @@ -569,955 +376,538 @@ def full_inference_tab(): with gr.Accordion(i18n("RVC Settings"), open=False): output_path = gr.Textbox( - label=i18n("Output Path"), - placeholder=i18n("Enter output path"), - info=i18n( - "The path where the output audio will be saved, by default in audio_files/rvc/output.wav" - ), - value=os.path.join(now_dir, "audio_files", "rvc"), - interactive=False, - visible=False, - ) infer_backing_vocals = gr.Checkbox( - label=i18n("Infer Backing Vocals"), - info=i18n("Infer the bakcing vocals too."), - visible=True, - value=False, - interactive=True, - ) with gr.Row(): infer_backing_vocals_model = gr.Dropdown( - label=i18n("Backing Vocals Model"), - info=i18n( - "Select the backing vocals model to use for the conversion." - ), - choices=sorted(names, key=lambda path: os.path.getsize(path)), - interactive=True, - value=default_weight, - visible=False, - allow_custom_value=False, - ) infer_backing_vocals_index = gr.Dropdown( - label=i18n("Backing Vocals Index File"), - info=i18n( - "Select the backing vocals index file to use for the conversion." - ), - choices=get_indexes(), - value=match_index(default_weight) if default_weight else "", - interactive=True, - visible=False, - allow_custom_value=True, - ) with gr.Column(): refresh_button_infer_backing_vocals = gr.Button( - i18n("Refresh"), - visible=False, - ) unload_button_infer_backing_vocals = gr.Button( - i18n("Unload Voice"), - visible=False, - ) - - unload_button_infer_backing_vocals.click( - fn=lambda: ( - {"value": "", "__type__": "update"}, - {"value": "", "__type__": "update"}, - ), - inputs=[], - outputs=[ - infer_backing_vocals_model, - infer_backing_vocals_index, - ], - ) infer_backing_vocals_model.select( - fn=lambda model_file_value: match_index(model_file_value), - inputs=[infer_backing_vocals_model], - outputs=[infer_backing_vocals_index], - ) with gr.Accordion( - i18n("RVC Settings for Backing vocals"), open=False, visible=False - ) as back_rvc_settings: export_format_rvc_back = gr.Radio( - label=i18n("Export Format"), - info=i18n("Select the format to export the audio."), - choices=["WAV", "MP3", "FLAC", "OGG", "M4A"], - value="MP3", - interactive=True, - visible=False, - ) split_audio_back = gr.Checkbox( - label=i18n("Split Audio"), - info=i18n( - "Split the audio into chunks for inference to obtain better results in some cases." - ), - visible=True, - value=False, - interactive=True, - ) pitch_extract_back = gr.Radio( - label=i18n("Pitch Extractor"), - info=i18n("Pitch extract Algorith."), - choices=["rmvpe", "crepe", "crepe-tiny", "fcpe"], - value="rmvpe", - interactive=True, - ) hop_length_back = gr.Slider( - label=i18n("Hop Length"), - info=i18n("Hop length for pitch extraction."), - minimum=1, - maximum=512, - step=1, - value=64, - visible=False, - ) embedder_model_back = gr.Radio( - label=i18n("Embedder Model"), - info=i18n("Model used for learning speaker embedding."), - choices=[ - "contentvec", - "chinese-hubert-base", - "japanese-hubert-base", - "korean-hubert-base", - ], - value="contentvec", - interactive=True, - ) autotune_back = gr.Checkbox( - label=i18n("Autotune"), - info=i18n( - "Apply a soft autotune to your inferences, recommended for singing conversions." - ), - visible=True, - value=False, - interactive=True, - ) pitch_back = gr.Slider( - label=i18n("Pitch"), - info=i18n("Adjust the pitch of the audio."), - minimum=-12, - maximum=12, - step=1, - value=0, - interactive=True, - ) filter_radius_back = gr.Slider( - minimum=0, - maximum=7, - label=i18n("Filter Radius"), - info=i18n( - "If the number is greater than or equal to three, employing median filtering on the collected tone results has the potential to decrease respiration." - ), - value=3, - step=1, - interactive=True, - ) index_rate_back = gr.Slider( - minimum=0, - maximum=1, - label=i18n("Search Feature Ratio"), - info=i18n( - "Influence exerted by the index file; a higher value corresponds to greater influence. However, opting for lower values can help mitigate artifacts present in the audio." - ), - value=0.75, - interactive=True, - ) rms_mix_rate_back = gr.Slider( - minimum=0, - maximum=1, - label=i18n("Volume Envelope"), - info=i18n( - "Substitute or blend with the volume envelope of the output. The closer the ratio is to 1, the more the output envelope is employed." - ), - value=0.25, - interactive=True, - ) protect_back = gr.Slider( - minimum=0, - maximum=0.5, - label=i18n("Protect Voiceless Consonants"), - info=i18n( - "Safeguard distinct consonants and breathing sounds to prevent electro-acoustic tearing and other artifacts. Pulling the parameter to its maximum value of 0.5 offers comprehensive protection. However, reducing this value might decrease the extent of protection while potentially mitigating the indexing effect." - ), - value=0.33, - interactive=True, - ) clear_outputs_infer = gr.Button( - i18n("Clear Outputs (Deletes all audios in assets/audios)") - ) export_format_rvc = gr.Radio( - label=i18n("Export Format"), - info=i18n("Select the format to export the audio."), - choices=["WAV", "MP3", "FLAC", "OGG", "M4A"], - value="FLAC", - interactive=True, - visible=False, - ) split_audio = gr.Checkbox( - label=i18n("Split Audio"), - info=i18n( - "Split the audio into chunks for inference to obtain better results in some cases." - ), - visible=True, - value=False, - interactive=True, - ) pitch_extract = gr.Radio( - label=i18n("Pitch Extractor"), - info=i18n("Pitch extract Algorith."), - choices=["rmvpe", "crepe", "crepe-tiny", "fcpe"], - value="rmvpe", - interactive=True, - ) hop_length = gr.Slider( - label=i18n("Hop Length"), - info=i18n("Hop length for pitch extraction."), - minimum=1, - maximum=512, - step=1, - value=64, - visible=False, - ) embedder_model = gr.Radio( - label=i18n("Embedder Model"), - info=i18n("Model used for learning speaker embedding."), - choices=[ - "contentvec", - "chinese-hubert-base", - "japanese-hubert-base", - "korean-hubert-base", - ], - value="contentvec", - interactive=True, - ) autotune = gr.Checkbox( - label=i18n("Autotune"), - info=i18n( - "Apply a soft autotune to your inferences, recommended for singing conversions." - ), - visible=True, - value=False, - interactive=True, - ) pitch = gr.Slider( - label=i18n("Pitch"), - info=i18n("Adjust the pitch of the audio."), - minimum=-12, - maximum=12, - step=1, - value=0, - interactive=True, - ) filter_radius = gr.Slider( - minimum=0, - maximum=7, - label=i18n("Filter Radius"), - info=i18n( - "If the number is greater than or equal to three, employing median filtering on the collected tone results has the potential to decrease respiration." - ), - value=3, - step=1, - interactive=True, - ) index_rate = gr.Slider( - minimum=0, - maximum=1, - label=i18n("Search Feature Ratio"), - info=i18n( - "Influence exerted by the index file; a higher value corresponds to greater influence. However, opting for lower values can help mitigate artifacts present in the audio." - ), - value=0.75, - interactive=True, - ) rms_mix_rate = gr.Slider( - minimum=0, - maximum=1, - label=i18n("Volume Envelope"), - info=i18n( - "Substitute or blend with the volume envelope of the output. The closer the ratio is to 1, the more the output envelope is employed." - ), - value=0.25, - interactive=True, - ) protect = gr.Slider( - minimum=0, - maximum=0.5, - label=i18n("Protect Voiceless Consonants"), - info=i18n( - "Safeguard distinct consonants and breathing sounds to prevent electro-acoustic tearing and other artifacts. Pulling the parameter to its maximum value of 0.5 offers comprehensive protection. However, reducing this value might decrease the extent of protection while potentially mitigating the indexing effect." - ), - value=0.33, - interactive=True, - ) with gr.Accordion(i18n("Audio Separation Settings"), open=False): use_tta = gr.Checkbox( - label=i18n("Use TTA"), - info=i18n("Use Test Time Augmentation."), - visible=True, - value=False, - interactive=True, - ) batch_size = gr.Slider( - minimum=1, - maximum=24, - step=1, - label=i18n("Batch Size"), - info=i18n("Set the batch size for the separation."), - value=1, - interactive=True, - ) vocal_model = gr.Dropdown( - label=i18n("Vocals Model"), - info=i18n("Select the vocals model to use for the separation."), - choices=sorted(vocals_model_names), - interactive=True, - value="Mel-Roformer by KimberleyJSN", - allow_custom_value=False, - ) karaoke_model = gr.Dropdown( - label=i18n("Karaoke Model"), - info=i18n("Select the karaoke model to use for the separation."), - choices=sorted(karaoke_models_names), - interactive=True, - value="Mel-Roformer Karaoke by aufr33 and viperx", - allow_custom_value=False, - ) dereverb_model = gr.Dropdown( - label=i18n("Dereverb Model"), - info=i18n("Select the dereverb model to use for the separation."), - choices=sorted(dereverb_models_names), - interactive=True, - value="UVR-Deecho-Dereverb", - allow_custom_value=False, - ) deecho = gr.Checkbox( - label=i18n("Deeecho"), - info=i18n("Apply deeecho to the audio."), - visible=True, - value=True, - interactive=True, - ) deeecho_model = gr.Dropdown( - label=i18n("Deeecho Model"), - info=i18n("Select the deeecho model to use for the separation."), - choices=sorted(deeecho_models_names), - interactive=True, - value="UVR-Deecho-Normal", - allow_custom_value=False, - ) denoise = gr.Checkbox( - label=i18n("Denoise"), - info=i18n("Apply denoise to the audio."), - visible=True, - value=False, - interactive=True, - ) denoise_model = gr.Dropdown( - label=i18n("Denoise Model"), - info=i18n("Select the denoise model to use for the separation."), - choices=sorted(denoise_models_names), - interactive=True, - value="Mel-Roformer Denoise Normal by aufr33", - allow_custom_value=False, - visible=False, - ) with gr.Accordion(i18n("Audio post-process Settings"), open=False): change_inst_pitch = gr.Slider( - label=i18n("Change Instrumental Pitch"), - info=i18n("Change the pitch of the instrumental."), - minimum=-12, - maximum=12, - step=1, - value=0, - interactive=True, - ) delete_audios = gr.Checkbox( - label=i18n("Delete Audios"), - info=i18n("Delete the audios after the conversion."), - visible=True, - value=False, - interactive=True, - ) reverb = gr.Checkbox( - label=i18n("Reverb"), - info=i18n("Apply reverb to the audio."), - visible=True, - value=False, - interactive=True, - ) reverb_room_size = gr.Slider( - minimum=0, - maximum=1, - label=i18n("Reverb Room Size"), - info=i18n("Set the room size of the reverb."), - value=0.5, - interactive=True, - visible=False, - ) - - reverb_damping = gr.Slider( - minimum=0, - maximum=1, - label=i18n("Reverb Damping"), - info=i18n("Set the damping of the reverb."), - value=0.5, - interactive=True, - visible=False, - ) - - reverb_wet_gain = gr.Slider( - minimum=0, - maximum=1, - label=i18n("Reverb Wet Gain"), - info=i18n("Set the wet gain of the reverb."), - value=0.33, - interactive=True, - visible=False, - ) - - reverb_dry_gain = gr.Slider( - minimum=0, - maximum=1, - label=i18n("Reverb Dry Gain"), - info=i18n("Set the dry gain of the reverb."), - value=0.4, - interactive=True, - visible=False, - ) - - reverb_width = gr.Slider( - minimum=0, - maximum=1, - label=i18n("Reverb Width"), - info=i18n("Set the width of the reverb."), - value=1.0, - interactive=True, - visible=False, - ) vocals_volume = gr.Slider( - label=i18n("Vocals Volume"), - info=i18n("Adjust the volume of the vocals."), - minimum=-10, - maximum=0, - step=1, - value=-3, - interactive=True, - ) instrumentals_volume = gr.Slider( - label=i18n("Instrumentals Volume"), - info=i18n("Adjust the volume of the Instrumentals."), - minimum=-10, - maximum=0, - step=1, - value=-3, - interactive=True, - ) backing_vocals_volume = gr.Slider( - label=i18n("Backing Vocals Volume"), - info=i18n("Adjust the volume of the backing vocals."), - minimum=-10, - maximum=0, - step=1, - value=-3, - interactive=True, - ) export_format_final = gr.Radio( - label=i18n("Export Format"), - info=i18n("Select the format to export the audio."), - choices=["WAV", "MP3", "FLAC", "OGG", "M4A"], - value="FLAC", - interactive=True, - ) with gr.Accordion(i18n("Device Settings"), open=False): devices = gr.Textbox( - label=i18n("Device"), - info=i18n( - "Select the device to use for the conversion. 0 to ∞ separated by - and for CPU leave only an -" - ), - value=get_number_of_gpus(), - interactive=True, - ) - - + with gr.Row(): convert_button = gr.Button(i18n("Convert")) - + with gr.Row(): vc_output1 = gr.Textbox( - - label=i18n("Output Information"), - - info=i18n("The output information will be displayed here."), + label=i18n("Output Information"), + info=i18n("The output information will be displayed here."), ) vc_output2 = gr.Audio(label=i18n("Export Audio")) - with gr.Tab(i18n("Download Music")): + with gr.Tab(i18n("Download Music")): download_music_tab() @@ -1525,285 +915,152 @@ def update_dropdown_visibility(checkbox): return gr.update(visible=checkbox) - - def update_reverb_sliders_visibility(reverb_checked): return { - reverb_room_size: gr.update(visible=reverb_checked), - reverb_damping: gr.update(visible=reverb_checked), - reverb_wet_gain: gr.update(visible=reverb_checked), - reverb_dry_gain: gr.update(visible=reverb_checked), - reverb_width: gr.update(visible=reverb_checked), - } - - def update_visibility_infer_backing(infer_backing_vocals): visible = infer_backing_vocals return ( - {"visible": visible, "__type__": "update"}, - {"visible": visible, "__type__": "update"}, - {"visible": visible, "__type__": "update"}, - {"visible": visible, "__type__": "update"}, - {"visible": visible, "__type__": "update"}, - ) - - def update_hop_length_visibility(pitch_extract_value): return gr.update(visible=pitch_extract_value in ["crepe", "crepe-tiny"]) - - - refresh_button.click( - fn=change_choices, - inputs=[], - outputs=[model_file, index_file, audio], - ) refresh_button_infer_backing_vocals.click( - fn=change_choices, - inputs=[], - outputs=[infer_backing_vocals_model, infer_backing_vocals_index], - ) upload_audio.upload( - fn=save_to_wav, - inputs=[upload_audio], - outputs=[audio, output_path], - ) clear_outputs_infer.click( - fn=delete_outputs, - inputs=[], - outputs=[], - ) convert_button.click( - full_inference_program, - inputs=[ - model_file, - index_file, - audio, - output_path, - export_format_rvc, - split_audio, - autotune, - vocal_model, - karaoke_model, - dereverb_model, - deecho, - deeecho_model, - denoise, - denoise_model, - reverb, - vocals_volume, - instrumentals_volume, - backing_vocals_volume, - export_format_final, - devices, - pitch, - filter_radius, - index_rate, - rms_mix_rate, - protect, - pitch_extract, - hop_length, - reverb_room_size, - reverb_damping, - reverb_wet_gain, - reverb_dry_gain, - reverb_width, - embedder_model, - delete_audios, - use_tta, - batch_size, - infer_backing_vocals, - infer_backing_vocals_model, - infer_backing_vocals_index, - change_inst_pitch, - pitch_back, - filter_radius_back, - index_rate_back, - rms_mix_rate_back, - protect_back, - pitch_extract_back, - hop_length_back, - export_format_rvc_back, - split_audio_back, - autotune_back, - embedder_model_back, - ], - outputs=[vc_output1, vc_output2], - ) - - deecho.change( - fn=update_dropdown_visibility, - inputs=deecho, - outputs=deeecho_model, - ) - - denoise.change( - fn=update_dropdown_visibility, - inputs=denoise, - outputs=denoise_model, - ) - - reverb.change( - fn=update_reverb_sliders_visibility, - inputs=reverb, - outputs=[ - reverb_room_size, - reverb_damping, - reverb_wet_gain, - reverb_dry_gain, - reverb_width, - ], - ) pitch_extract.change( - fn=update_hop_length_visibility, - inputs=pitch_extract, - outputs=hop_length, - ) - - infer_backing_vocals.change( - fn=update_visibility_infer_backing, - inputs=[infer_backing_vocals], - outputs=[ - infer_backing_vocals_model, - infer_backing_vocals_index, - refresh_button_infer_backing_vocals, - unload_button_infer_backing_vocals, - back_rvc_settings, - ], - ) diff --git a/tabs/settings.py b/tabs/settings.py index 6ab6bec..2734130 100644 --- a/tabs/settings.py +++ b/tabs/settings.py @@ -39,7 +39,6 @@ def save_lang_settings(selected_language): json.dump(config, file, indent=2) - def restart_applio(): if os.name != "nt": os.system("clear") @@ -49,11 +48,6 @@ def restart_applio(): os.execl(python, python, *sys.argv) - - - - - def lang_tab(): with gr.Column(): selected_language = gr.Dropdown(