@@ -491,6 +491,8 @@ public class MAIN1 {
491
491
492
492
线性随机神经网络相较于 线性神经网络训练模型 来说具有强大的兼容性和较好的性能,但是其牺牲了些精确度,线性随机神经网络模型的使用以及训练出来的模型从保存如下所示。
493
493
494
+ - 1.19 版本中的线性神经网络计算实现案例
495
+
494
496
``` java
495
497
package zhao.algorithmMagic ;
496
498
@@ -550,4 +552,90 @@ public class MAIN1 {
550
552
ASModel . Utils . write(new File (" C:\\ Users\\ zhao\\ Desktop\\ fsDownload\\ MytModel.as" ), model);
551
553
}
552
554
}
553
- ```
555
+ ```
556
+
557
+ - 1.20 版本中的图像分类模型训练案例。
558
+
559
+ ``` java
560
+ package zhao.algorithmMagic ;
561
+
562
+ import zhao.algorithmMagic.core.model.* ;
563
+ import zhao.algorithmMagic.core.model.dataSet.ASDataSet ;
564
+ import zhao.algorithmMagic.operands.matrix.ColorMatrix ;
565
+ import zhao.algorithmMagic.operands.matrix.DoubleMatrix ;
566
+ import zhao.algorithmMagic.operands.matrix.block.DoubleMatrixSpace ;
567
+ import zhao.algorithmMagic.operands.matrix.block.IntegerMatrixSpace ;
568
+ import zhao.algorithmMagic.operands.table.FinalCell ;
569
+ import zhao.algorithmMagic.operands.vector.DoubleVector ;
570
+ import zhao.algorithmMagic.utils.ASMath ;
571
+ import zhao.algorithmMagic.utils.dataContainer.KeyValue ;
572
+
573
+ import java.io.File ;
574
+ import java.util.Arrays ;
575
+
576
+ public class MAIN1 {
577
+
578
+ public static void main (String [] args ) {
579
+
580
+ // 指定图尺寸
581
+ int w = 91 , h = 87 ;
582
+
583
+ // 准备 CNN 神经网络模型
584
+ SingleLayerCNNModel singleLayerCnnModel = ASModel . SINGLE_LAYER_CNN_MODEL ;
585
+ // 设置学习率
586
+ singleLayerCnnModel. setLearningRate(0.1f );
587
+ // 设置训练次数
588
+ singleLayerCnnModel. setLearnCount(200 );
589
+ // 设置激活函数
590
+ singleLayerCnnModel. setActivationFunction(ActivationFunction . LEAKY_RE_LU );
591
+ // 设置损失函数
592
+ singleLayerCnnModel. setLossFunction(LossFunction . MAE );
593
+
594
+ // 准备卷积核,目标为突出图形轮廓
595
+ DoubleMatrix parse = DoubleMatrix . parse(
596
+ new double []{- 1 , - 1 , - 1 },
597
+ new double []{- 1 , 8 , - 1 },
598
+ new double []{- 1 , - 1 , - 1 }
599
+ );
600
+ DoubleMatrixSpace core = DoubleMatrixSpace . parse(parse, parse, parse);
601
+ // 设置 卷积核
602
+ singleLayerCnnModel. setArg(SingleLayerCNNModel . KERNEL , new FinalCell<> (core));
603
+ // 设置 附加任务 池化 然后进行二值化操作 TODO 注意 如果需要模型的保存,请使用 Class 的方式进行设置,使用 lambda 将会导致模型无法反序列化
604
+ // 如果不需要,此处可以不进行设置
605
+ singleLayerCnnModel. setTransformation(
606
+ new PoolBinaryTfTask (2 , 1 , true , 50 , 0x011001 , 0x010101 , ColorMatrix . _R_)
607
+ );
608
+
609
+ // 获取到字母数据集
610
+ ASDataSet load = ASDataSet . Load . LETTER. load(w, h);
611
+ // 将目标数值与权重设置到网络
612
+ singleLayerCnnModel. setWeight(load. getY_train(), load. getImageWeight());
613
+
614
+ // 准备训练时的附加任务 打印信息
615
+ SingleLayerCNNModel . TaskConsumer taskConsumer = (loss, g, weight1) - > {
616
+ System . out. println(" \n 损失函数 = " + loss);
617
+ System . out. println(" 梯度数据 = " + Arrays . toString(g));
618
+ };
619
+
620
+ // 训练出结果模型
621
+ long start = System . currentTimeMillis();
622
+ ClassificationModel<IntegerMatrixSpace > model = singleLayerCnnModel. function(taskConsumer, load. getX_train());
623
+ System . out. println(" 训练模型完成,耗时:" + (System . currentTimeMillis() - start));
624
+ // 保存模型
625
+ ASModel . Utils . write(new File (" C:\\ Users\\ Liming\\ Desktop\\ fsDownload\\ MytModel.as" ), model);
626
+
627
+
628
+ // 提供一个新图 开始进行测试
629
+ IntegerMatrixSpace parse1 = IntegerMatrixSpace . parse(" C:\\ Users\\ Liming\\ Desktop\\ fsDownload\\ 字母5.jpg" , w, h);
630
+ // 放到模型中 获取到结果
631
+ KeyValue<String [], DoubleVector []> function = model. function(new IntegerMatrixSpace []{parse1});
632
+ // 提取结果向量
633
+ DoubleVector [] value = function. getValue();
634
+ // 由于被分类的图像对象只有一个,因此直接查看 0 索引的数据就好 这里是一个向量,其中每一个索引代表对应索引的类别得分值
635
+ System . out. println(value[0 ]);
636
+ // 查看向量中不同维度对应的类别
637
+ System . out. println(Arrays . toString(function. getKey()));
638
+ System . out. println(" 当前图像类别 = " + function. getKey()[ASMath . findMaxIndex(value[0 ]. toArray())]);
639
+ }
640
+ }
641
+ ```
0 commit comments