Skip to content

Commit f71ebb1

Browse files
authored
Merge pull request #201 from JuliaAI/dev
For a 0.12.1 release
2 parents 5a04aba + 9062bd1 commit f71ebb1

File tree

3 files changed

+37
-8
lines changed

3 files changed

+37
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name = "DecisionTree"
22
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
33
license = "MIT"
44
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
5-
version = "0.12.0"
5+
version = "0.12.1"
66

77
[deps]
88
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/abstract_trees.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ apart from the two points mentioned.
2828
In analogy to the type definitions of `DecisionTree`, the generic type `S` is
2929
the type of the feature values used within a node as a threshold for the splits
3030
between its children and `T` is the type of the classes given (these might be ids or labels).
31+
32+
!!! note
33+
You may only add lacking class labels. It's not possible to overwrite existing labels
34+
with this mechanism. In case you want add class labels, the generic type `T` must
35+
be a subtype of `Integer`.
3136
"""
3237
struct InfoNode{S, T} <: AbstractTrees.AbstractNode{DecisionTree.Node{S,T}}
3338
node :: DecisionTree.Node{S, T}
@@ -89,8 +94,8 @@ AbstractTrees.children(node::InfoNode) = (
8994
AbstractTrees.children(node::InfoLeaf) = ()
9095

9196
"""
92-
printnode(io::IO, node::InfoNode)
93-
printnode(io::IO, leaf::InfoLeaf)
97+
printnode(io::IO, node::InfoNode; sigdigits=4)
98+
printnode(io::IO, leaf::InfoLeaf; sigdigits=4)
9499
95100
Write a printable representation of `node` or `leaf` to output-stream `io`.
96101
@@ -108,23 +113,28 @@ For the condition of the form `feature < value` which gets printed in the `print
108113
variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree
109114
accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first
110115
and then below the right subtree.
116+
117+
`value` gets rounded to `sigdigits` significant digits.
111118
"""
112-
function AbstractTrees.printnode(io::IO, node::InfoNode)
119+
function AbstractTrees.printnode(io::IO, node::InfoNode; sigdigits=4)
120+
featval = round(node.node.featval; sigdigits)
113121
if :featurenames keys(node.info)
114-
print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval)
122+
print(io, node.info.featurenames[node.node.featid], " < ", featval)
115123
else
116-
print(io, "Feature: ", node.node.featid, " < ", node.node.featval)
124+
print(io, "Feature: ", node.node.featid, " < ", featval)
117125
end
118126
end
119127

120-
function AbstractTrees.printnode(io::IO, leaf::InfoLeaf)
128+
function AbstractTrees.printnode(io::IO, leaf::InfoLeaf; sigdigits=4)
121129
dt_leaf = leaf.leaf
122130
matches = findall(dt_leaf.values .== dt_leaf.majority)
123131
match_count = length(matches)
124132
val_count = length(dt_leaf.values)
125133
if :classlabels keys(leaf.info)
134+
@assert dt_leaf.majority isa Integer "classes must be represented as Integers"
126135
print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)")
127136
else
128-
print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)")
137+
print(io, dt_leaf.majority isa Integer ? "Class: " : "",
138+
dt_leaf.majority, " ($match_count/$val_count)")
129139
end
130140
end

test/miscellaneous/abstract_trees_test.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,23 @@ end
8181
traverse_tree(leaf::InfoLeaf) = nothing
8282

8383
traverse_tree(wrapped_tree)
84+
end
85+
86+
@testset "abstract_trees - test misuse" begin
87+
88+
@info("Test misuse of `classlabel` information")
89+
90+
@info("Create test data - a decision tree based on the iris data set")
91+
features, labels = load_data("iris")
92+
features = float.(features)
93+
labels = string.(labels)
94+
model = DecisionTreeClassifier()
95+
fit!(model, features, labels)
96+
97+
@info("Try to replace the exisitng class labels")
98+
class_labels = unique(labels)
99+
dtree = model.root.node
100+
wt = DecisionTree.wrap(dtree, (classlabels = class_labels,))
101+
@test_throws AssertionError AbstractTrees.print_tree(wt)
102+
84103
end

0 commit comments

Comments
 (0)