Skip to content

Commit 98361a3

Browse files
committed
added apply_proba tests to iris suite
1 parent 3fb6c4d commit 98361a3

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

test/classification/iris.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55

66
features, labels = load_data("iris")
77
labels = String.(labels)
8+
classes = sort(unique(labels))
9+
n = length(labels)
810

911
# train a decision stump (depth=1)
1012
model = build_stump(labels, features)
1113
preds = apply_tree(model, features)
1214
cm = confusion_matrix(labels, preds)
1315
@test cm.accuracy > 0.6
1416
@test depth(model) == 1
17+
probs = apply_tree_proba(model, features, classes)
18+
@test reshape(sum(probs, dims=2), n) ones(n)
1519

1620
# train full-tree classifier (over-fit)
1721
model = build_tree(labels, features)
@@ -22,6 +26,8 @@ cm = confusion_matrix(labels, preds)
2226
@test depth(model) == 5
2327
@test typeof(preds) == Vector{String}
2428
print_tree(model)
29+
probs = apply_tree_proba(model, features, classes)
30+
@test reshape(sum(probs, dims=2), n) ones(n)
2531

2632
# prune tree to 8 leaves
2733
pruning_purity = 0.9
@@ -38,6 +44,8 @@ pt = prune_tree(model, pruning_purity)
3844
preds = apply_tree(pt, features)
3945
cm = confusion_matrix(labels, preds)
4046
@test 0.95 < cm.accuracy < 1.0
47+
probs = apply_tree_proba(model, features, classes)
48+
@test reshape(sum(probs, dims=2), n) ones(n)
4149

4250
# prune tree to a stump, 2 leaves
4351
pruning_purity = 0.5
@@ -63,6 +71,8 @@ preds = apply_forest(model, features)
6371
cm = confusion_matrix(labels, preds)
6472
@test cm.accuracy > 0.95
6573
@test typeof(preds) == Vector{String}
74+
probs = apply_forest_proba(model, features, classes)
75+
@test reshape(sum(probs, dims=2), n) ones(n)
6676

6777
# run n-fold cross validation for forests
6878
println("\n##### nfoldCV Classification Forest #####")
@@ -80,6 +90,8 @@ preds = apply_adaboost_stumps(model, coeffs, features)
8090
cm = confusion_matrix(labels, preds)
8191
@test cm.accuracy > 0.9
8292
@test typeof(preds) == Vector{String}
93+
probs = apply_adaboost_stumps_proba(model, coeffs, features, classes)
94+
@test reshape(sum(probs, dims=2), n) ones(n)
8395

8496
# run n-fold cross validation for boosted stumps, using 7 iterations and 3 folds
8597
println("\n##### nfoldCV Classification Adaboosted Stumps #####")

0 commit comments

Comments
 (0)