5
5
#include < Eigen/Dense>
6
6
#include < cassert>
7
7
#include < filesystem>
8
+ #include < iomanip>
8
9
#include < iostream>
9
10
#include < libnyquist/Common.h>
10
11
#include < libnyquist/Decoders.h>
@@ -107,8 +108,8 @@ int main(int argc, const char **argv)
107
108
{
108
109
if (argc != 4 )
109
110
{
110
- std::cerr << " Usage: " << argv[0 ]
111
- << " <model dir> <wav file> <out dir> " << std::endl;
111
+ std::cerr << " Usage: " << argv[0 ] << " <model dir> <wav file> <out dir> "
112
+ << std::endl;
112
113
exit (1 );
113
114
}
114
115
@@ -135,63 +136,100 @@ int main(int argc, const char **argv)
135
136
std::string model_file;
136
137
for (const auto &entry : std::filesystem::directory_iterator (model_dir))
137
138
{
138
- bool ret;
139
+ bool ret = false ;
139
140
140
141
// check if entry contains the name "htdemucs_ft_drums"
141
142
if (entry.path ().string ().find (" htdemucs_ft_drums" ) !=
142
143
std::string::npos)
143
144
{
144
- ret = load_demucs_model (entry.path ().string (), &models[0 ]);
145
- std::cout << " Loading ft model " << entry.path ().string () << " for drums" << std::endl;
146
- } else if (entry.path ().string ().find (" htdemucs_ft_bass" ) !=
147
- std::string::npos)
145
+ ret = load_demucs_model (entry.path ().string (), &models[0 ]);
146
+ std::cout << " Loading ft model " << entry.path ().string ()
147
+ << " for drums" << std::endl;
148
+ }
149
+ else if (entry.path ().string ().find (" htdemucs_ft_bass" ) !=
150
+ std::string::npos)
148
151
{
149
- ret = load_demucs_model (entry.path ().string (), &models[1 ]);
150
- std::cout << " Loading ft model " << entry.path ().string () << " for bass" << std::endl;
151
- } else if (entry.path ().string ().find (" htdemucs_ft_other" ) !=
152
- std::string::npos)
152
+ ret = load_demucs_model (entry.path ().string (), &models[1 ]);
153
+ std::cout << " Loading ft model " << entry.path ().string ()
154
+ << " for bass" << std::endl;
155
+ }
156
+ else if (entry.path ().string ().find (" htdemucs_ft_other" ) !=
157
+ std::string::npos)
153
158
{
154
- ret = load_demucs_model (entry.path ().string (), &models[2 ]);
155
- std::cout << " Loading ft model " << entry.path ().string () << " for other" << std::endl;
156
- } else if (entry.path ().string ().find (" htdemucs_ft_vocals" ) !=
157
- std::string::npos)
159
+ ret = load_demucs_model (entry.path ().string (), &models[2 ]);
160
+ std::cout << " Loading ft model " << entry.path ().string ()
161
+ << " for other" << std::endl;
162
+ }
163
+ else if (entry.path ().string ().find (" htdemucs_ft_vocals" ) !=
164
+ std::string::npos)
158
165
{
159
- ret = load_demucs_model (entry.path ().string (), &models[3 ]);
160
- std::cout << " Loading ft model " << entry.path ().string () << " for vocals" << std::endl;
166
+ ret = load_demucs_model (entry.path ().string (), &models[3 ]);
167
+ std::cout << " Loading ft model " << entry.path ().string ()
168
+ << " for vocals" << std::endl;
161
169
}
162
170
163
171
// debug some members of model
164
172
std::cout << " demucs_model_load returned " << (ret ? " true" : " false" )
165
- << std::endl;
173
+ << std::endl;
166
174
if (!ret)
167
175
{
168
176
std::cerr << " Error loading model" << std::endl;
169
177
exit (1 );
170
178
}
171
-
172
179
}
173
180
174
181
const int nb_sources = 4 ;
175
182
176
183
std::cout << " Starting Demucs fine-tuned (" << std::to_string (nb_sources)
177
184
<< " -source) inference" << std::endl;
178
185
179
- demucscpp::ProgressCallback progressCallback = [](float progress)
180
- { std::cout << " Progress: " << progress * 100 << " %\n " ; };
186
+ // set output precision to 3 decimal places
187
+ std::cout << std::fixed << std::setprecision (3 );
188
+
189
+ demucscpp::ProgressCallback progressCallback1 =
190
+ [](float progress, const std::string &log_message)
191
+ {
192
+ std::cout << " [DRUMS] \t (" << std::setw (3 ) << std::setfill (' ' )
193
+ << progress * 25 .0f << " %) " << log_message << std::endl;
194
+ };
195
+ demucscpp::ProgressCallback progressCallback2 =
196
+ [](float progress, const std::string &log_message)
197
+ {
198
+ std::cout << " [BASS] \t (" << std::setw (3 ) << std::setfill (' ' )
199
+ << 25 .0f + progress * 25 .0f << " %) " << log_message
200
+ << std::endl;
201
+ };
202
+ demucscpp::ProgressCallback progressCallback3 =
203
+ [](float progress, const std::string &log_message)
204
+ {
205
+ std::cout << " [OTHER] \t (" << std::setw (3 ) << std::setfill (' ' )
206
+ << 50 .0f + progress * 25 .0f << " %) " << log_message
207
+ << std::endl;
208
+ };
209
+ demucscpp::ProgressCallback progressCallback4 =
210
+ [](float progress, const std::string &log_message)
211
+ {
212
+ std::cout << " [VOCALS] \t (" << std::setw (3 ) << std::setfill (' ' )
213
+ << 75 .0f + progress * 25 .0f << " %) " << log_message
214
+ << std::endl;
215
+ };
181
216
182
217
// create 4 audio matrix same size, to hold output
183
218
Eigen::Tensor3dXf drums_targets =
184
- demucscpp::demucs_inference (models[0 ], audio, progressCallback);
219
+ demucscpp::demucs_inference (models[0 ], audio, progressCallback1);
220
+
185
221
Eigen::Tensor3dXf bass_targets =
186
- demucscpp::demucs_inference (models[1 ], audio, progressCallback);
222
+ demucscpp::demucs_inference (models[1 ], audio, progressCallback2);
223
+
187
224
Eigen::Tensor3dXf other_targets =
188
- demucscpp::demucs_inference (models[2 ], audio, progressCallback);
225
+ demucscpp::demucs_inference (models[2 ], audio, progressCallback3);
226
+
189
227
Eigen::Tensor3dXf vocals_targets =
190
- demucscpp::demucs_inference (models[3 ], audio, progressCallback );
228
+ demucscpp::demucs_inference (models[3 ], audio, progressCallback4 );
191
229
192
230
out_targets = Eigen::Tensor3dXf (drums_targets.dimension (0 ),
193
- drums_targets.dimension (1 ),
194
- drums_targets.dimension (2 ));
231
+ drums_targets.dimension (1 ),
232
+ drums_targets.dimension (2 ));
195
233
196
234
// simply use the respective stem from each independent fine-tuned model
197
235
out_targets.chip <0 >(0 ) = drums_targets.chip <0 >(0 );
@@ -246,7 +284,8 @@ int main(int argc, const char **argv)
246
284
247
285
// insert target_name into the path after the digit
248
286
// e.g. target_name_0_drums.wav
249
- p_target.replace_filename (" target_" + std::to_string (target) + " _" + target_name + " .wav" );
287
+ p_target.replace_filename (" target_" + std::to_string (target) + " _" +
288
+ target_name + " .wav" );
250
289
251
290
std::cout << " Writing wav file " << p_target << std::endl;
252
291
0 commit comments