Skip to content

Commit 66f99b8

Browse files
authored
Merge pull request #178 from JuliaAI/dev
For a 0.10.13 release
2 parents 7e090bb + bfe6ac5 commit 66f99b8

File tree

19 files changed

+225
-163
lines changed

19 files changed

+225
-163
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
version:
19+
version:
2020
- '1.0'
2121
- '1.6'
2222
- '1' # automatically expands to the latest stable 1.x release of Julia

Project.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,24 @@ name = "DecisionTree"
22
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
33
license = "MIT"
44
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
5-
version = "0.10.12"
5+
version = "0.10.13"
66

77
[deps]
88
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
99
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
10-
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1312
ScikitLearnBase = "6e75b9c4-186b-50bd-896f-2d2496a4843e"
1413
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
15-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1614

1715
[compat]
1816
AbstractTrees = "0.3"
1917
ScikitLearnBase = "0.5"
2018
julia = "1"
19+
20+
[extras]
21+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
22+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
24+
[targets]
25+
test = ["StableRNGs", "Test"]

src/DecisionTree.jl

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__precompile__()
22

3-
module DecisionTree
3+
module DecisionTree
44

55
import Base: length, show, convert, promote_rule, zero
66
using DelimitedFiles
@@ -80,55 +80,61 @@ length(ensemble::Ensemble) = length(ensemble.trees)
8080
depth(leaf::Leaf) = 0
8181
depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
8282

83+
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
84+
n_matches = count(leaf.values .== leaf.majority)
85+
ratio = string(n_matches, "/", length(leaf.values))
86+
println(io, "$(leaf.majority) : $(ratio)")
87+
end
8388
function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
84-
matches = findall(leaf.values .== leaf.majority)
85-
ratio = string(length(matches)) * "/" * string(length(leaf.values))
86-
println("$(leaf.majority) : $(ratio)")
89+
return print_tree(stdout, leaf, depth, indent; feature_names=feature_names)
8790
end
8891

92+
8993
"""
90-
print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
91-
92-
Print a textual visualization of the given decision tree `tree`.
93-
In the example output below, the top node considers whether
94-
"Feature 3" is above or below the threshold -28.156052806422238.
95-
If the value of "Feature 3" is strictly below the threshold for some input to be classified,
96-
we move to the `L->` part underneath, which is a node
97-
looking at if "Feature 2" is above or below -161.04351901384842.
98-
If the value of "Feature 2" is strictly below the threshold for some input to be classified,
99-
we end up at `L-> 5 : 842/3650`. This is to be read as "In the left split,
100-
the tree will classify the input as class 5, as 842 of the 3650 datapoints
101-
in the training data that ended up here were of class 5."
94+
print_tree([io::IO,] tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
95+
96+
Print a textual visualization of the specified `tree`. For example, if
97+
for some input pattern the value of "Feature 3" is "-30" and the value
98+
of "Feature 2" is "100", then, according to the sample output below,
99+
the majority class prediction is 7. Moreover, one can see that of the
100+
10555 training samples that terminate at the same leaf as this input
101+
data, 2493 of these predict the majority class, leading to a
102+
probabilistic prediction for class 7 of `2493/10555`. Ratios for
103+
non-majority classes are not shown.
102104
103105
# Example output:
104106
```
105-
Feature 3, Threshold -28.156052806422238
106-
L-> Feature 2, Threshold -161.04351901384842
107-
L-> 5 : 842/3650
108-
R-> 7 : 2493/10555
109-
R-> Feature 7, Threshold 108.1408338577021
110-
L-> 2 : 2434/15287
111-
R-> 8 : 1227/3508
107+
Feature 3 < -28.15 ?
108+
├─ Feature 2 < -161.0 ?
109+
├─ 5 : 842/3650
110+
└─ 7 : 2493/10555
111+
└─ Feature 7 < 108.1 ?
112+
├─ 2 : 2434/15287
113+
└─ 8 : 1227/3508
112114
```
113115
114-
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
115-
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
116-
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
116+
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
117+
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
118+
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
117119
"""
118-
function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
120+
function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
119121
if depth == indent
120-
println()
122+
println(io)
121123
return
122124
end
125+
featval = round(tree.featval; sigdigits=sigdigits)
123126
if feature_names === nothing
124-
println("Feature $(tree.featid), Threshold $(tree.featval)")
127+
println(io, "Feature $(tree.featid) < $featval ?")
125128
else
126-
println("Feature $(tree.featid): \"$(feature_names[tree.featid])\", Threshold $(tree.featval)")
129+
println(io, "Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?")
127130
end
128-
print(" " ^ indent * "L-> ")
129-
print_tree(tree.left, depth, indent + 1; feature_names = feature_names)
130-
print(" " ^ indent * "R-> ")
131-
print_tree(tree.right, depth, indent + 1; feature_names = feature_names)
131+
print(io, " " ^ indent * "├─ ")
132+
print_tree(io, tree.left, depth, indent + 1; feature_names=feature_names)
133+
print(io, " " ^ indent * "└─ ")
134+
print_tree(io, tree.right, depth, indent + 1; feature_names=feature_names)
135+
end
136+
function print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
137+
return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names)
132138
end
133139

