Skip to content

Commit bfcb756

Browse files
authored
Merge pull request #41 from SciML/hr/performance
2 parents 4c3cd73 + 81d084c commit bfcb756

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

src/RootedTrees.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,8 @@ function Base.iterate(forest::PartitionForestIterator, state)
682682
# and `edge_set`.
683683
deleteat!(level_sequence, subtree_root_index:subtree_last_index)
684684
deleteat!(edge_set, subtree_root_index-1:subtree_last_index-1)
685-
edge_to_remove = findlast(==(false), edge_set)
685+
686+
edge_to_remove = findprev(==(false), edge_set, edge_to_remove - 1)
686687
if edge_to_remove === nothing
687688
edge_to_remove = typemin(Int)
688689
end
@@ -755,7 +756,7 @@ function partition_skeleton!(level_sequence, edge_set)
755756
deleteat!(level_sequence, subtree_root_index)
756757
deleteat!(edge_set, edge_to_contract)
757758

758-
edge_to_contract = findlast(edge_set)
759+
edge_to_contract = findprev(edge_set, edge_to_contract - 1)
759760
end
760761

761762
# The level sequence `level_sequence` will not automatically be a canonical
@@ -788,14 +789,27 @@ function all_partitions(t::RootedTree)
788789
skeletons = [partition_skeleton(t, edge_set)]
789790

790791
for edge_set_value in 1:(2^length(edge_set) - 1)
791-
digits!(edge_set, edge_set_value, base=2)
792+
binary_digits!(edge_set, edge_set_value)
792793
push!(forests, partition_forest(t, edge_set))
793794
push!(skeletons, partition_skeleton(t, edge_set))
794795
end
795796

796797
return (; forests, skeletons)
797798
end
798799

800+
# A helper function to comute the binary representation of an integer `n` as
801+
# a vector of `Bool`s. This is a more efficient version of
802+
# binary_digits!(digits, n) = digits!(digits, n, base=2)
803+
function binary_digits!(digits::Vector{Bool}, n::Int)
804+
bit = 1
805+
for i in eachindex(digits)
806+
digits[i] = n & bit > 0
807+
bit = bit << 1
808+
end
809+
digits
810+
end
811+
812+
799813

800814
"""
801815
PartitionIterator(t::RootedTree)
@@ -855,7 +869,7 @@ function Base.iterate(partitions::PartitionIterator, edge_set_value)
855869
edge_set = partitions.edge_set
856870
edge_set_tmp = partitions.edge_set_tmp
857871

858-
digits!(edge_set, edge_set_value, base=2)
872+
binary_digits!(edge_set, edge_set_value)
859873

860874
# Compute the partition skeleton.
861875
# The following is a more efficient version of
@@ -918,7 +932,7 @@ function all_splittings(t::RootedTree)
918932
subtrees = Vector{RootedTree{T, Vector{T}}}() # ordered subtrees
919933

920934
for node_set_value in 0:(2^order(t) - 1)
921-
digits!(node_set, node_set_value, base=2)
935+
binary_digits!(node_set, node_set_value)
922936

923937
# Check that if a node is removed then all of its descendants are removed
924938
subtree_root_index = 1
@@ -1000,7 +1014,7 @@ function Base.iterate(splittings::SplittingIterator, node_set_value)
10001014
forest = Vector{RootedTree{T, Vector{T}}}()
10011015

10021016
while node_set_value <= splittings.max_node_set_value
1003-
digits!(node_set, node_set_value, base=2)
1017+
binary_digits!(node_set, node_set_value)
10041018

10051019
# Check that if a node is removed then all of its descendants are removed
10061020
subtree_root_index = 1

0 commit comments

Comments
 (0)