Skip to content

Commit 5cf8cb7

Browse files
authored
Better progress and logging per-layer and vendor gtest (#9)
* Better progress and logging per-layer Also vendor googletest * Remove libnyquist from library, use only in driver programs Inspired by @adamski's PR #8
1 parent e88c105 commit 5cf8cb7

16 files changed

+288
-239
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@
88
[submodule "vendor/libnyquist"]
99
path = vendor/libnyquist
1010
url = https://github.com/ddiakopoulos/libnyquist
11+
[submodule "vendor/googletest"]
12+
path = vendor/googletest
13+
url = https://github.com/google/googletest

CMakeLists.txt

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,8 @@ enable_testing()
2828
set(CMAKE_CXX_STANDARD 17)
2929
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
3030

31-
# compile vendored submodule libnyquist
32-
set(LIBNYQUIST_BUILD_EXAMPLE OFF CACHE BOOL "Disable libnyquist example")
33-
add_subdirectory(vendor/libnyquist)
34-
3531
# add library Eigen3
3632
include_directories(vendor/eigen)
37-
include_directories(vendor/libnyquist/include)
3833
include_directories(src)
3934

4035
if (USE_AMD_AOCL)
@@ -80,17 +75,23 @@ file(GLOB SOURCES "src/*.cpp")
8075

8176
add_library(demucs.cpp.lib ${SOURCES})
8277

83-
target_link_libraries(demucs.cpp.lib libnyquist ${LIBRARIES_TO_LINK})
78+
target_link_libraries(demucs.cpp.lib ${LIBRARIES_TO_LINK})
79+
80+
# compile vendored submodule libnyquist for driver programs
81+
set(LIBNYQUIST_BUILD_EXAMPLE OFF CACHE BOOL "Disable libnyquist example" FORCE)
82+
add_subdirectory(vendor/libnyquist)
8483

8584
# Add target to compile demucs.cpp.main, the main driver program for demucs.cpp
8685
add_executable(demucs.cpp.main "cli-apps/demucs.cpp")
87-
target_link_libraries(demucs.cpp.main demucs.cpp.lib)
86+
target_include_directories(demucs.cpp.main PRIVATE vendor/libnyquist/include)
87+
target_link_libraries(demucs.cpp.main demucs.cpp.lib libnyquist)
8888

8989
# Add target to compile demucs_ft.cpp.main, the fine-tuned driver program for demucs.cpp
9090
add_executable(demucs_ft.cpp.main "cli-apps/demucs_ft.cpp")
91-
target_link_libraries(demucs_ft.cpp.main demucs.cpp.lib)
91+
target_include_directories(demucs_ft.cpp.main PRIVATE vendor/libnyquist/include)
92+
target_link_libraries(demucs_ft.cpp.main demucs.cpp.lib libnyquist)
9293

93-
file(GLOB SOURCES_TO_LINT "src/*.cpp" "src/*.hpp" "demucs.cpp" "test/*.cpp")
94+
file(GLOB SOURCES_TO_LINT "src/*.cpp" "src/*.hpp" "cli-apps/*.cpp")
9495

9596
# add target to run standard lints and formatters
9697
add_custom_target(lint
@@ -104,6 +105,10 @@ add_custom_target(lint
104105

105106
# add target to run cpp tests in test/ directory with gtest
106107

108+
# get gtest from vendor/googletest
109+
set(BUILD_GMOCK OFF CACHE BOOL "Disable gmock in googletest" FORCE)
110+
add_subdirectory(vendor/googletest)
111+
107112
# include test/*.cpp as test files
108113
file(GLOB TEST_SOURCES "test/*.cpp")
109114

cli-apps/demucs.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <Eigen/Dense>
66
#include <cassert>
77
#include <filesystem>
8+
#include <iomanip>
89
#include <iostream>
910
#include <libnyquist/Common.h>
1011
#include <libnyquist/Decoders.h>
@@ -146,13 +147,19 @@ int main(int argc, const char **argv)
146147
std::cout << "Starting Demucs (" << std::to_string(nb_sources)
147148
<< "-source) inference" << std::endl;
148149

149-
demucscpp::ProgressCallback progressCallback = [](float progress)
150-
{ std::cout << "Progress: " << progress * 100 << "%\n"; };
150+
// set output precision to 3 decimal places
151+
std::cout << std::fixed << std::setprecision(3);
152+
153+
demucscpp::ProgressCallback progressCallback =
154+
[](float progress, const std::string &log_message)
155+
{
156+
std::cout << "(" << std::setw(3) << std::setfill(' ')
157+
<< progress * 100.0f << "%) " << log_message << std::endl;
158+
};
151159

152160
// create 4 audio matrix same size, to hold output
153161
Eigen::Tensor3dXf audio_targets =
154162
demucscpp::demucs_inference(model, audio, progressCallback);
155-
std::cout << "returned!" << std::endl;
156163

157164
out_targets = audio_targets;
158165

@@ -203,7 +210,8 @@ int main(int argc, const char **argv)
203210

204211
// insert target_name into the path after the digit
205212
// e.g. target_name_0_drums.wav
206-
p_target.replace_filename("target_" + std::to_string(target) + "_" + target_name + ".wav");
213+
p_target.replace_filename("target_" + std::to_string(target) + "_" +
214+
target_name + ".wav");
207215

208216
std::cout << "Writing wav file " << p_target << std::endl;
209217

cli-apps/demucs_ft.cpp

Lines changed: 67 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <Eigen/Dense>
66
#include <cassert>
77
#include <filesystem>
8+
#include <iomanip>
89
#include <iostream>
910
#include <libnyquist/Common.h>
1011
#include <libnyquist/Decoders.h>
@@ -107,8 +108,8 @@ int main(int argc, const char **argv)
107108
{
108109
if (argc != 4)
109110
{
110-
std::cerr << "Usage: " << argv[0]
111-
<< " <model dir> <wav file> <out dir>" << std::endl;
111+
std::cerr << "Usage: " << argv[0] << " <model dir> <wav file> <out dir>"
112+
<< std::endl;
112113
exit(1);
113114
}
114115

@@ -135,63 +136,100 @@ int main(int argc, const char **argv)
135136
std::string model_file;
136137
for (const auto &entry : std::filesystem::directory_iterator(model_dir))
137138
{
138-
bool ret;
139+
bool ret = false;
139140

140141
// check if entry contains the name "htdemucs_ft_drums"
141142
if (entry.path().string().find("htdemucs_ft_drums") !=
142143
std::string::npos)
143144
{
144-
ret = load_demucs_model(entry.path().string(), &models[0]);
145-
std::cout << "Loading ft model " << entry.path().string() << " for drums" << std::endl;
146-
} else if (entry.path().string().find("htdemucs_ft_bass") !=
147-
std::string::npos)
145+
ret = load_demucs_model(entry.path().string(), &models[0]);
146+
std::cout << "Loading ft model " << entry.path().string()
147+
<< " for drums" << std::endl;
148+
}
149+
else if (entry.path().string().find("htdemucs_ft_bass") !=
150+
std::string::npos)
148151
{
149-
ret = load_demucs_model(entry.path().string(), &models[1]);
150-
std::cout << "Loading ft model " << entry.path().string() << " for bass" << std::endl;
151-
} else if (entry.path().string().find("htdemucs_ft_other") !=
152-
std::string::npos)
152+
ret = load_demucs_model(entry.path().string(), &models[1]);
153+
std::cout << "Loading ft model " << entry.path().string()
154+
<< " for bass" << std::endl;
155+
}
156+
else if (entry.path().string().find("htdemucs_ft_other") !=
157+
std::string::npos)
153158
{
154-
ret = load_demucs_model(entry.path().string(), &models[2]);
155-
std::cout << "Loading ft model " << entry.path().string() << " for other" << std::endl;
156-
} else if (entry.path().string().find("htdemucs_ft_vocals") !=
157-
std::string::npos)
159+
ret = load_demucs_model(entry.path().string(), &models[2]);
160+
std::cout << "Loading ft model " << entry.path().string()
161+
<< " for other" << std::endl;
162+
}
163+
else if (entry.path().string().find("htdemucs_ft_vocals") !=
164+
std::string::npos)
158165
{
159-
ret = load_demucs_model(entry.path().string(), &models[3]);
160-
std::cout << "Loading ft model " << entry.path().string() << " for vocals" << std::endl;
166+
ret = load_demucs_model(entry.path().string(), &models[3]);
167+
std::cout << "Loading ft model " << entry.path().string()
168+
<< " for vocals" << std::endl;
161169
}
162170

163171
// debug some members of model
164172
std::cout << "demucs_model_load returned " << (ret ? "true" : "false")
165-
<< std::endl;
173+
<< std::endl;
166174
if (!ret)
167175
{
168176
std::cerr << "Error loading model" << std::endl;
169177
exit(1);
170178
}
171-
172179
}
173180

174181
const int nb_sources = 4;
175182

176183
std::cout << "Starting Demucs fine-tuned (" << std::to_string(nb_sources)
177184
<< "-source) inference" << std::endl;
178185

179-
demucscpp::ProgressCallback progressCallback = [](float progress)
180-
{ std::cout << "Progress: " << progress * 100 << "%\n"; };
186+
// set output precision to 3 decimal places
187+
std::cout << std::fixed << std::setprecision(3);
188+
189+
demucscpp::ProgressCallback progressCallback1 =
190+
[](float progress, const std::string &log_message)
191+
{
192+
std::cout << "[DRUMS] \t(" << std::setw(3) << std::setfill(' ')
193+
<< progress * 25.0f << "%) " << log_message << std::endl;
194+
};
195+
demucscpp::ProgressCallback progressCallback2 =
196+
[](float progress, const std::string &log_message)
197+
{
198+
std::cout << "[BASS] \t(" << std::setw(3) << std::setfill(' ')
199+
<< 25.0f + progress * 25.0f << "%) " << log_message
200+
<< std::endl;
201+
};
202+
demucscpp::ProgressCallback progressCallback3 =
203+
[](float progress, const std::string &log_message)
204+
{
205+
std::cout << "[OTHER] \t(" << std::setw(3) << std::setfill(' ')
206+
<< 50.0f + progress * 25.0f << "%) " << log_message
207+
<< std::endl;
208+
};
209+
demucscpp::ProgressCallback progressCallback4 =
210+
[](float progress, const std::string &log_message)
211+
{
212+
std::cout << "[VOCALS] \t(" << std::setw(3) << std::setfill(' ')
213+
<< 75.0f + progress * 25.0f << "%) " << log_message
214+
<< std::endl;
215+
};
181216

182217
// create 4 audio matrix same size, to hold output
183218
Eigen::Tensor3dXf drums_targets =
184-
demucscpp::demucs_inference(models[0], audio, progressCallback);
219+
demucscpp::demucs_inference(models[0], audio, progressCallback1);
220+
185221
Eigen::Tensor3dXf bass_targets =
186-
demucscpp::demucs_inference(models[1], audio, progressCallback);
222+
demucscpp::demucs_inference(models[1], audio, progressCallback2);
223+
187224
Eigen::Tensor3dXf other_targets =
188-
demucscpp::demucs_inference(models[2], audio, progressCallback);
225+
demucscpp::demucs_inference(models[2], audio, progressCallback3);
226+
189227
Eigen::Tensor3dXf vocals_targets =
190-
demucscpp::demucs_inference(models[3], audio, progressCallback);
228+
demucscpp::demucs_inference(models[3], audio, progressCallback4);
191229

192230
out_targets = Eigen::Tensor3dXf(drums_targets.dimension(0),
193-
drums_targets.dimension(1),
194-
drums_targets.dimension(2));
231+
drums_targets.dimension(1),
232+
drums_targets.dimension(2));
195233

