Skip to content

Commit fb41c47

Browse files
committed
refactor some code. add more optimizer.
1 parent 58f70ed commit fb41c47

22 files changed

+298
-121
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ small, clean, easy to understand!
99
* All in one: without any dependency, pure c++ implemented.
1010
* Basic layer: data layer, convolution layer, pooling layer, full connect layer, softmax layer, activation layers(sigmod, tanh, RELU)
1111
* Loss function: Cross Entropy, MSE.
12-
* Optimize method: SGD.
12+
* Optimize method: SGD, SGDWithMomentum.
1313

1414
## Examples
1515
* mnist demo, with ConvNet and MLP net, [examples/mnist/mnist_train_test.cpp](./examples/mnist/mnist_train_test.cpp "mnist_train_test.cpp")
@@ -19,7 +19,7 @@ small, clean, easy to understand!
1919
* ~~fix train error when batch > 1 issue.~~
2020
* ~~add load & save model function.~~
2121
* add more layer, such as batch normalization layer, dropout layer, etc.
22-
* add weight regular, gradient momentum.
22+
* add weight regular.
2323
* port to other platforms, such as linux, mac, android, iOS, etc.
2424
* optimize network train/test speed, use cuBLAS/OpenBLAS etc.
2525
* add more optimize method.

examples/digit/digit_train_test.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ static EasyCNN::NetWork buildMLPNet(const size_t batch, const size_t channels, c
197197
network.setPhase(EasyCNN::Phase::Train);
198198
network.setInputSize(EasyCNN::DataSize(batch, channels, width, height));
199199
network.setLossFunctor(std::make_shared<EasyCNN::MSEFunctor>());
200+
network.setOptimizer(std::make_shared<EasyCNN::SGDWithMomentum>(0.01f, 0.9f));
200201
//input data layer
201202
std::shared_ptr<EasyCNN::InputLayer> _0_inputLayer(std::make_shared<EasyCNN::InputLayer>());
202203
network.addayer(_0_inputLayer);
@@ -273,7 +274,7 @@ static void train(const std::string& digit_train_images_dir,
273274
EasyCNN::logCritical("load training data done. train set's size is %d,validate set's size is %d", train_images.size(), validate_images.size());
274275

275276
float learningRate = 0.1f;
276-
const float decayRate = 0.001f;
277+
const float decayRate = 0.2f;
277278
const float minLearningRate = 0.00001f;
278279
const size_t testAfterBatches = 200;
279280
const size_t maxBatches = 100000000;
@@ -288,6 +289,7 @@ static void train(const std::string& digit_train_images_dir,
288289

289290
EasyCNN::logCritical("construct network begin...");
290291
EasyCNN::NetWork network(buildConvNet(batch, channels, width, height));
292+
network.setLearningRate(learningRate);
291293
EasyCNN::logCritical("construct network done.");
292294

293295
//train
@@ -297,18 +299,18 @@ static void train(const std::string& digit_train_images_dir,
297299
size_t epochIdx = 0;
298300
while (epochIdx < max_epoch)
299301
{
302+
//before epoch start, shuffle all train data first
303+
shuffle_data(images, labels);
300304
size_t batchIdx = 0;
301305
while (true)
302306
{
303307
if (!fetch_data(train_images, inputDataBucket, train_labels, labelDataBucket, batchIdx*batch, batch))
304308
{
305309
break;
306310
}
307-
const float loss = network.trainBatch(inputDataBucket,labelDataBucket, learningRate);
311+
const float loss = network.trainBatch(inputDataBucket,labelDataBucket);
308312
if (batchIdx > 0 && batchIdx % testAfterBatches == 0)
309-
{
310-
learningRate -= decayRate;
311-
learningRate = std::max(learningRate, minLearningRate);
313+
{
312314
const float accuracy = test_batch(network,128,validate_images, validate_labels);
313315
EasyCNN::logCritical("sample : %d/%d , learningRate : %f , loss : %f , accuracy : %.4f%%",
314316
batchIdx*batch, train_images.size(), learningRate, loss, accuracy*100.0f);
@@ -325,6 +327,10 @@ static void train(const std::string& digit_train_images_dir,
325327
}
326328
const float accuracy = test_batch(network,128,validate_images, validate_labels);
327329
EasyCNN::logCritical("epoch[%d] accuracy : %.4f%%", epochIdx++, accuracy*100.0f);
330+
//update learning rate
331+
learningRate *= decayRate;
332+
learningRate = std::max(learningRate, minLearningRate);
333+
network.setLearningRate(learningRate);
328334
if (accuracy >= 0.99)
329335
{
330336
break;

examples/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ extern int digit_main(int argc, char* argv[]);
77

88
int main(int argc, char* argv[])
99
{
10-
return digit_main(argc, argv);
10+
return mnist_main(argc, argv);
1111
}

examples/mnist/mnist_train_test.cpp

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ static EasyCNN::NetWork buildConvNet(const size_t batch,const size_t channels,co
152152
network.setPhase(EasyCNN::Phase::Train);
153153
network.setInputSize(EasyCNN::DataSize(batch, channels, width, height));
154154
network.setLossFunctor(std::make_shared<EasyCNN::CrossEntropyFunctor>());
155+
network.setOptimizer(std::make_shared<EasyCNN::SGD>(0.01f));
155156
//input data layer 0
156157
std::shared_ptr<EasyCNN::InputLayer> _0_inputLayer(std::make_shared<EasyCNN::InputLayer>());
157158
network.addayer(_0_inputLayer);
@@ -197,6 +198,7 @@ static EasyCNN::NetWork buildMLPNet(const size_t batch, const size_t channels, c
197198
network.setPhase(EasyCNN::Phase::Train);
198199
network.setInputSize(EasyCNN::DataSize(batch, channels, width, height));
199200
network.setLossFunctor(std::make_shared<EasyCNN::MSEFunctor>());
201+
network.setOptimizer(std::make_shared<EasyCNN::SGDWithMomentum>(0.01f,0.9f));
200202
//input data layer
201203
std::shared_ptr<EasyCNN::InputLayer> _0_inputLayer(std::make_shared<EasyCNN::InputLayer>());
202204
network.addayer(_0_inputLayer);
@@ -277,7 +279,7 @@ static void train(const std::string& mnist_train_images_file,
277279
EasyCNN::logCritical("load training data done. train set's size is %d,validate set's size is %d", train_images.size(), validate_images.size());
278280

279281
float learningRate = 0.1f;
280-
const float decayRate = 0.001f;
282+
const float decayRate = 0.2f;
281283
const float minLearningRate = 0.001f;
282284
const size_t testAfterBatches = 200;
283285
const size_t maxBatches = 10000;
@@ -291,7 +293,8 @@ static void train(const std::string& mnist_train_images_file,
291293
EasyCNN::logCritical("channels:%d , width:%d , height:%d", channels, width, height);
292294

293295
EasyCNN::logCritical("construct network begin...");
294-
EasyCNN::NetWork network(buildConvNet(batch, channels, width, height));
296+
EasyCNN::NetWork network(buildMLPNet(batch, channels, width, height));
297+
network.setLearningRate(learningRate);
295298
EasyCNN::logCritical("construct network done.");
296299

297300
//train
@@ -301,18 +304,18 @@ static void train(const std::string& mnist_train_images_file,
301304
size_t epochIdx = 0;
302305
while (epochIdx < max_epoch)
303306
{
307+
//before epoch start, shuffle all train data first
308+
shuffle_data(images, labels);
304309
size_t batchIdx = 0;
305310
while (true)
306311
{
307312
if (!fetch_data(train_images, inputDataBucket, train_labels, labelDataBucket, batchIdx*batch, batch))
308313
{
309314
break;
310315
}
311-
const float loss = network.trainBatch(inputDataBucket,labelDataBucket, learningRate);
316+
const float loss = network.trainBatch(inputDataBucket,labelDataBucket);
312317
if (batchIdx > 0 && batchIdx % testAfterBatches == 0)
313318
{
314-
learningRate -= decayRate;
315-
learningRate = std::max(learningRate, minLearningRate);
316319
const float accuracy = test(network,128,validate_images, validate_labels);
317320
EasyCNN::logCritical("sample : %d/%d , learningRate : %f , loss : %f , accuracy : %.4f%%",
318321
batchIdx*batch, train_images.size(), learningRate, loss, accuracy*100.0f);
@@ -326,8 +329,11 @@ static void train(const std::string& mnist_train_images_file,
326329
if (batchIdx >= maxBatches)
327330
{
328331
break;
329-
}
332+
}
330333
const float accuracy = test(network,128,validate_images, validate_labels);
334+
//update learning rate
335+
learningRate = std::max(learningRate*decayRate, minLearningRate);
336+
network.setLearningRate(learningRate);
331337
EasyCNN::logCritical("epoch[%d] accuracy : %.4f%%", epochIdx++, accuracy*100.0f);
332338
}
333339
const float accuracy = test(network, 128, validate_images, validate_labels);
@@ -376,9 +382,9 @@ static void test(const std::string& mnist_test_images_file,
376382
EasyCNN::logCritical("finished test.");
377383
}
378384

379-
static std::shared_ptr<EasyCNN::DataBucket> loadImage(const std::vector<std::string>& filePaths)
385+
static std::shared_ptr<EasyCNN::DataBucket> loadImage(const std::vector<std::pair<int, cv::Mat>>& samples)
380386
{
381-
const int number = filePaths.size();
387+
const int number = samples.size();
382388
const int channel = 1;
383389
const int width = 20;
384390
const int height = 20;
@@ -387,11 +393,11 @@ static std::shared_ptr<EasyCNN::DataBucket> loadImage(const std::vector<std::str
387393
const float scaleRate = 1.0f / 255.0f;
388394
for (size_t i = 0; i < (size_t)number; i++)
389395
{
390-
const cv::Mat srcGrayImg = cv::imread(filePaths[i], cv::IMREAD_GRAYSCALE);
396+
const cv::Mat srcGrayImg = samples[i].second;
391397
cv::Mat normalisedImg;
392398
cv::resize(srcGrayImg, normalisedImg, cv::Size(width, height));
393399
cv::Mat binaryImg;
394-
cv::threshold(normalisedImg, binaryImg, 127, 255, CV_THRESH_BINARY_INV);
400+
cv::threshold(normalisedImg, binaryImg, 127, 255, CV_THRESH_BINARY);
395401

396402
//image data
397403
float* inputData = result->getData().get() + i*sizePerImage;
@@ -403,7 +409,7 @@ static std::shared_ptr<EasyCNN::DataBucket> loadImage(const std::vector<std::str
403409
}
404410
return result;
405411
}
406-
static void test_single(const std::vector<std::string>& filePaths, const std::string& modelFilePath)
412+
static void test_single(const std::vector<std::pair<int, cv::Mat>>& samples, const std::string& modelFilePath)
407413
{
408414
bool success = false;
409415

@@ -418,26 +424,55 @@ static void test_single(const std::vector<std::string>& filePaths, const std::st
418424
//train
419425
EasyCNN::logCritical("begin test...");
420426

421-
const std::shared_ptr<EasyCNN::DataBucket> inputDataBucket = loadImage(filePaths);
427+
const std::shared_ptr<EasyCNN::DataBucket> inputDataBucket = loadImage(samples);
422428
const std::shared_ptr<EasyCNN::DataBucket> probDataBucket = network.testBatch(inputDataBucket);
423429
const size_t labelSize = probDataBucket->getSize()._3DSize();
424430
const float* probData = probDataBucket->getData().get();
425-
for (size_t j = 0; j < filePaths.size(); j++)
431+
for (size_t i = 0; i < samples.size(); i++)
426432
{
427-
const uint8_t testProb = getMaxIdxInArray(probData + j*labelSize, probData + (j + 1) * labelSize);
433+
const uint8_t testProb = getMaxIdxInArray(probData + i*labelSize, probData + (i + 1) * labelSize);
428434
EasyCNN::logCritical("label : %d",testProb);
429435

430-
const cv::Mat srcGrayImg = cv::imread(filePaths[j], cv::IMREAD_GRAYSCALE);
436+
const cv::Mat srcGrayImg = samples[i].second;
431437
cv::destroyAllWindows();
432438
cv::imshow("src", srcGrayImg);
433439
cv::waitKey(0);
434440
}
435441
EasyCNN::logCritical("finished test.");
436442
}
443+
static cv::Mat image_to_cv(const image_t& img)
444+
{
445+
assert(img.channels == 1);
446+
cv::Mat result(img.height, img.width,CV_8UC1,(void*)(&img.data[0]),img.width);
447+
return result.clone();
448+
}
449+
static std::vector<std::pair<int, cv::Mat>> export_random_mnist_image(const std::string& mnist_test_images_file,
450+
const std::string& mnist_test_labels_file,
451+
const int test_size)
452+
{
453+
std::vector<std::pair<int, cv::Mat>> result;
454+
bool success = true;
455+
std::vector<image_t> images;
456+
success = load_mnist_images(mnist_test_images_file, images);
457+
assert(success);
458+
std::vector<label_t> labels;
459+
success = load_mnist_labels(mnist_test_labels_file, labels);
460+
assert(success);
461+
std::default_random_engine generator;
462+
std::uniform_int_distribution<int> dis(0, images.size());
463+
for (int i = 0; i < test_size;i++)
464+
{
465+
const int idx = dis(generator);
466+
const int label = labels[idx].data;
467+
const cv::Mat image = image_to_cv(images[idx]);
468+
result.push_back(std::make_pair(label, image));
469+
}
470+
return result;
471+
}
437472
int mnist_main(int argc, char* argv[])
438473
{
439-
const std::string model_file = "../../res/model/mnist_conv.model";
440-
#if 0
474+
const std::string model_file = "../../res/model/mnist_mlp.model";
475+
#if 1
441476
const std::string mnist_train_images_file = "../../res/mnist_data/train-images.idx3-ubyte";
442477
const std::string mnist_train_labels_file = "../../res/mnist_data/train-labels.idx1-ubyte";
443478
train(mnist_train_images_file, mnist_train_labels_file, model_file);
@@ -447,8 +482,11 @@ int mnist_main(int argc, char* argv[])
447482
const std::string mnist_test_images_file = "../../res/mnist_data/t10k-images.idx3-ubyte";
448483
const std::string mnist_test_labels_file = "../../res/mnist_data/t10k-labels.idx1-ubyte";
449484
test(mnist_test_images_file, mnist_test_labels_file, model_file);
485+
#else
486+
const std::string mnist_test_images_file = "../../res/mnist_data/t10k-images.idx3-ubyte";
487+
const std::string mnist_test_labels_file = "../../res/mnist_data/t10k-labels.idx1-ubyte";
488+
std::vector<std::pair<int, cv::Mat>> samples = export_random_mnist_image(mnist_test_images_file, mnist_test_labels_file, 10);
489+
test_single(samples, model_file);
450490
#endif
451-
452-
test_single(std::vector<std::string>{"d:/0.png", "d:/1.png", "d:/2.png"}, model_file);
453491
return 0;
454492
}

header/EasyCNN/ActivationLayer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ namespace EasyCNN
66
{
77
class ActivationLayer : public Layer
88
{
9+
public:
10+
virtual std::string getLayerType() const = 0;
11+
virtual void forward(const std::shared_ptr<DataBucket> prevDataBucket, std::shared_ptr<DataBucket> nextDataBucket) = 0;
12+
virtual void backward(std::shared_ptr<DataBucket> prevDataBucket, const std::shared_ptr<DataBucket> nextDataBucket, std::shared_ptr<DataBucket>& nextDiffBucket) = 0;
913
};
1014

1115
class SigmodLayer : public ActivationLayer

header/EasyCNN/ConvolutionLayer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ namespace EasyCNN
2424
size_t widthStep = 0;
2525
size_t heightStep = 0;
2626
std::shared_ptr<ParamBucket> kernelData;
27+
std::shared_ptr<ParamBucket> kernelDiffData;
2728
bool enabledBias = false;
2829
std::shared_ptr<ParamBucket> biasData;
30+
std::shared_ptr<ParamBucket> biasDiffData;
2931
};
3032
}

header/EasyCNN/DataBucket.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#pragma once
2-
#include <iostream>
32
#include <memory>
43
#include "EasyCNN/Configure.h"
54
#include "EasyCNN/EasyLogger.h"
@@ -13,6 +12,7 @@ namespace EasyCNN
1312
DataSize() = default;
1413
DataSize(const size_t _number, const size_t _channels, const size_t _width, const size_t _height)
1514
:number(_number),channels(_channels), width(_width), height(_height){}
15+
inline size_t totalSize() const { return _4DSize(); }
1616
inline size_t _4DSize() const { return number*channels*width*height; }
1717
inline size_t _3DSize() const { return channels*width*height; }
1818
inline size_t _2DSize() const { return width*height; }

header/EasyCNN/FullconnectLayer.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ namespace EasyCNN
1919
virtual std::string getLayerType() const override;
2020
virtual void solveInnerParams() override;
2121
virtual void forward(const std::shared_ptr<DataBucket> prevDataBucket, std::shared_ptr<DataBucket> nextDataBucket) override;
22-
virtual void backward(std::shared_ptr<DataBucket> prevDataBucket, const std::shared_ptr<DataBucket> nextDataBucket, std::shared_ptr<DataBucket>& nextDiffBucket) override;
22+
virtual void backward(std::shared_ptr<DataBucket> prevDataBucket, const std::shared_ptr<DataBucket> nextDataBucket, std::shared_ptr<DataBucket>& nextDiffBucket) override;
2323
private:
2424
ParamSize outMapSize;
2525
std::shared_ptr<ParamBucket> weightsData;
26+
std::shared_ptr<ParamBucket> weightsDiffData;
2627
bool enabledBias = false;
27-
std::shared_ptr<ParamBucket> biasData;
28+
std::shared_ptr<ParamBucket> biasData;
29+
std::shared_ptr<ParamBucket> biasDiffData;
2830
};
2931
}

header/EasyCNN/Layer.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22
#include <memory>
33
#include <string>
4+
#include <vector>
45
#include "EasyCNN/Configure.h"
56
#include "EasyCNN/DataBucket.h"
67
#include "EasyCNN/ParamBucket.h"
@@ -29,6 +30,10 @@ namespace EasyCNN
2930
//learning rate
3031
inline void setLearningRate(const float learningRate){ this->learningRate = learningRate; }
3132
inline float getLearningRate() const{ return learningRate; }
33+
//diff
34+
inline std::vector<std::shared_ptr<ParamBucket>> getDiffData() const { return diff; }
35+
//params
36+
inline std::vector<std::shared_ptr<ParamBucket>> getParamData() const { return params; }
3237
//size
3338
inline void setInputBucketSize(const DataSize size){ inputSize = size; }
3439
inline DataSize getInputBucketSize() const{ return inputSize; }
@@ -39,7 +44,11 @@ namespace EasyCNN
3944
//data flow
4045
virtual void forward(const std::shared_ptr<DataBucket> prevDataBucket, std::shared_ptr<DataBucket> nextDataBucket) = 0;
4146
virtual void backward(std::shared_ptr<DataBucket> prevDataBucket, const std::shared_ptr<DataBucket> nextDataBucket, std::shared_ptr<DataBucket>& nextDiffBucket) = 0;
42-
private:
47+
protected:
48+
//subclass must add all diffs to diff
49+
std::vector<std::shared_ptr<DataBucket>> diff;
50+
//subclass must add all weight to params
51+
std::vector<std::shared_ptr<ParamBucket>> params;
4352
Phase phase = Phase::Train;
4453
DataSize inputSize;
4554
DataSize outputSize;

header/EasyCNN/NetWork.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "EasyCNN/Configure.h"
55
#include "EasyCNN/Layer.h"
66
#include "EasyCNN/LossFunction.h"
7+
#include "EasyCNN/Optimizer.h"
78

89
namespace EasyCNN
910
{
@@ -22,17 +23,19 @@ namespace EasyCNN
2223
//train only!
2324
void setInputSize(const DataSize size);
2425
void setLossFunctor(std::shared_ptr<LossFunctor> lossFunctor);
26+
void setOptimizer(std::shared_ptr<Optimizer> optimizer);
27+
void setLearningRate(const float lr);
2528
void addayer(std::shared_ptr<Layer> layer);
2629
float trainBatch(const std::shared_ptr<DataBucket> inputDataBucket,
27-
const std::shared_ptr<DataBucket> labelDataBucket, float learningRate);
30+
const std::shared_ptr<DataBucket> labelDataBucket);
2831
bool saveModel(const std::string& modelFile);
2932
private:
3033
std::string encrypt(const std::string& content);
3134
std::string decrypt(const std::string& content);
3235
private:
3336
//common
3437
std::shared_ptr<EasyCNN::DataBucket> forward(const std::shared_ptr<DataBucket> inputDataBucket);
35-
float backward(const std::shared_ptr<DataBucket> labelDataBucket, float learningRate);
38+
float backward(const std::shared_ptr<DataBucket> labelDataBucket);
3639
std::string serializeToString() const;
3740
std::vector<std::shared_ptr<EasyCNN::Layer>> serializeFromString(const std::string content);
3841
std::shared_ptr<EasyCNN::Layer> createLayerByType(const std::string layerType);
@@ -41,5 +44,6 @@ namespace EasyCNN
4144
std::vector<std::shared_ptr<Layer>> layers;
4245
std::vector<std::shared_ptr<DataBucket>> dataBuckets;
4346
std::shared_ptr<LossFunctor> lossFunctor;
47+
std::shared_ptr<Optimizer> optimizer;
4448
};
4549
}

0 commit comments

Comments
 (0)