Skip to content

Commit 1cdbbbe

Browse files
authored
Merge pull request #187 from JuliaAI/dev
For a 0.11.1 release
2 parents 9eb6ad9 + 33bbec4 commit 1cdbbbe

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name = "DecisionTree"
22
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
33
license = "MIT"
44
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
5-
version = "0.11.0"
5+
version = "0.11.1"
66

77
[deps]
88
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/DecisionTree.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ end
5656
is_leaf(l::Leaf) = true
5757
is_leaf(n::Node) = false
5858

59-
zero(::Type{String}) = ""
60-
convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, zero(S), lf, Leaf(zero(T), [zero(T)]))
59+
_zero(::Type{String}) = ""
60+
_zero(x::Any) = zero(x)
61+
convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, _zero(S), lf, Leaf(_zero(T), [_zero(T)]))
6162
convert(::Type{Root{S, T}}, node::LeafOrNode{S, T}) where {S, T} = Root{S, T}(node, 0, Float64[])
6263
convert(::Type{LeafOrNode{S, T}}, tree::Root{S, T}) where {S, T} = tree.node
6364
promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T}
@@ -95,21 +96,21 @@ depth(leaf::Leaf) = 0
9596
depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
9697
depth(tree::Root) = depth(tree.node)
9798

98-
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
99+
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
99100
n_matches = count(leaf.values .== leaf.majority)
100101
ratio = string(n_matches, "/", length(leaf.values))
101102
println(io, "$(leaf.majority) : $(ratio)")
102103
end
103-
function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
104-
return print_tree(stdout, leaf, depth, indent; feature_names=feature_names)
104+
function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
105+
return print_tree(stdout, leaf, depth, indent; sigdigits, feature_names)
105106
end
106107

107108

108-
function print_tree(io::IO, tree::Root, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
109-
return print_tree(io, tree.node, depth, indent; sigdigits=sigdigits, feature_names=feature_names)
109+
function print_tree(io::IO, tree::Root, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
110+
return print_tree(io, tree.node, depth, indent; sigdigits, feature_names)
110111
end
111-
function print_tree(tree::Root, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
112-
return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names)
112+
function print_tree(tree::Root, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
113+
return print_tree(stdout, tree, depth, indent; sigdigits, feature_names)
113114
end
114115

115116
"""
@@ -137,26 +138,26 @@ Feature 3 < -28.15 ?
137138
138139
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object,
139140
`DecisionTree.Node` object or `DecisionTree.Root` object can be wrapped to obtain a tree structure implementing the
140-
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
141+
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
141142
"""
142143
function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
143144
if depth == indent
144145
println(io)
145146
return
146147
end
147-
featval = round(tree.featval; sigdigits=sigdigits)
148+
featval = round(tree.featval; sigdigits)
148149
if feature_names === nothing
149150
println(io, "Feature $(tree.featid) < $featval ?")
150151
else
151152
println(io, "Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?")
152153
end
153154
print(io, " " ^ indent * "├─ ")
154-
print_tree(io, tree.left, depth, indent + 1; feature_names=feature_names)
155+
print_tree(io, tree.left, depth, indent + 1; sigdigits, feature_names)
155156
print(io, " " ^ indent * "└─ ")
156-
print_tree(io, tree.right, depth, indent + 1; feature_names=feature_names)
157+
print_tree(io, tree.right, depth, indent + 1; sigdigits, feature_names)
157158
end
158-
function print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
159-
return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names)
159+
function print_tree(tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
160+
return print_tree(stdout, tree, depth, indent; sigdigits, feature_names)
160161
end
161162

162163
function show(io::IO, leaf::Leaf)

test/miscellaneous/convert.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,14 @@ push!(rv, nv[1])
3232
@test apply_tree(rv[1], [0]) == "A"
3333
@test apply_tree(rv[2], [0]) == "A"
3434

35-
end # @testset
35+
end
36+
37+
@testset "convert to text" begin
38+
n, m = 10^3, 5;
39+
features = rand(StableRNG(1), n, m);
40+
weights = rand(StableRNG(1), -1:1, m);
41+
labels = features * weights;
42+
model = fit!(DecisionTreeRegressor(; rng=StableRNG(1)), features, labels)
43+
# Smoke test.
44+
print_tree(devnull, model)
45+
end

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ using Statistics
99
using Test
1010
using LinearAlgebra
1111

12-
import DecisionTree: accuracy, R2, majority_vote, mean_squared_error
13-
import DecisionTree: confusion_matrix, ConfusionMatrix
12+
using DecisionTree: accuracy, R2, majority_vote, mean_squared_error
13+
using DecisionTree: confusion_matrix, ConfusionMatrix
1414

1515
println("Julia version: ", VERSION)
1616

0 commit comments

Comments
 (0)