196234
// simply use the respective stem from each independent fine-tuned model
197235
out_targets.chip<0>(0) = drums_targets.chip<0>(0);
@@ -246,7 +284,8 @@ int main(int argc, const char **argv)
246284

247285
// insert target_name into the path after the digit
248286
// e.g. target_name_0_drums.wav
249-
p_target.replace_filename("target_" + std::to_string(target) + "_" + target_name + ".wav");
287+
p_target.replace_filename("target_" + std::to_string(target) + "_" +
288+
target_name + ".wav");
250289

251290
std::cout << "Writing wav file " << p_target << std::endl;
252291

src/conv.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,13 @@ Eigen::Tensor3dXf conv2d_gemm(const Eigen::Tensor3dXf &x,
118118
{
119119
for (int h = 0; h < out_height; ++h)
120120
{
121-
for (int w = 0; w < out_width; ++w)
121+
for (int w_ = 0; w_ < out_width; ++w_)
122122
{
123-
int row_idx = h * out_width + w;
123+
int row_idx = h * out_width + w_;
124124
// Assign the value from the GEMM output to the output tensor
125125
if (row_idx < result.rows())
126126
{
127-
y_out(chout, h, w) = result(row_idx, chout);
127+
y_out(chout, h, w_) = result(row_idx, chout);
128128
}
129129
}
130130
}
@@ -187,16 +187,16 @@ Eigen::Tensor3dXf conv2d_gemm_fused_gelu(const Eigen::Tensor3dXf &x,
187187
{
188188
for (int h = 0; h < out_height; ++h)
189189
{
190-
for (int w = 0; w < out_width; ++w)
190+
for (int w_ = 0; w_ < out_width; ++w_)
191191
{
192-
int row_idx = h * out_width + w;
192+
int row_idx = h * out_width + w_;
193193
// Assign the value from the GEMM output to the output tensor
194194
// with gelu
195195
float value = result(row_idx, chout);
196196
float activated_value =
197197
0.5f * value * (1.0f + std::erf(value / std::sqrt(2.0f)));
198198
// Assign the activated value to the output tensor
199-
y_out(chout, h, w) = activated_value;
199+
y_out(chout, h, w_) = activated_value;
200200
}
201201
}
202202
}
@@ -378,15 +378,15 @@ Eigen::Tensor3dXf conv2d_tr_gemm(const Eigen::Tensor3dXf &x,
378378
{
379379
for (int h = 0; h < out_height; ++h)
380380
{
381-
for (int w = 0; w < out_width; ++w)
381+
for (int w_ = 0; w_ < out_width; ++w_)
382382
{
383383
// Calculate the linear index in the GEMM result corresponding
384384
// to this output location
385-
int gemm_row = h * out_width + w;
385+
int gemm_row = h * out_width + w_;
386386
int gemm_col = ch;
387387

388388
// Assign the value from the GEMM result to the output tensor
389-
y_out(ch, h, w) += result(gemm_row, gemm_col);
389+
y_out(ch, h, w_) += result(gemm_row, gemm_col);
390390
}
391391
}
392392
}
@@ -455,11 +455,11 @@ Eigen::Tensor3dXf conv2d_tr_gemm_fused_gelu(const Eigen::Tensor3dXf &x,
455455
{
456456
for (int h = 0; h < out_height; ++h)
457457
{
458-
for (int w = 0; w < out_width; ++w)
458+
for (int w_ = 0; w_ < out_width; ++w_)
459459
{
460460
// Calculate the linear index in the GEMM result corresponding
461461
// to this output location
462-
int gemm_row = h * out_width + w;
462+
int gemm_row = h * out_width + w_;
463463
int gemm_col = ch;
464464

465465
// Compute the value from the GEMM result
@@ -470,7 +470,7 @@ Eigen::Tensor3dXf conv2d_tr_gemm_fused_gelu(const Eigen::Tensor3dXf &x,
470470
0.5f * value * (1.0f + std::erf(value / std::sqrt(2.0f)));
471471

472472
// Assign the activated value to the output tensor
473-
y_out(ch, h, w) += activated_value;
473+
y_out(ch, h, w_) += activated_value;
474474
}
475475
}
476476
}

0 commit comments

Comments
 (0)