Skip to content

Commit 9eb6ad9

Browse files
authored
Merge pull request #184 from JuliaAI/dev
For a 0.11 release
2 parents 66f99b8 + 2e5be13 commit 9eb6ad9

File tree

21 files changed

+973
-358
lines changed

21 files changed

+973
-358
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.0'
2120
- '1.6'
2221
- '1' # automatically expands to the latest stable 1.x release of Julia
2322
os:

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name = "DecisionTree"
22
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
33
license = "MIT"
44
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
5-
version = "0.10.13"
5+
version = "0.11.0"
66

77
[deps]
88
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -13,9 +13,9 @@ ScikitLearnBase = "6e75b9c4-186b-50bd-896f-2d2496a4843e"
1313
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1414

1515
[compat]
16-
AbstractTrees = "0.3"
16+
AbstractTrees = "0.3, 0.4"
1717
ScikitLearnBase = "0.5"
18-
julia = "1"
18+
julia = "1.6"
1919

2020
[extras]
2121
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Available via:
1919
* pre-pruning (max depth, min leaf size)
2020
* post-pruning (pessimistic pruning)
2121
* multi-threaded bagging (random forests)
22-
* adaptive boosting (decision stumps)
22+
* adaptive boosting (decision stumps), using [SAMME](https://www.intlpress.com/site/pub/pages/journals/items/sii/content/vols/0002/0003/a008/)
2323
* cross validation (n-fold)
2424
* support for ordered features (encoded as `Real`s or `String`s)
2525

@@ -92,7 +92,7 @@ apply_tree(model, [5.9,3.0,5.1,1.9])
9292
# apply model to all the sames
9393
preds = apply_tree(model, features)
9494
# generate confusion matrix, along with accuracy and kappa scores
95-
confusion_matrix(labels, preds)
95+
DecisionTree.confusion_matrix(labels, preds)
9696
# get the probability of each label
9797
apply_tree_proba(model, [5.9,3.0,5.1,1.9], ["Iris-setosa", "Iris-versicolor", "Iris-virginica"])
9898
# run 3-fold cross validation of pruned tree,
@@ -312,6 +312,13 @@ Available models are: `AdaBoostStumpClassifier`,
312312
`RandomForestClassifier`, `RandomForestRegressor`.
313313

314314

315+
## Feature Importances
316+
317+
The following methods provide measures of feature importance for all models:
318+
`impurity_importance`, `split_importance`, `permutation_importance`. Query the document
319+
strings for details.
320+
321+
315322
## Saving Models
316323
Models can be saved to disk and loaded back with the use of the [JLD2.jl](https://github.com/JuliaIO/JLD2.jl) package.
317324
```julia

src/DecisionTree.jl

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
__precompile__()
2-
31
module DecisionTree
42

53
import Base: length, show, convert, promote_rule, zero
@@ -9,11 +7,11 @@ using Random
97
using Statistics
108
import AbstractTrees
119

12-
export Leaf, Node, Ensemble, print_tree, depth, build_stump, build_tree,
10+
export Leaf, Node, Root, Ensemble, print_tree, depth, build_stump, build_tree,
1311
prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest,
1412
apply_forest, apply_forest_proba, nfoldCV_forest, build_adaboost_stumps,
1513
apply_adaboost_stumps, apply_adaboost_stumps_proba, nfoldCV_stumps,
16-
majority_vote, ConfusionMatrix, confusion_matrix, mean_squared_error, R2, load_data
14+
load_data, impurity_importance, split_importance, permutation_importance
1715

1816
# ScikitLearn API
1917
export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier,
@@ -42,17 +40,32 @@ end
4240

4341
const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
4442

43+
struct Root{S, T}
44+
node :: LeafOrNode{S, T}
45+
n_feat :: Int
46+
featim :: Vector{Float64} # impurity importance
47+
end
48+
4549
struct Ensemble{S, T}
46-
trees :: Vector{LeafOrNode{S, T}}
50+
trees :: Vector{LeafOrNode{S, T}}
51+
n_feat :: Int
52+
featim :: Vector{Float64}
4753
end
4854

55+
4956
is_leaf(l::Leaf) = true
5057
is_leaf(n::Node) = false
5158

52-
zero(String) = ""
59+
zero(::Type{String}) = ""
5360
convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, zero(S), lf, Leaf(zero(T), [zero(T)]))
61+
convert(::Type{Root{S, T}}, node::LeafOrNode{S, T}) where {S, T} = Root{S, T}(node, 0, Float64[])
62+
convert(::Type{LeafOrNode{S, T}}, tree::Root{S, T}) where {S, T} = tree.node
5463
promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T}
5564
promote_rule(::Type{Leaf{T}}, ::Type{Node{S, T}}) where {S, T} = Node{S, T}
65+
promote_rule(::Type{Root{S, T}}, ::Type{Leaf{T}}) where {S, T} = Root{S, T}
66+
promote_rule(::Type{Leaf{T}}, ::Type{Root{S, T}}) where {S, T} = Root{S, T}
67+
promote_rule(::Type{Root{S, T}}, ::Type{Node{S, T}}) where {S, T} = Root{S, T}
68+
promote_rule(::Type{Node{S, T}}, ::Type{Root{S, T}}) where {S, T} = Root{S, T}
5669

5770
# make a Random Number Generator object
5871
mk_rng(rng::Random.AbstractRNG) = rng
@@ -75,10 +88,12 @@ include("abstract_trees.jl")
7588

7689
length(leaf::Leaf) = 1
7790
length(tree::Node) = length(tree.left) + length(tree.right)
91+
length(tree::Root) = length(tree.node)
7892
length(ensemble::Ensemble) = length(ensemble.trees)
7993

8094
depth(leaf::Leaf) = 0
8195
depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
96+
depth(tree::Root) = depth(tree.node)
8297

8398
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
8499
n_matches = count(leaf.values .== leaf.majority)
@@ -90,6 +105,13 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
90105
end
91106

92107

108+
function print_tree(io::IO, tree::Root, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
109+
return print_tree(io, tree.node, depth, indent; sigdigits=sigdigits, feature_names=feature_names)
110+
end
111+
function print_tree(tree::Root, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
112+
return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names)
113+
end
114+
93115
"""
94116
print_tree([io::IO,] tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
95117
@@ -113,9 +135,9 @@ Feature 3 < -28.15 ?
113135
└─ 8 : 1227/3508
114136
```
115137
116-
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
117-
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
118-
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
138+
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object,
139+
`DecisionTree.Node` object or `DecisionTree.Root` object can be wrapped to obtain a tree structure implementing the
140+
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
119141
"""
120142
function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
121143
if depth == indent
@@ -149,6 +171,12 @@ function show(io::IO, tree::Node)
149171
print(io, "Depth: $(depth(tree))")
150172
end
151173

174+
function show(io::IO, tree::Root)
175+
println(io, "Decision Tree")
176+
println(io, "Leaves: $(length(tree))")
177+
print(io, "Depth: $(depth(tree))")
178+
end
179+
152180
function show(io::IO, ensemble::Ensemble)
153181
println(io, "Ensemble of Decision Trees")
154182
println(io, "Trees: $(length(ensemble))")

src/abstract_trees.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ In the first case `dc` gets just wrapped, no information is added. No. 2 adds fe
6666
as well as class labels. In the last two cases either of this information is added (Note the
6767
trailing comma; it's needed to make it a tuple).
6868
"""
69+
wrap(tree::DecisionTree.Root, info::NamedTuple = NamedTuple()) = wrap(tree.node, info)
6970
wrap(node::DecisionTree.Node, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
7071
wrap(leaf::DecisionTree.Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)
7172

0 commit comments

Comments
 (0)