Skip to content

Commit 3180f75

Browse files
authored
partitions of colored trees (#69)
* update TOOD notes * partition skeleton of colored trees * fix ordering of colored trees with different lengths * fix depwarn in plotting stuff * unsafe_deleteat! and generic implementation of the partition skeleton * generic PartitionForestIterator * generic implementation of PartitionIterator * buffer for PartitionIterator of colored trees
1 parent 01b7473 commit 3180f75

File tree

4 files changed

+428
-63
lines changed

4 files changed

+428
-63
lines changed

src/RootedTrees.jl

Lines changed: 129 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,70 @@ iscanonical(t::RootedTree) = t.iscanonical
8888
#TODO: Validate rooted tree in constructor?
8989

9090
Base.copy(t::RootedTree) = RootedTree(copy(t.level_sequence), t.iscanonical)
91+
Base.similar(t::RootedTree) = RootedTree(similar(t.level_sequence), true)
9192
Base.isempty(t::RootedTree) = isempty(t.level_sequence)
9293
Base.empty(t::RootedTree) = RootedTree(empty(t.level_sequence), iscanonical(t))
9394

95+
@inline function Base.copy!(t_dst::RootedTree, t_src::RootedTree)
96+
copy!(t_dst.level_sequence, t_src.level_sequence)
97+
return t_dst
98+
end
99+
100+
"""
101+
unsafe_deleteat!(t::AbstractRootedTree, i)
102+
103+
Delete the node `i` from the rooted tree `t`. This is an unsafe operation
104+
since the rooted tree will not necessarily be in canonical representation
105+
afterwards, even if the corresponding flag of `t` is set. Use with caution!
106+
107+
!!! warn "Internal interface"
108+
This function is considered to be an internal implementation detail and
109+
will not necessarily be stable.
110+
"""
111+
@inline function unsafe_deleteat!(t::RootedTree, i)
112+
deleteat!(t.level_sequence, i)
113+
return t
114+
end
115+
116+
"""
117+
unsafe_resize!(t::AbstractRootedTree, n::Integer)
118+
119+
Resize the rooted tree `t` to `n` nodes. This is an unsafe operation
120+
since the rooted tree will not necessarily be in canonical representation
121+
afterwards, even if the corresponding flag of `t` is set. Use with caution!
122+
123+
!!! warn "Internal interface"
124+
This function is considered to be an internal implementation detail and
125+
will not necessarily be stable.
126+
"""
127+
@inline function unsafe_resize!(t::RootedTree, n::Integer)
128+
resize!(t.level_sequence, n)
129+
return t
130+
end
131+
132+
"""
133+
unsafe_copyto!(t_dst::AbstractRootedTree, dst_offset,
134+
t_src::AbstractRootedTree, src_offset, N)
135+
136+
Copy `N`` nodes from `t_src` starting at offset `src_offset` to `t_dst`
137+
starting at offset `dst_offset`. The types of the rooted trees must match.
138+
For example, you cannot copy a [`ColoredRootedTree`](@ref) to a
139+
[`RootedTree`](@ref).
140+
141+
This is an unsafe operation since the rooted tree `t_dst` will not necessarily
142+
be in canonical representation afterwards, even if the corresponding flag
143+
of `t_dst` is set. Use with caution!
144+
145+
!!! warn "Internal interface"
146+
This function is considered to be an internal implementation detail and
147+
will not necessarily be stable.
148+
"""
149+
@inline function unsafe_copyto!(t_dst::RootedTree, dst_offset,
150+
t_src::RootedTree, src_offset, N)
151+
copyto!(t_dst.level_sequence, dst_offset, t_src.level_sequence, src_offset, N)
152+
return t_dst
153+
end
154+
94155

95156
# #function RootedTree(sequence::Vector{T}, valid::Bool)
96157
# function RootedTree(sequence::Array{T,1})
@@ -526,7 +587,6 @@ end
526587

527588

528589
# partitions
529-
# TODO: partitions; add documentation in the README to make them public API
530590
"""
531591
partition_forest(t::RootedTree, edge_set)
532592
@@ -582,7 +642,7 @@ end
582642

583643

584644
"""
585-
PartitionForestIterator(t::RootedTree, edge_set)
645+
PartitionForestIterator(t::AbstractRootedTree, edge_set)
586646
587647
Lazy iterator representation of the [`partition_forest`](@ref) of the rooted
588648
tree `t`.
@@ -601,30 +661,30 @@ Section 2.3 of
601661
Foundations of Computational Mathematics
602662
[DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1)
603663
"""
604-
struct PartitionForestIterator{T, V, Tree<:RootedTree{T, V}}
605-
t::Tree
606-
level_sequence::V
664+
struct PartitionForestIterator{Tree<:AbstractRootedTree}
665+
t_iter::Tree # return value from `iterate`
666+
t_temp::Tree # internal temporary buffer
607667
edge_set::Vector{Bool}
608668
end
609669

610-
function PartitionForestIterator(t::RootedTree, edge_set)
611-
level_sequence = copy(t.level_sequence)
612-
t_iterate = RootedTree(copy(level_sequence), true)
613-
PartitionForestIterator(t_iterate, level_sequence, copy(edge_set))
670+
function PartitionForestIterator(t::AbstractRootedTree, edge_set)
671+
t_iter = copy(t)
672+
t_temp = copy(t)
673+
PartitionForestIterator(t_iter, t_temp, copy(edge_set))
614674
end
615675

616676
Base.IteratorSize(::Type{<:PartitionForestIterator}) = Base.HasLength()
617677
Base.length(forest::PartitionForestIterator) = count(==(false), forest.edge_set) + 1
618-
Base.eltype(::Type{PartitionForestIterator{T, V, Tree}}) where {T, V, Tree} = Tree
678+
Base.eltype(::Type{PartitionForestIterator{Tree}}) where {Tree} = Tree
619679

620680
@inline function Base.iterate(forest::PartitionForestIterator)
621681
iterate(forest, lastindex(forest.edge_set))
622682
end
623683

624684
@inline function Base.iterate(forest::PartitionForestIterator, search_start)
625-
t = forest.t
685+
t_iter = forest.t_iter
686+
t_temp = forest.t_temp
626687
edge_set = forest.edge_set
627-
level_sequence = forest.level_sequence
628688

629689
# We use `search_start = typemin(Int)` to indicate that we have already
630690
# returned the final tree in the previous call.
@@ -636,31 +696,31 @@ end
636696

637697
# There are no further edges to remove and we can return the final tree.
638698
if edge_to_remove === nothing
639-
resize!(t.level_sequence, length(level_sequence))
640-
copy!(t.level_sequence, level_sequence)
641-
canonical_representation!(t)
642-
return (t, typemin(Int))
699+
unsafe_resize!(t_iter, order(t_temp))
700+
copy!(t_iter, t_temp)
701+
canonical_representation!(t_iter)
702+
return (t_iter, typemin(Int))
643703
end
644704

645705
# On to the next subtree
646706
# Remember the convention node = edge + 1
647707
subtree_root_index = edge_to_remove + 1
648-
subtree_last_index = _subtree_last_index(subtree_root_index, level_sequence)
708+
subtree_last_index = _subtree_last_index(subtree_root_index, t_temp.level_sequence)
649709
subtree_length = subtree_last_index - subtree_root_index + 1
650710

651711
# Since we search from the end, there is no additional edge that needs to
652712
# be removed in the current subtree. Thus, we can return it as the next
653713
# iterate of the partition forest
654-
resize!(t.level_sequence, subtree_length)
655-
copyto!(t.level_sequence, 1, level_sequence, subtree_root_index, subtree_length)
656-
canonical_representation!(t)
714+
unsafe_resize!(t_iter, subtree_length)
715+
unsafe_copyto!(t_iter, 1, t_temp, subtree_root_index, subtree_length)
716+
canonical_representation!(t_iter)
657717

658-
# Now, we can remove the next subtree iterate from the active `level_sequence`
659-
# and `edge_set`.
660-
deleteat!(level_sequence, subtree_root_index:subtree_last_index)
718+
# Now, we can remove the next subtree iterate from the active
719+
# level sequence in `t_temp` and the `edge_set`.
720+
unsafe_deleteat!(t_temp, subtree_root_index:subtree_last_index)
661721
deleteat!(edge_set, subtree_root_index-1:subtree_last_index-1)
662722

663-
return (t, edge_to_remove - 1)
723+
return (t_iter, edge_to_remove - 1)
664724
end
665725

666726
# necessary for simple and convenient use since the iterates may be modified
@@ -674,36 +734,37 @@ function Base.collect(forest::PartitionForestIterator)
674734
end
675735

676736

677-
# TODO: partitions; add documentation in the README to make them public API
678737
"""
679-
partition_skeleton(t::RootedTree, edge_set)
738+
partition_skeleton(t::AbstractRootedTree, edge_set)
680739
681-
Form the partition skeleton of the rooted tree `t`, i.e., the rooted tree obtained
682-
by contracting each tree of the partition forest to a single vertex and re-establishing
683-
the edges removed to obtain the partition forest.
740+
Form the partition skeleton of the rooted tree `t`, i.e., the rooted tree
741+
obtained by contracting each tree of the partition forest to a single vertex
742+
and re-establishing the edges removed to obtain the partition forest.
684743
685744
See also [`partition_forest`](@ref) and [`PartitionIterator`](@ref).
686745
687746
# References
688747
689-
Section 2.3 of
748+
Section 2.3 (and Section 6.1 for colored trees) of
690749
- Philippe Chartier, Ernst Hairer, Gilles Vilmart (2010)
691750
Algebraic Structures of B-series.
692751
Foundations of Computational Mathematics
693752
[DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1)
694753
"""
695-
function partition_skeleton(t::RootedTree, edge_set)
754+
function partition_skeleton(t::AbstractRootedTree, edge_set)
696755
@boundscheck begin
697-
@assert length(t.level_sequence) == length(edge_set) + 1
756+
@assert order(t) == length(edge_set) + 1
698757
end
699758

700759
edge_set_copy = copy(edge_set)
701-
skeleton = RootedTree(copy(t.level_sequence), true)
702-
return partition_skeleton!(skeleton.level_sequence, edge_set_copy)
760+
skeleton = copy(t)
761+
return partition_skeleton!(skeleton, edge_set_copy)
703762
end
704763

705764
# internal in-place version of partition_skeleton modifying the inputs
706-
function partition_skeleton!(level_sequence, edge_set)
765+
function partition_skeleton!(skeleton::AbstractRootedTree, edge_set)
766+
level_sequence = skeleton.level_sequence
767+
707768
# Iterate over all edges that shall be kept/contracted.
708769
# We start the iteration at the end since this will result in less memory
709770
# moves because we have already reduced the size of the vectors when reaching
@@ -725,19 +786,18 @@ function partition_skeleton!(level_sequence, edge_set)
725786
end
726787

727788
# Remove the root node
728-
deleteat!(level_sequence, subtree_root_index)
789+
unsafe_deleteat!(skeleton, subtree_root_index)
729790
deleteat!(edge_set, edge_to_contract)
730791

731792
edge_to_contract = findprev(edge_set, edge_to_contract - 1)
732793
end
733794

734795
# The level sequence `level_sequence` will not automatically be a canonical
735796
# representation.
736-
return rootedtree!(level_sequence)
797+
canonical_representation!(skeleton)
737798
end
738799

739800

740-
# TODO: partitions; add documentation in the README to make them public API
741801
"""
742802
all_partitions(t::RootedTree)
743803
@@ -784,7 +844,7 @@ end
784844

785845

786846
"""
787-
PartitionIterator(t::RootedTree)
847+
PartitionIterator(t::AbstractRootedTree)
788848
789849
Iterator over all partition forests and skeletons of the rooted tree `t`.
790850
This is basically a pure iterator version of [`all_partitions`](@ref).
@@ -804,29 +864,32 @@ Section 2.3 of
804864
Foundations of Computational Mathematics
805865
[DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1)
806866
"""
807-
struct PartitionIterator{T, Tree<:RootedTree{T}}
808-
t::Tree
809-
forest::PartitionForestIterator{T, Vector{T}, RootedTree{T, Vector{T}}}
810-
skeleton::RootedTree{T, Vector{T}}
867+
struct PartitionIterator{TreeInput<:AbstractRootedTree, TreeOutput<:AbstractRootedTree}
868+
t::TreeInput
869+
forest::PartitionForestIterator{TreeOutput}
870+
skeleton::TreeOutput
811871
edge_set::Vector{Bool}
812872
edge_set_tmp::Vector{Bool}
813873
end
814874

815-
function PartitionIterator(t::Tree) where {T, Tree<:RootedTree{T}}
816-
skeleton = RootedTree(Vector{T}(undef, order(t)), true)
875+
function PartitionIterator(t::AbstractRootedTree)
876+
skeleton = similar(t)
817877
edge_set = Vector{Bool}(undef, order(t) - 1)
818878
edge_set_tmp = similar(edge_set)
819879

820-
t_forest = RootedTree(Vector{T}(undef, order(t)), true)
821-
level_sequence = similar(t_forest.level_sequence)
822-
forest = PartitionForestIterator(t_forest, level_sequence, edge_set_tmp)
823-
PartitionIterator{T, Tree}(t, forest, skeleton, edge_set, edge_set_tmp)
880+
t_forest = similar(t)
881+
t_temp_forest = similar(t)
882+
forest = PartitionForestIterator(t_forest, t_temp_forest, edge_set_tmp)
883+
PartitionIterator{typeof(t), typeof(skeleton)}(t, forest, skeleton, edge_set, edge_set_tmp)
824884
end
825885

826886
# Allocate global buffer for `PartitionIterator` for each thread
827887
const PARTITION_ITERATOR_BUFFER_FOREST_T = Vector{Vector{Int}}()
888+
const PARTITION_ITERATOR_BUFFER_FOREST_T_COLORS = Vector{Vector{Bool}}()
828889
const PARTITION_ITERATOR_BUFFER_FOREST_LEVEL_SEQUENCE = Vector{Vector{Int}}()
890+
const PARTITION_ITERATOR_BUFFER_FOREST_COLOR_SEQUENCE = Vector{Vector{Bool}}()
829891
const PARTITION_ITERATOR_BUFFER_SKELETON = Vector{Vector{Int}}()
892+
const PARTITION_ITERATOR_BUFFER_SKELETON_COLORS = Vector{Vector{Bool}}()
830893
const PARTITION_ITERATOR_BUFFER_EDGE_SET = Vector{Vector{Bool}}()
831894
const PARTITION_ITERATOR_BUFFER_EDGE_SET_TMP = Vector{Vector{Bool}}()
832895

@@ -856,15 +919,16 @@ function PartitionIterator(t::RootedTree{Int, Vector{Int}})
856919

857920
skeleton = RootedTree(buffer_skeleton, true)
858921
t_forest = RootedTree(buffer_forest_t, true)
859-
forest = PartitionForestIterator(t_forest, level_sequence, edge_set_tmp)
860-
PartitionIterator{Int, RootedTree{Int, Vector{Int}}}(
922+
t_temp_forest = RootedTree(level_sequence, true)
923+
forest = PartitionForestIterator(t_forest, t_temp_forest, edge_set_tmp)
924+
PartitionIterator{typeof(t), RootedTree{Int, Vector{Int}}}(
861925
t, forest, skeleton, edge_set, edge_set_tmp)
862926
end
863927

864928

865929
Base.IteratorSize(::Type{<:PartitionIterator}) = Base.HasLength()
866930
Base.length(partitions::PartitionIterator) = 2^length(partitions.edge_set)
867-
Base.eltype(::Type{PartitionIterator{T, Tree}}) where {T, Tree} = Tuple{Vector{RootedTree{T, Vector{T}}}, RootedTree{T, Vector{T}}}
931+
Base.eltype(::Type{PartitionIterator{TreeInput, TreeOutput}}) where {TreeInput, TreeOutput} = Tuple{PartitionForestIterator{TreeOutput}, TreeOutput}
868932

869933
@inline function Base.iterate(partitions::PartitionIterator)
870934
edge_set_value = 0
@@ -888,26 +952,26 @@ end
888952
# avoiding some allocations.
889953
resize!(edge_set_tmp, length(edge_set))
890954
copy!(edge_set_tmp, edge_set)
891-
resize!(skeleton.level_sequence, order(t))
892-
copy!(skeleton.level_sequence, t.level_sequence)
893-
partition_skeleton!(skeleton.level_sequence, edge_set_tmp)
955+
unsafe_resize!(skeleton, order(t))
956+
copy!(skeleton, t)
957+
partition_skeleton!(skeleton, edge_set_tmp)
894958

895959
# Compute the partition forest.
896960
# The following is a more efficient version of
897961
# forest = partition_forest(t, edge_set)
898962
# avoiding some allocations and using a lazy iterator.
899963
resize!(edge_set_tmp, length(edge_set))
900964
copy!(edge_set_tmp, edge_set)
901-
resize!(forest.level_sequence, order(t))
902-
copy!(forest.level_sequence, t.level_sequence)
965+
unsafe_resize!(forest.t_temp, order(t))
966+
copy!(forest.t_temp, t)
903967

904968

905969
((forest, skeleton), edge_set_value + 1)
906970
end
907971

908972
# necessary for simple and convenient use since the iterates may be modified
909-
function Base.collect(partitions::PartitionIterator)
910-
iterates = Vector{eltype(partitions)}()
973+
function Base.collect(partitions::PartitionIterator{TreeInput, TreeOutput}) where {TreeInput, TreeOutput}
974+
iterates = Vector{Tuple{Vector{TreeOutput}, TreeOutput}}()
911975
sizehint!(iterates, length(partitions))
912976
for (forest, skeleton) in partitions
913977
push!(iterates, (collect(forest), copy(skeleton)))
@@ -918,7 +982,6 @@ end
918982

919983

920984
# splittings
921-
# TODO: splittings; add documentation in the README to make them public API
922985
"""
923986
all_splittings(t::RootedTree)
924987
@@ -1290,10 +1353,16 @@ function __init__()
12901353
# PartitionIterator
12911354
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_FOREST_T,
12921355
Vector{Int}(undef, BUFFER_LENGTH))
1356+
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_FOREST_T_COLORS,
1357+
Vector{Int}(undef, BUFFER_LENGTH))
12931358
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_FOREST_LEVEL_SEQUENCE,
12941359
Vector{Int}(undef, BUFFER_LENGTH))
1360+
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_FOREST_COLOR_SEQUENCE,
1361+
Vector{Bool}(undef, BUFFER_LENGTH))
12951362
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_SKELETON,
12961363
Vector{Int}(undef, BUFFER_LENGTH))
1364+
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_SKELETON_COLORS,
1365+
Vector{Bool}(undef, BUFFER_LENGTH))
12971366
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_EDGE_SET,
12981367
Vector{Bool}(undef, BUFFER_LENGTH))
12991368
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_EDGE_SET_TMP,

0 commit comments

Comments
 (0)