5
5
6
6
features, labels = load_data (" iris" )
7
7
labels = String .(labels)
8
+ classes = sort (unique (labels))
9
+ n = length (labels)
8
10
9
11
# train a decision stump (depth=1)
10
12
model = build_stump (labels, features)
11
13
preds = apply_tree (model, features)
12
14
cm = confusion_matrix (labels, preds)
13
15
@test cm. accuracy > 0.6
14
16
@test depth (model) == 1
17
+ probs = apply_tree_proba (model, features, classes)
18
+ @test reshape (sum (probs, dims= 2 ), n) ≈ ones (n)
15
19
16
20
# train full-tree classifier (over-fit)
17
21
model = build_tree (labels, features)
@@ -22,6 +26,8 @@ cm = confusion_matrix(labels, preds)
22
26
@test depth (model) == 5
23
27
@test typeof (preds) == Vector{String}
24
28
print_tree (model)
29
+ probs = apply_tree_proba (model, features, classes)
30
+ @test reshape (sum (probs, dims= 2 ), n) ≈ ones (n)
25
31
26
32
# prune tree to 8 leaves
27
33
pruning_purity = 0.9
@@ -38,6 +44,8 @@ pt = prune_tree(model, pruning_purity)
38
44
preds = apply_tree (pt, features)
39
45
cm = confusion_matrix (labels, preds)
40
46
@test 0.95 < cm. accuracy < 1.0
47
+ probs = apply_tree_proba (model, features, classes)
48
+ @test reshape (sum (probs, dims= 2 ), n) ≈ ones (n)
41
49
42
50
# prune tree to a stump, 2 leaves
43
51
pruning_purity = 0.5
@@ -63,6 +71,8 @@ preds = apply_forest(model, features)
63
71
cm = confusion_matrix (labels, preds)
64
72
@test cm. accuracy > 0.95
65
73
@test typeof (preds) == Vector{String}
74
+ probs = apply_forest_proba (model, features, classes)
75
+ @test reshape (sum (probs, dims= 2 ), n) ≈ ones (n)
66
76
67
77
# run n-fold cross validation for forests
68
78
println (" \n ##### nfoldCV Classification Forest #####" )
@@ -80,6 +90,8 @@ preds = apply_adaboost_stumps(model, coeffs, features)
80
90
cm = confusion_matrix (labels, preds)
81
91
@test cm. accuracy > 0.9
82
92
@test typeof (preds) == Vector{String}
93
+ probs = apply_adaboost_stumps_proba (model, coeffs, features, classes)
94
+ @test reshape (sum (probs, dims= 2 ), n) ≈ ones (n)
83
95
84
96
# run n-fold cross validation for boosted stumps, using 7 iterations and 3 folds
85
97
println (" \n ##### nfoldCV Classification Adaboosted Stumps #####" )
0 commit comments