134140
function show(io::IO, leaf::Leaf)

src/classification/main.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ end
3737
function _convert(
3838
node :: treeclassifier.NodeMeta{S},
3939
list :: AbstractVector{T},
40-
labels :: AbstractVector{T}) where {S, T}
40+
labels :: AbstractVector{T}
41+
) where {S, T}
4142

4243
if node.is_leaf
4344
return Leaf{T}(list[node.label], labels[node.region])
@@ -138,7 +139,7 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T}
138139
end
139140

140141

141-
apply_tree(leaf::Leaf{T}, feature::AbstractVector{S}) where {S, T} = leaf.majority
142+
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.majority
142143

143144
function apply_tree(tree::Node{S, T}, features::AbstractVector{S}) where {S, T}
144145
if tree.featid == 0
@@ -197,7 +198,7 @@ function build_forest(
197198
min_samples_leaf = 1,
198199
min_samples_split = 2,
199200
min_purity_increase = 0.0;
200-
rng = Random.GLOBAL_RNG) where {S, T}
201+
rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG) where {S, T}
201202

202203
if n_trees < 1
203204
throw("the number of trees must be >= 1")
@@ -221,7 +222,12 @@ function build_forest(
221222

222223
if rng isa Random.AbstractRNG
223224
Threads.@threads for i in 1:n_trees
224-
inds = rand(rng, 1:t_samples, n_samples)
225+
# The Mersenne Twister (Julia's default) is not thread-safe.
226+
_rng = copy(rng)
227+
# Take some elements from the ring to have different states for each tree.
228+
# This is the only way given that only a `copy` can be expected to exist for RNGs.
229+
rand(_rng, i)
230+
inds = rand(_rng, 1:t_samples, n_samples)
225231
forest[i] = build_tree(
226232
labels[inds],
227233
features[inds,:],
@@ -231,9 +237,9 @@ function build_forest(
231237
min_samples_split,
232238
min_purity_increase,
233239
loss = loss,
234-
rng = rng)
240+
rng = _rng)
235241
end
236-
elseif rng isa Integer # each thread gets its own seeded rng
242+
else # each thread gets its own seeded rng
237243
Threads.@threads for i in 1:n_trees
238244
Random.seed!(rng + i)
239245
inds = rand(1:t_samples, n_samples)
@@ -247,8 +253,6 @@ function build_forest(
247253
min_purity_increase,
248254
loss = loss)
249255
end
250-
else
251-
throw("rng must of be type Integer or Random.AbstractRNG")
252256
end
253257

254258
return Ensemble{S, T}(forest)
@@ -298,7 +302,7 @@ function build_adaboost_stumps(
298302
labels :: AbstractVector{T},
299303
features :: AbstractMatrix{S},
300304
n_iterations :: Integer;
301-
rng = Random.GLOBAL_RNG) where {S, T}
305+
rng = Random.GLOBAL_RNG) where {S, T}
302306
N = length(labels)
303307
weights = ones(N) / N
304308
stumps = Node{S, T}[]

