1
- __precompile__ ()
2
-
3
1
module DecisionTree
4
2
5
3
import Base: length, show, convert, promote_rule, zero
@@ -9,11 +7,11 @@ using Random
9
7
using Statistics
10
8
import AbstractTrees
11
9
12
- export Leaf, Node, Ensemble, print_tree, depth, build_stump, build_tree,
10
+ export Leaf, Node, Root, Ensemble, print_tree, depth, build_stump, build_tree,
13
11
prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest,
14
12
apply_forest, apply_forest_proba, nfoldCV_forest, build_adaboost_stumps,
15
13
apply_adaboost_stumps, apply_adaboost_stumps_proba, nfoldCV_stumps,
16
- majority_vote, ConfusionMatrix, confusion_matrix, mean_squared_error, R2, load_data
14
+ load_data, impurity_importance, split_importance, permutation_importance
17
15
18
16
# ScikitLearn API
19
17
export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier,
42
40
43
41
const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
44
42
43
+ struct Root{S, T}
44
+ node :: LeafOrNode{S, T}
45
+ n_feat :: Int
46
+ featim :: Vector{Float64} # impurity importance
47
+ end
48
+
45
49
struct Ensemble{S, T}
46
- trees :: Vector{LeafOrNode{S, T}}
50
+ trees :: Vector{LeafOrNode{S, T}}
51
+ n_feat :: Int
52
+ featim :: Vector{Float64}
47
53
end
48
54
55
+
49
56
is_leaf (l:: Leaf ) = true
50
57
is_leaf (n:: Node ) = false
51
58
52
- zero (String) = " "
59
+ zero (:: Type{ String} ) = " "
53
60
convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} = Node (0 , zero (S), lf, Leaf (zero (T), [zero (T)]))
61
+ convert (:: Type{Root{S, T}} , node:: LeafOrNode{S, T} ) where {S, T} = Root {S, T} (node, 0 , Float64[])
62
+ convert (:: Type{LeafOrNode{S, T}} , tree:: Root{S, T} ) where {S, T} = tree. node
54
63
promote_rule (:: Type{Node{S, T}} , :: Type{Leaf{T}} ) where {S, T} = Node{S, T}
55
64
promote_rule (:: Type{Leaf{T}} , :: Type{Node{S, T}} ) where {S, T} = Node{S, T}
65
+ promote_rule (:: Type{Root{S, T}} , :: Type{Leaf{T}} ) where {S, T} = Root{S, T}
66
+ promote_rule (:: Type{Leaf{T}} , :: Type{Root{S, T}} ) where {S, T} = Root{S, T}
67
+ promote_rule (:: Type{Root{S, T}} , :: Type{Node{S, T}} ) where {S, T} = Root{S, T}
68
+ promote_rule (:: Type{Node{S, T}} , :: Type{Root{S, T}} ) where {S, T} = Root{S, T}
56
69
57
70
# make a Random Number Generator object
58
71
mk_rng (rng:: Random.AbstractRNG ) = rng
@@ -75,10 +88,12 @@ include("abstract_trees.jl")
75
88
76
89
length (leaf:: Leaf ) = 1
77
90
length (tree:: Node ) = length (tree. left) + length (tree. right)
91
+ length (tree:: Root ) = length (tree. node)
78
92
length (ensemble:: Ensemble ) = length (ensemble. trees)
79
93
80
94
depth (leaf:: Leaf ) = 0
81
95
depth (tree:: Node ) = 1 + max (depth (tree. left), depth (tree. right))
96
+ depth (tree:: Root ) = depth (tree. node)
82
97
83
98
function print_tree (io:: IO , leaf:: Leaf , depth= - 1 , indent= 0 ; feature_names= nothing )
84
99
n_matches = count (leaf. values .== leaf. majority)
@@ -90,6 +105,13 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
90
105
end
91
106
92
107
108
+ function print_tree (io:: IO , tree:: Root , depth= - 1 , indent= 0 ; sigdigits= 2 , feature_names= nothing )
109
+ return print_tree (io, tree. node, depth, indent; sigdigits= sigdigits, feature_names= feature_names)
110
+ end
111
+ function print_tree (tree:: Root , depth= - 1 , indent= 0 ; sigdigits= 2 , feature_names= nothing )
112
+ return print_tree (stdout , tree, depth, indent; sigdigits= sigdigits, feature_names= feature_names)
113
+ end
114
+
93
115
"""
94
116
print_tree([io::IO,] tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
95
117
@@ -113,9 +135,9 @@ Feature 3 < -28.15 ?
113
135
└─ 8 : 1227/3508
114
136
```
115
137
116
- To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
117
- `DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
118
- AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
138
+ To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object,
139
+ `DecisionTree.Node` object or `DecisionTree.Root` object can be wrapped to obtain a tree structure implementing the
140
+ AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
119
141
"""
120
142
function print_tree (io:: IO , tree:: Node , depth= - 1 , indent= 0 ; sigdigits= 2 , feature_names= nothing )
121
143
if depth == indent
@@ -149,6 +171,12 @@ function show(io::IO, tree::Node)
149
171
print (io, " Depth: $(depth (tree)) " )
150
172
end
151
173
174
+ function show (io:: IO , tree:: Root )
175
+ println (io, " Decision Tree" )
176
+ println (io, " Leaves: $(length (tree)) " )
177
+ print (io, " Depth: $(depth (tree)) " )
178
+ end
179
+
152
180
function show (io:: IO , ensemble:: Ensemble )
153
181
println (io, " Ensemble of Decision Trees" )
154
182
println (io, " Trees: $(length (ensemble)) " )
0 commit comments