11# coding: utf-8
2- __author__ = ' Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
2+ __author__ = " Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
33
44import os
55import librosa
@@ -81,42 +81,42 @@ def average_waveforms(pred_track, weights, algorithm):
8181
8282 mod_track = []
8383 for i in range (pred_track .shape [0 ]):
84- if algorithm == ' avg_wave' :
84+ if algorithm == " avg_wave" :
8585 mod_track .append (pred_track [i ] * weights [i ])
86- elif algorithm in [' median_wave' , ' min_wave' , ' max_wave' ]:
86+ elif algorithm in [" median_wave" , " min_wave" , " max_wave" ]:
8787 mod_track .append (pred_track [i ])
88- elif algorithm in [' avg_fft' , ' min_fft' , ' max_fft' , ' median_fft' ]:
88+ elif algorithm in [" avg_fft" , " min_fft" , " max_fft" , " median_fft" ]:
8989 spec = stft (pred_track [i ], nfft = 2048 , hl = 1024 )
90- if algorithm in [' avg_fft' ]:
90+ if algorithm in [" avg_fft" ]:
9191 mod_track .append (spec * weights [i ])
9292 else :
9393 mod_track .append (spec )
9494 pred_track = np .array (mod_track )
9595
96- if algorithm in [' avg_wave' ]:
96+ if algorithm in [" avg_wave" ]:
9797 pred_track = pred_track .sum (axis = 0 )
9898 pred_track /= np .array (weights ).sum ().T
99- elif algorithm in [' median_wave' ]:
99+ elif algorithm in [" median_wave" ]:
100100 pred_track = np .median (pred_track , axis = 0 )
101- elif algorithm in [' min_wave' ]:
101+ elif algorithm in [" min_wave" ]:
102102 pred_track = np .array (pred_track )
103103 pred_track = lambda_min (pred_track , axis = 0 , key = np .abs )
104- elif algorithm in [' max_wave' ]:
104+ elif algorithm in [" max_wave" ]:
105105 pred_track = np .array (pred_track )
106106 pred_track = lambda_max (pred_track , axis = 0 , key = np .abs )
107- elif algorithm in [' avg_fft' ]:
107+ elif algorithm in [" avg_fft" ]:
108108 pred_track = pred_track .sum (axis = 0 )
109109 pred_track /= np .array (weights ).sum ()
110110 pred_track = istft (pred_track , 1024 , final_length )
111- elif algorithm in [' min_fft' ]:
111+ elif algorithm in [" min_fft" ]:
112112 pred_track = np .array (pred_track )
113113 pred_track = lambda_min (pred_track , axis = 0 , key = np .abs )
114114 pred_track = istft (pred_track , 1024 , final_length )
115- elif algorithm in [' max_fft' ]:
115+ elif algorithm in [" max_fft" ]:
116116 pred_track = np .array (pred_track )
117117 pred_track = absmax (pred_track , axis = 0 )
118118 pred_track = istft (pred_track , 1024 , final_length )
119- elif algorithm in [' median_fft' ]:
119+ elif algorithm in [" median_fft" ]:
120120 pred_track = np .array (pred_track )
121121 pred_track = np .median (pred_track , axis = 0 )
122122 pred_track = istft (pred_track , 1024 , final_length )
@@ -125,37 +125,58 @@ def average_waveforms(pred_track, weights, algorithm):
125125
126126def ensemble_files (args ):
127127 parser = argparse .ArgumentParser ()
128- parser .add_argument ("--files" , type = str , required = True , nargs = '+' , help = "Path to all audio-files to ensemble" )
129- parser .add_argument ("--type" , type = str , default = 'avg_wave' , help = "One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft" )
130- parser .add_argument ("--weights" , type = float , nargs = '+' , help = "Weights to create ensemble. Number of weights must be equal to number of files" )
131- parser .add_argument ("--output" , default = "res.wav" , type = str , help = "Path to wav file where ensemble result will be stored" )
128+ parser .add_argument (
129+ "--files" ,
130+ type = str ,
131+ required = True ,
132+ nargs = "+" ,
133+ help = "Path to all audio-files to ensemble" ,
134+ )
135+ parser .add_argument (
136+ "--type" ,
137+ type = str ,
138+ default = "avg_wave" ,
139+ help = "One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft" ,
140+ )
141+ parser .add_argument (
142+ "--weights" ,
143+ type = float ,
144+ nargs = "+" ,
145+ help = "Weights to create ensemble. Number of weights must be equal to number of files" ,
146+ )
147+ parser .add_argument (
148+ "--output" ,
149+ default = "res.wav" ,
150+ type = str ,
151+ help = "Path to wav file where ensemble result will be stored" ,
152+ )
132153 if args is None :
133154 args = parser .parse_args ()
134155 else :
135156 args = parser .parse_args (args )
136157
137- print (' Ensemble type: {}' .format (args .type ))
138- print (' Number of input files: {}' .format (len (args .files )))
158+ print (" Ensemble type: {}" .format (args .type ))
159+ print (" Number of input files: {}" .format (len (args .files )))
139160 if args .weights is not None :
140161 weights = args .weights
141162 else :
142163 weights = np .ones (len (args .files ))
143- print (' Weights: {}' .format (weights ))
144- print (' Output file: {}' .format (args .output ))
164+ print (" Weights: {}" .format (weights ))
165+ print (" Output file: {}" .format (args .output ))
145166 data = []
146167 for f in args .files :
147168 if not os .path .isfile (f ):
148- print (' Error. Can\ ' t find file: {}. Check paths.' .format (f ))
169+ print (" Error. Can't find file: {}. Check paths." .format (f ))
149170 exit ()
150- print (' Reading file: {}' .format (f ))
171+ print (" Reading file: {}" .format (f ))
151172 wav , sr = librosa .load (f , sr = None , mono = False )
152173 # wav, sr = sf.read(f)
153174 print ("Waveform shape: {} sample rate: {}" .format (wav .shape , sr ))
154175 data .append (wav )
155176 data = np .array (data )
156177 res = average_waveforms (data , weights , args .type )
157- print (' Result shape: {}' .format (res .shape ))
158- sf .write (args .output , res .T , sr , ' FLOAT' )
178+ print (" Result shape: {}" .format (res .shape ))
179+ sf .write (args .output , res .T , sr , " FLOAT" )
159180
160181
161182if __name__ == "__main__" :
0 commit comments