src/measures.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ function _nfoldCV(classifier::Symbol, labels::AbstractVector{T}, features::Abstr
135135
predictions = apply_forest(model, test_features)
136136
elseif classifier == :stumps
137137
model, coeffs = build_adaboost_stumps(
138-
train_labels, train_features, n_iterations)
138+
train_labels, train_features, n_iterations; rng=rng)
139139
predictions = apply_adaboost_stumps(model, coeffs, test_features)
140140
end
141141
cm = confusion_matrix(test_labels, predictions)
@@ -186,6 +186,7 @@ function nfoldCV_stumps(
186186
n_iterations ::Integer = 10;
187187
verbose :: Bool = true,
188188
rng = Random.GLOBAL_RNG) where {S, T}
189+
rng = mk_rng(rng)::Random.AbstractRNG
189190
_nfoldCV(:stumps, labels, features, n_folds, n_iterations; verbose=verbose, rng=rng)
190191
end
191192

src/regression/main.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function build_forest(
5656
min_samples_leaf = 5,
5757
min_samples_split = 2,
5858
min_purity_increase = 0.0;
59-
rng = Random.GLOBAL_RNG) where {S, T <: Float64}
59+
rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG) where {S, T <: Float64}
6060

6161
if n_trees < 1
6262
throw("the number of trees must be >= 1")
@@ -77,7 +77,12 @@ function build_forest(
7777

7878
if rng isa Random.AbstractRNG
7979
Threads.@threads for i in 1:n_trees
80-
inds = rand(rng, 1:t_samples, n_samples)
80+
# The Mersenne Twister (Julia's default) is not thread-safe.
81+
_rng = copy(rng)
82+
# Take some elements from the ring to have different states for each tree.
83+
# This is the only way given that only a `copy` can be expected to exist for RNGs.
84+
rand(_rng, i)
85+
inds = rand(_rng, 1:t_samples, n_samples)
8186
forest[i] = build_tree(
8287
labels[inds],
8388
features[inds,:],
@@ -86,9 +91,9 @@ function build_forest(
8691
min_samples_leaf,
8792
min_samples_split,
8893
min_purity_increase,
89-
rng = rng)
94+
rng = _rng)
9095
end
91-
elseif rng isa Integer # each thread gets its own seeded rng
96+
else # each thread gets its own seeded rng
9297
Threads.@threads for i in 1:n_trees
9398
Random.seed!(rng + i)
9499
inds = rand(1:t_samples, n_samples)
@@ -101,8 +106,6 @@ function build_forest(
101106
min_samples_split,
102107
min_purity_increase)
103108
end
104-
else
105-
throw("rng must of be type Integer or Random.AbstractRNG")
106109
end
107110

108111
return Ensemble{S, T}(forest)

src/scikitlearnAPI.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,5 +386,7 @@ length(dt::DecisionTreeClassifier) = length(dt.root)
386386
length(dt::DecisionTreeRegressor) = length(dt.root)
387387

388388
print_tree(dt::DecisionTreeClassifier, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...)
389-
print_tree(dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...)
390-
print_tree(n::Nothing, depth=-1; kwargs...) = show(n)
389+
print_tree(io::IO, dt::DecisionTreeClassifier, depth=-1; kwargs...) = print_tree(io, dt.root, depth; kwargs...)
390+
print_tree(dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...)
391+
print_tree(io::IO, dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(io, dt.root, depth; kwargs...)
392+
print_tree(n::Nothing, depth=-1; kwargs...) = show(n)

test/classification/adult.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
features, labels = load_data("adult")
77

8-
model = build_tree(labels, features)
8+
model = build_tree(labels, features; rng=StableRNG(1))
99
preds = apply_tree(model, features)
1010
cm = confusion_matrix(labels, preds)
1111
@test cm.accuracy > 0.99
@@ -15,35 +15,35 @@ labels = string.(labels)
1515

1616
n_subfeatures = 3
1717
n_trees = 5
18-
model = build_forest(labels, features, n_subfeatures, n_trees)
18+
model = build_forest(labels, features, n_subfeatures, n_trees; rng=StableRNG(1))
1919
preds = apply_forest(model, features)
2020
cm = confusion_matrix(labels, preds)
2121
@test cm.accuracy > 0.9
2222

2323
n_iterations = 15
24-
model, coeffs = build_adaboost_stumps(labels, features, n_iterations);
24+
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
2525
preds = apply_adaboost_stumps(model, coeffs, features);
2626
cm = confusion_matrix(labels, preds);
2727
@test cm.accuracy > 0.8
2828

2929
println("\n##### 3 foldCV Classification Tree #####")
3030
pruning_purity = 0.9
3131
nfolds = 3
32-
accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity; verbose=false);
32+
accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity; rng=StableRNG(1), verbose=false);
3333
@test mean(accuracy) > 0.8
3434

