|
56 | 56 | is_leaf(l::Leaf) = true
|
57 | 57 | is_leaf(n::Node) = false
|
58 | 58 |
|
59 |
| -zero(::Type{String}) = "" |
60 |
| -convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, zero(S), lf, Leaf(zero(T), [zero(T)])) |
| 59 | +_zero(::Type{String}) = "" |
| 60 | +_zero(x::Any) = zero(x) |
| 61 | +convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, _zero(S), lf, Leaf(_zero(T), [_zero(T)])) |
61 | 62 | convert(::Type{Root{S, T}}, node::LeafOrNode{S, T}) where {S, T} = Root{S, T}(node, 0, Float64[])
|
62 | 63 | convert(::Type{LeafOrNode{S, T}}, tree::Root{S, T}) where {S, T} = tree.node
|
63 | 64 | promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T}
|
@@ -95,21 +96,21 @@ depth(leaf::Leaf) = 0
|
95 | 96 | depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
|
96 | 97 | depth(tree::Root) = depth(tree.node)
|
97 | 98 |
|
98 |
| -function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing) |
| 99 | +function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
99 | 100 | n_matches = count(leaf.values .== leaf.majority)
|
100 | 101 | ratio = string(n_matches, "/", length(leaf.values))
|
101 | 102 | println(io, "$(leaf.majority) : $(ratio)")
|
102 | 103 | end
|
103 |
| -function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) |
104 |
| - return print_tree(stdout, leaf, depth, indent; feature_names=feature_names) |
| 104 | +function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
| 105 | + return print_tree(stdout, leaf, depth, indent; sigdigits, feature_names) |
105 | 106 | end
|
106 | 107 |
|
107 | 108 |
|
108 |
| -function print_tree(io::IO, tree::Root, depth=-1, indent=0; sigdigits=2, feature_names=nothing) |
109 |
| - return print_tree(io, tree.node, depth, indent; sigdigits=sigdigits, feature_names=feature_names) |
| 109 | +function print_tree(io::IO, tree::Root, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
| 110 | + return print_tree(io, tree.node, depth, indent; sigdigits, feature_names) |
110 | 111 | end
|
111 |
| -function print_tree(tree::Root, depth=-1, indent=0; sigdigits=2, feature_names=nothing) |
112 |
| - return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names) |
| 112 | +function print_tree(tree::Root, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
| 113 | + return print_tree(stdout, tree, depth, indent; sigdigits, feature_names) |
113 | 114 | end
|
114 | 115 |
|
115 | 116 | """
|
@@ -137,26 +138,26 @@ Feature 3 < -28.15 ?
|
137 | 138 |
|
138 | 139 | To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object,
|
139 | 140 | `DecisionTree.Node` object or `DecisionTree.Root` object can be wrapped to obtain a tree structure implementing the
|
140 |
| -AbstractTrees.jl interface. See [`wrap`](@ref)` for details. |
| 141 | +AbstractTrees.jl interface. See [`wrap`](@ref)` for details. |
141 | 142 | """
|
142 | 143 | function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
|
143 | 144 | if depth == indent
|
144 | 145 | println(io)
|
145 | 146 | return
|
146 | 147 | end
|
147 |
| - featval = round(tree.featval; sigdigits=sigdigits) |
| 148 | + featval = round(tree.featval; sigdigits) |
148 | 149 | if feature_names === nothing
|
149 | 150 | println(io, "Feature $(tree.featid) < $featval ?")
|
150 | 151 | else
|
151 | 152 | println(io, "Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?")
|
152 | 153 | end
|
153 | 154 | print(io, " " ^ indent * "├─ ")
|
154 |
| - print_tree(io, tree.left, depth, indent + 1; feature_names=feature_names) |
| 155 | + print_tree(io, tree.left, depth, indent + 1; sigdigits, feature_names) |
155 | 156 | print(io, " " ^ indent * "└─ ")
|
156 |
| - print_tree(io, tree.right, depth, indent + 1; feature_names=feature_names) |
| 157 | + print_tree(io, tree.right, depth, indent + 1; sigdigits, feature_names) |
157 | 158 | end
|
158 |
| -function print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) |
159 |
| - return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names) |
| 159 | +function print_tree(tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
| 160 | + return print_tree(stdout, tree, depth, indent; sigdigits, feature_names) |
160 | 161 | end
|
161 | 162 |
|
162 | 163 | function show(io::IO, leaf::Leaf)
|
|
0 commit comments