@@ -152,6 +152,7 @@ static EasyCNN::NetWork buildConvNet(const size_t batch,const size_t channels,co
152
152
network.setPhase (EasyCNN::Phase::Train);
153
153
network.setInputSize (EasyCNN::DataSize (batch, channels, width, height));
154
154
network.setLossFunctor (std::make_shared<EasyCNN::CrossEntropyFunctor>());
155
+ network.setOptimizer (std::make_shared<EasyCNN::SGD>(0 .01f ));
155
156
// input data layer 0
156
157
std::shared_ptr<EasyCNN::InputLayer> _0_inputLayer (std::make_shared<EasyCNN::InputLayer>());
157
158
network.addayer (_0_inputLayer);
@@ -197,6 +198,7 @@ static EasyCNN::NetWork buildMLPNet(const size_t batch, const size_t channels, c
197
198
network.setPhase (EasyCNN::Phase::Train);
198
199
network.setInputSize (EasyCNN::DataSize (batch, channels, width, height));
199
200
network.setLossFunctor (std::make_shared<EasyCNN::MSEFunctor>());
201
+ network.setOptimizer (std::make_shared<EasyCNN::SGDWithMomentum>(0 .01f ,0 .9f ));
200
202
// input data layer
201
203
std::shared_ptr<EasyCNN::InputLayer> _0_inputLayer (std::make_shared<EasyCNN::InputLayer>());
202
204
network.addayer (_0_inputLayer);
@@ -277,7 +279,7 @@ static void train(const std::string& mnist_train_images_file,
277
279
EasyCNN::logCritical (" load training data done. train set's size is %d,validate set's size is %d" , train_images.size (), validate_images.size ());
278
280
279
281
float learningRate = 0 .1f ;
280
- const float decayRate = 0 .001f ;
282
+ const float decayRate = 0 .2f ;
281
283
const float minLearningRate = 0 .001f ;
282
284
const size_t testAfterBatches = 200 ;
283
285
const size_t maxBatches = 10000 ;
@@ -291,7 +293,8 @@ static void train(const std::string& mnist_train_images_file,
291
293
EasyCNN::logCritical (" channels:%d , width:%d , height:%d" , channels, width, height);
292
294
293
295
EasyCNN::logCritical (" construct network begin..." );
294
- EasyCNN::NetWork network (buildConvNet (batch, channels, width, height));
296
+ EasyCNN::NetWork network (buildMLPNet (batch, channels, width, height));
297
+ network.setLearningRate (learningRate);
295
298
EasyCNN::logCritical (" construct network done." );
296
299
297
300
// train
@@ -301,18 +304,18 @@ static void train(const std::string& mnist_train_images_file,
301
304
size_t epochIdx = 0 ;
302
305
while (epochIdx < max_epoch)
303
306
{
307
+ // before epoch start, shuffle all train data first
308
+ shuffle_data (images, labels);
304
309
size_t batchIdx = 0 ;
305
310
while (true )
306
311
{
307
312
if (!fetch_data (train_images, inputDataBucket, train_labels, labelDataBucket, batchIdx*batch, batch))
308
313
{
309
314
break ;
310
315
}
311
- const float loss = network.trainBatch (inputDataBucket,labelDataBucket, learningRate );
316
+ const float loss = network.trainBatch (inputDataBucket,labelDataBucket);
312
317
if (batchIdx > 0 && batchIdx % testAfterBatches == 0 )
313
318
{
314
- learningRate -= decayRate;
315
- learningRate = std::max (learningRate, minLearningRate);
316
319
const float accuracy = test (network,128 ,validate_images, validate_labels);
317
320
EasyCNN::logCritical (" sample : %d/%d , learningRate : %f , loss : %f , accuracy : %.4f%%" ,
318
321
batchIdx*batch, train_images.size (), learningRate, loss, accuracy*100 .0f );
@@ -326,8 +329,11 @@ static void train(const std::string& mnist_train_images_file,
326
329
if (batchIdx >= maxBatches)
327
330
{
328
331
break ;
329
- }
332
+ }
330
333
const float accuracy = test (network,128 ,validate_images, validate_labels);
334
+ // update learning rate
335
+ learningRate = std::max (learningRate*decayRate, minLearningRate);
336
+ network.setLearningRate (learningRate);
331
337
EasyCNN::logCritical (" epoch[%d] accuracy : %.4f%%" , epochIdx++, accuracy*100 .0f );
332
338
}
333
339
const float accuracy = test (network, 128 , validate_images, validate_labels);
@@ -376,9 +382,9 @@ static void test(const std::string& mnist_test_images_file,
376
382
EasyCNN::logCritical (" finished test." );
377
383
}
378
384
379
- static std::shared_ptr<EasyCNN::DataBucket> loadImage (const std::vector<std::string>& filePaths )
385
+ static std::shared_ptr<EasyCNN::DataBucket> loadImage (const std::vector<std::pair< int , cv::Mat>>& samples )
380
386
{
381
- const int number = filePaths .size ();
387
+ const int number = samples .size ();
382
388
const int channel = 1 ;
383
389
const int width = 20 ;
384
390
const int height = 20 ;
@@ -387,11 +393,11 @@ static std::shared_ptr<EasyCNN::DataBucket> loadImage(const std::vector<std::str
387
393
const float scaleRate = 1 .0f / 255 .0f ;
388
394
for (size_t i = 0 ; i < (size_t )number; i++)
389
395
{
390
- const cv::Mat srcGrayImg = cv::imread (filePaths [i], cv::IMREAD_GRAYSCALE) ;
396
+ const cv::Mat srcGrayImg = samples [i]. second ;
391
397
cv::Mat normalisedImg;
392
398
cv::resize (srcGrayImg, normalisedImg, cv::Size (width, height));
393
399
cv::Mat binaryImg;
394
- cv::threshold (normalisedImg, binaryImg, 127 , 255 , CV_THRESH_BINARY_INV );
400
+ cv::threshold (normalisedImg, binaryImg, 127 , 255 , CV_THRESH_BINARY );
395
401
396
402
// image data
397
403
float * inputData = result->getData ().get () + i*sizePerImage;
@@ -403,7 +409,7 @@ static std::shared_ptr<EasyCNN::DataBucket> loadImage(const std::vector<std::str
403
409
}
404
410
return result;
405
411
}
406
- static void test_single (const std::vector<std::string>& filePaths , const std::string& modelFilePath)
412
+ static void test_single (const std::vector<std::pair< int , cv::Mat>>& samples , const std::string& modelFilePath)
407
413
{
408
414
bool success = false ;
409
415
@@ -418,26 +424,55 @@ static void test_single(const std::vector<std::string>& filePaths, const std::st
418
424
// train
419
425
EasyCNN::logCritical (" begin test..." );
420
426
421
- const std::shared_ptr<EasyCNN::DataBucket> inputDataBucket = loadImage (filePaths );
427
+ const std::shared_ptr<EasyCNN::DataBucket> inputDataBucket = loadImage (samples );
422
428
const std::shared_ptr<EasyCNN::DataBucket> probDataBucket = network.testBatch (inputDataBucket);
423
429
const size_t labelSize = probDataBucket->getSize ()._3DSize ();
424
430
const float * probData = probDataBucket->getData ().get ();
425
- for (size_t j = 0 ; j < filePaths .size (); j ++)
431
+ for (size_t i = 0 ; i < samples .size (); i ++)
426
432
{
427
- const uint8_t testProb = getMaxIdxInArray (probData + j *labelSize, probData + (j + 1 ) * labelSize);
433
+ const uint8_t testProb = getMaxIdxInArray (probData + i *labelSize, probData + (i + 1 ) * labelSize);
428
434
EasyCNN::logCritical (" label : %d" ,testProb);
429
435
430
- const cv::Mat srcGrayImg = cv::imread (filePaths[j], cv::IMREAD_GRAYSCALE) ;
436
+ const cv::Mat srcGrayImg = samples[i]. second ;
431
437
cv::destroyAllWindows ();
432
438
cv::imshow (" src" , srcGrayImg);
433
439
cv::waitKey (0 );
434
440
}
435
441
EasyCNN::logCritical (" finished test." );
436
442
}
443
+ static cv::Mat image_to_cv (const image_t & img)
444
+ {
445
+ assert (img.channels == 1 );
446
+ cv::Mat result (img.height , img.width ,CV_8UC1,(void *)(&img.data [0 ]),img.width );
447
+ return result.clone ();
448
+ }
449
+ static std::vector<std::pair<int , cv::Mat>> export_random_mnist_image (const std::string& mnist_test_images_file,
450
+ const std::string& mnist_test_labels_file,
451
+ const int test_size)
452
+ {
453
+ std::vector<std::pair<int , cv::Mat>> result;
454
+ bool success = true ;
455
+ std::vector<image_t > images;
456
+ success = load_mnist_images (mnist_test_images_file, images);
457
+ assert (success);
458
+ std::vector<label_t > labels;
459
+ success = load_mnist_labels (mnist_test_labels_file, labels);
460
+ assert (success);
461
+ std::default_random_engine generator;
462
+ std::uniform_int_distribution<int > dis (0 , images.size ());
463
+ for (int i = 0 ; i < test_size;i++)
464
+ {
465
+ const int idx = dis (generator);
466
+ const int label = labels[idx].data ;
467
+ const cv::Mat image = image_to_cv (images[idx]);
468
+ result.push_back (std::make_pair (label, image));
469
+ }
470
+ return result;
471
+ }
437
472
int mnist_main (int argc, char * argv[])
438
473
{
439
- const std::string model_file = " ../../res/model/mnist_conv .model" ;
440
- #if 0
474
+ const std::string model_file = " ../../res/model/mnist_mlp .model" ;
475
+ #if 1
441
476
const std::string mnist_train_images_file = " ../../res/mnist_data/train-images.idx3-ubyte" ;
442
477
const std::string mnist_train_labels_file = " ../../res/mnist_data/train-labels.idx1-ubyte" ;
443
478
train (mnist_train_images_file, mnist_train_labels_file, model_file);
@@ -447,8 +482,11 @@ int mnist_main(int argc, char* argv[])
447
482
const std::string mnist_test_images_file = " ../../res/mnist_data/t10k-images.idx3-ubyte" ;
448
483
const std::string mnist_test_labels_file = " ../../res/mnist_data/t10k-labels.idx1-ubyte" ;
449
484
test (mnist_test_images_file, mnist_test_labels_file, model_file);
485
+ #else
486
+ const std::string mnist_test_images_file = "../../res/mnist_data/t10k-images.idx3-ubyte";
487
+ const std::string mnist_test_labels_file = "../../res/mnist_data/t10k-labels.idx1-ubyte";
488
+ std::vector<std::pair<int, cv::Mat>> samples = export_random_mnist_image(mnist_test_images_file, mnist_test_labels_file, 10);
489
+ test_single(samples, model_file);
450
490
#endif
451
-
452
- test_single (std::vector<std::string>{" d:/0.png" , " d:/1.png" , " d:/2.png" }, model_file);
453
491
return 0 ;
454
492
}
0 commit comments