3535
println("\n##### 3 foldCV Classification Forest #####")
3636
n_subfeatures = 2
3737
n_trees = 10
3838
n_folds = 3
3939
partial_sampling = 0.5
40-
accuracy = nfoldCV_forest(labels, features, n_folds, n_subfeatures, n_trees, partial_sampling; verbose=false)
40+
accuracy = nfoldCV_forest(labels, features, n_folds, n_subfeatures, n_trees, partial_sampling; rng=StableRNG(1), verbose=false)
4141
@test mean(accuracy) > 0.8
4242

4343
println("\n##### nfoldCV Classification Adaboosted Stumps #####")
4444
n_iterations = 15
4545
n_folds = 3
46-
accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; verbose=false);
46+
accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; rng=StableRNG(1), verbose=false);
4747
@test mean(accuracy) > 0.8
4848

4949
end # @testset

test/classification/digits.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ model = DecisionTree.build_forest(
6969
max_depth,
7070
min_samples_leaf,
7171
min_samples_split,
72-
min_purity_increase)
72+
min_purity_increase;
73+
rng=StableRNG(1))
7374
preds = apply_forest(model, X)
7475
cm = confusion_matrix(Y, preds)
7576
@test cm.accuracy > 0.95

test/classification/heterogeneous.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,29 @@
55
m, n = 10^2, 5
66

77
tf = [trues(Int(m/2)) falses(Int(m/2))]
8-
inds = Random.randperm(m)
8+
inds = Random.randperm(StableRNG(1), m)
99
labels = string.(tf[inds])
1010

1111
features = Array{Any}(undef, m, n)
12-
features[:,:] = randn(m, n)
13-
features[:,2] = string.(tf[Random.randperm(m)])
12+
features[:,:] = randn(StableRNG(1), m, n)
13+
features[:,2] = string.(tf[Random.randperm(StableRNG(1), m)])
1414
features[:,3] = map(t -> round.(Int, t), features[:,3])
1515
features[:,4] = tf[inds]
1616

17-
model = build_tree(labels, features)
17+
model = build_tree(labels, features; rng=StableRNG(1))
1818
preds = apply_tree(model, features)
1919
cm = confusion_matrix(labels, preds)
2020
@test cm.accuracy > 0.9
2121

2222
n_subfeatures = 2
2323
n_trees = 3
24-
model = build_forest(labels, features, n_subfeatures, n_trees)
24+
model = build_forest(labels, features, n_subfeatures, n_trees; rng=StableRNG(1))
2525
preds = apply_forest(model, features)
2626
cm = confusion_matrix(labels, preds)
2727
@test cm.accuracy > 0.9
2828

2929
n_subfeatures = 7
30-
model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures)
30+
model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures; rng=StableRNG(1))
3131
preds = apply_adaboost_stumps(model, coeffs, features)
3232
cm = confusion_matrix(labels, preds)
3333
@test cm.accuracy > 0.9

0 commit comments

Comments
 (0)