Skip to content

Commit 5a04aba

Browse files
authored
Merge pull request #199 from JuliaAI/dev
For a 0.12.0 release
2 parents b045bb9 + a072539 commit 5a04aba

File tree

4 files changed

+61
-18
lines changed

4 files changed

+61
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name = "DecisionTree"
22
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
33
license = "MIT"
44
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
5-
version = "0.11.3"
5+
version = "0.12.0"
66

77
[deps]
88
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

README.md

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
[![CI](https://github.com/JuliaAI/DecisionTree.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/DecisionTree.jl/actions?query=workflow%3ACI)
44
[![Codecov](https://codecov.io/gh/JuliaAI/DecisionTree.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaAI/DecisionTree.jl)
55
[![Docs Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliahub.com/docs/DecisionTree/pEDeB/0.10.11/)
6+
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7359268.svg)](https://doi.org/10.5281/zenodo.7359268)
67

78
Julia implementation of Decision Tree (CART) and Random Forest algorithms
89

@@ -44,11 +45,15 @@ Available models: `DecisionTreeClassifier, DecisionTreeRegressor, RandomForestCl
4445
See each model's help (eg. `?DecisionTreeRegressor` at the REPL) for more information
4546

4647
### Classification Example
48+
4749
Load DecisionTree package
50+
4851
```julia
4952
using DecisionTree
5053
```
54+
5155
Separate Fisher's Iris dataset features and labels
56+
5257
```julia
5358
features, labels = load_data("iris") # also see "adult" and "digits" datasets
5459

@@ -57,7 +62,9 @@ features, labels = load_data("iris") # also see "adult" and "digits" datasets
5762
features = float.(features)
5863
labels = string.(labels)
5964
```
65+
6066
Pruned Tree Classifier
67+
6168
```julia
6269
# train depth-truncated classifier
6370
model = DecisionTreeClassifier(max_depth=2)
@@ -78,8 +85,11 @@ accuracy = cross_val_score(model, features, labels, cv=3)
7885
Also, have a look at these [classification](https://github.com/cstjean/ScikitLearn.jl/blob/master/examples/Classifier_Comparison_Julia.ipynb) and [regression](https://github.com/cstjean/ScikitLearn.jl/blob/master/examples/Decision_Tree_Regression_Julia.ipynb) notebooks.
7986

8087
## Native API
88+
8189
### Classification Example
90+
8291
Decision Tree Classifier
92+
8393
```julia
8494
# train full-tree classifier
8595
model = build_tree(labels, features)
@@ -129,6 +139,7 @@ accuracy = nfoldCV_tree(labels, features,
129139
rng = seed)
130140
```
131141
Random Forest Classifier
142+
132143
```julia
133144
# train random forest classifier
134145
# using 2 random features, 10 trees, 0.5 portion of samples per tree, and a maximum tree depth of 6
@@ -176,7 +187,9 @@ accuracy = nfoldCV_forest(labels, features,
176187
verbose = true,
177188
rng = seed)
178189
```
190+
179191
Adaptive-Boosted Decision Stumps Classifier
192+
180193
```julia
181194
# train adaptive-boosted stumps, using 7 iterations
182195
model, coeffs = build_adaboost_stumps(labels, features, 7);
@@ -193,13 +206,15 @@ accuracy = nfoldCV_stumps(labels, features,
193206
```
194207

195208
### Regression Example
209+
196210
```julia
197211
n, m = 10^3, 5
198212
features = randn(n, m)
199213
weights = rand(-2:2, m)
200214
labels = features * weights
201215
```
202216
Regression Tree
217+
203218
```julia
204219
# train regression tree
205220
model = build_tree(labels, features)
@@ -238,7 +253,9 @@ r2 = nfoldCV_tree(labels, features,
238253
verbose = true,
239254
rng = seed)
240255
```
256+
241257
Regression Random Forest
258+
242259
```julia
243260
# train regression forest, using 2 random features, 10 trees,
244261
# averaging of 5 samples per leaf, and 0.7 portion of samples per tree
@@ -285,6 +302,14 @@ r2 = nfoldCV_forest(labels, features,
285302
rng = seed)
286303
```
287304

305+
## Saving Models
306+
Models can be saved to disk and loaded back with the use of the [JLD2.jl](https://github.com/JuliaIO/JLD2.jl) package.
307+
```julia
308+
using JLD2
309+
@save "model_file.jld2" model
310+
```
311+
Note that even though features and labels of type `Array{Any}` are supported, it is highly recommended that data be cast to explicit types (ie with `float.(), string.()`, etc). This significantly improves model training and prediction execution times, and also drastically reduces the size of saved models.
312+
288313
## MLJ.jl API
289314

290315
To use DecsionTree.jl models in
@@ -318,15 +343,6 @@ The following methods provide measures of feature importance for all models:
318343
`impurity_importance`, `split_importance`, `permutation_importance`. Query the document
319344
strings for details.
320345

321-
322-
## Saving Models
323-
Models can be saved to disk and loaded back with the use of the [JLD2.jl](https://github.com/JuliaIO/JLD2.jl) package.
324-
```julia
325-
using JLD2
326-
@save "model_file.jld2" model
327-
```
328-
Note that even though features and labels of type `Array{Any}` are supported, it is highly recommended that data be cast to explicit types (ie with `float.(), string.()`, etc). This significantly improves model training and prediction execution times, and also drastically reduces the size of saved models.
329-
330346
## Visualization
331347
A `DecisionTree` model can be visualized using the `print_tree`-function of its native interface
332348
(for an example see above in section 'Classification Example').
@@ -335,3 +351,34 @@ In addition, an abstraction layer using `AbstractTrees.jl` has been implemented
335351

336352
Apart from this, `AbstractTrees.jl` brings its own implementation of `print_tree`.
337353

354+
355+
## Citing the package in publications
356+
357+
DOI: [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7359268.svg)](https://doi.org/10.5281/zenodo.7359268).
358+
359+
BibTeX entry:
360+
361+
```
362+
@software{ben_sadeghi_2022_7359268,
363+
author = {Ben Sadeghi and
364+
Poom Chiarawongse and
365+
Kevin Squire and
366+
Daniel C. Jones and
367+
Andreas Noack and
368+
Cédric St-Jean and
369+
Rik Huijzer and
370+
Roland Schätzle and
371+
Ian Butterworth and
372+
Yu-Fong Peng and
373+
Anthony Blaom},
374+
title = {{DecisionTree.jl - A Julia implementation of the
375+
CART Decision Tree and Random Forest algorithms}},
376+
month = nov,
377+
year = 2022,
378+
publisher = {Zenodo},
379+
version = {0.11.3},
380+
doi = {10.5281/zenodo.7359268},
381+
url = {https://doi.org/10.5281/zenodo.7359268}
382+
}
383+
```
384+
>

src/classification/main.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,12 +370,10 @@ function build_forest(
370370
loss = (ns, n) -> util.entropy(ns, n, entropy_terms)
371371

372372
if rng isa Random.AbstractRNG
373+
shared_seed = rand(rng, UInt)
373374
Threads.@threads for i in 1:n_trees
374375
# The Mersenne Twister (Julia's default) is not thread-safe.
375-
_rng = copy(rng)
376-
# Take some elements from the ring to have different states for each tree. This
377-
# is the only way given that only a `copy` can be expected to exist for RNGs.
378-
rand(_rng, i)
376+
_rng = Random.seed!(copy(rng), shared_seed + i)
379377
inds = rand(_rng, 1:t_samples, n_samples)
380378
forest[i] = build_tree(
381379
labels[inds],

src/regression/main.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,10 @@ function build_forest(
9595
forest = impurity_importance ? Vector{Root{S, T}}(undef, n_trees) : Vector{LeafOrNode{S, T}}(undef, n_trees)
9696

9797
if rng isa Random.AbstractRNG
98+
shared_seed = rand(rng, UInt)
9899
Threads.@threads for i in 1:n_trees
99100
# The Mersenne Twister (Julia's default) is not thread-safe.
100-
_rng = copy(rng)
101-
# Take some elements from the ring to have different states for each tree.
102-
# This is the only way given that only a `copy` can be expected to exist for RNGs.
103-
rand(_rng, i)
101+
_rng = Random.seed!(copy(rng), shared_seed + i)
104102
inds = rand(_rng, 1:t_samples, n_samples)
105103
forest[i] = build_tree(
106104
labels[inds],

0 commit comments

Comments
 (0)