Skip to content

Commit 8721106

Browse files
authored
buffer for PartitionIterator (#39)
* buffer for PartitionIterator * more tests
1 parent 7700223 commit 8721106

File tree

2 files changed

+73
-12
lines changed

2 files changed

+73
-12
lines changed

src/RootedTrees.jl

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,11 @@ end
388388
end
389389

390390
# Allocate global buffer for `canonical_representation!` for each thread
391-
const CANONICAL_REPRESENTATION_BUFFER_LENGTH = 64
391+
const BUFFER_LENGTH = 64
392392
const CANONICAL_REPRESENTATION_BUFFER = Vector{Vector{Int}}()
393393

394394
function canonical_representation!(t::RootedTree{Int, Vector{Int}})
395-
if order(t) <= CANONICAL_REPRESENTATION_BUFFER_LENGTH
395+
if order(t) <= BUFFER_LENGTH
396396
buffer = CANONICAL_REPRESENTATION_BUFFER[Threads.threadid()]
397397
else
398398
buffer = similar(t.level_sequence)
@@ -402,8 +402,23 @@ end
402402

403403

404404
function __init__()
405+
# canonical_representation!
405406
Threads.resize_nthreads!(CANONICAL_REPRESENTATION_BUFFER,
406-
Vector{Int}(undef, CANONICAL_REPRESENTATION_BUFFER_LENGTH))
407+
Vector{Int}(undef, BUFFER_LENGTH))
408+
409+
# PartitionIterator
410+
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_FOREST_T,
411+
Vector{Int}(undef, BUFFER_LENGTH))
412+
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_FOREST_LEVEL_SEQUENCE,
413+
Vector{Int}(undef, BUFFER_LENGTH))
414+
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_SKELETON,
415+
Vector{Int}(undef, BUFFER_LENGTH))
416+
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_EDGE_SET,
417+
Vector{Bool}(undef, BUFFER_LENGTH))
418+
Threads.resize_nthreads!(PARTITION_ITERATOR_BUFFER_EDGE_SET_TMP,
419+
Vector{Bool}(undef, BUFFER_LENGTH))
420+
421+
return nothing
407422
end
408423

409424

@@ -832,19 +847,58 @@ struct PartitionIterator{T, Tree<:RootedTree{T}}
832847
skeleton::RootedTree{T, Vector{T}}
833848
edge_set::Vector{Bool}
834849
edge_set_tmp::Vector{Bool}
850+
end
851+
852+
function PartitionIterator(t::Tree) where {T, Tree<:RootedTree{T}}
853+
skeleton = RootedTree(Vector{T}(undef, order(t)), true)
854+
edge_set = Vector{Bool}(undef, order(t) - 1)
855+
edge_set_tmp = similar(edge_set)
835856

836-
function PartitionIterator(t::Tree) where {T, Tree<:RootedTree{T}}
837-
skeleton = RootedTree(Vector{T}(undef, order(t)), true)
838-
edge_set = zeros(Bool, order(t) - 1)
839-
edge_set_tmp = similar(edge_set)
857+
t_forest = RootedTree(Vector{T}(undef, order(t)), true)
858+
level_sequence = similar(t_forest.level_sequence)
859+
forest = PartitionForestIterator(t_forest, level_sequence, edge_set_tmp)
860+
PartitionIterator{T, Tree}(t, forest, skeleton, edge_set, edge_set_tmp)
861+
end
840862

841-
t_forest = RootedTree(Vector{T}(undef, order(t)), true)
842-
level_sequence = similar(t_forest.level_sequence)
843-
forest = PartitionForestIterator(t_forest, level_sequence, edge_set_tmp)
844-
new{T, Tree}(t, forest, skeleton, edge_set, edge_set_tmp)
863+
# Allocate global buffer for `PartitionIterator` for each thread
864+
const PARTITION_ITERATOR_BUFFER_FOREST_T = Vector{Vector{Int}}()
865+
const PARTITION_ITERATOR_BUFFER_FOREST_LEVEL_SEQUENCE = Vector{Vector{Int}}()
866+
const PARTITION_ITERATOR_BUFFER_SKELETON = Vector{Vector{Int}}()
867+
const PARTITION_ITERATOR_BUFFER_EDGE_SET = Vector{Vector{Bool}}()
868+
const PARTITION_ITERATOR_BUFFER_EDGE_SET_TMP = Vector{Vector{Bool}}()
869+
870+
function PartitionIterator(t::RootedTree{Int, Vector{Int}})
871+
order_t = order(t)
872+
873+
if order_t <= BUFFER_LENGTH
874+
id = Threads.threadid()
875+
876+
buffer_forest_t = PARTITION_ITERATOR_BUFFER_FOREST_T[id]
877+
resize!(buffer_forest_t, order_t)
878+
level_sequence = PARTITION_ITERATOR_BUFFER_FOREST_LEVEL_SEQUENCE[id]
879+
resize!(level_sequence, order_t)
880+
buffer_skeleton = PARTITION_ITERATOR_BUFFER_SKELETON[id]
881+
resize!(buffer_skeleton, order_t)
882+
edge_set = PARTITION_ITERATOR_BUFFER_EDGE_SET[id]
883+
resize!(edge_set, order_t - 1)
884+
edge_set_tmp = PARTITION_ITERATOR_BUFFER_EDGE_SET_TMP[id]
885+
resize!(edge_set_tmp, order_t - 1)
886+
else
887+
buffer_forest_t = Vector{Int}(undef, order_t)
888+
level_sequence = similar(buffer_forest_t)
889+
buffer_skeleton = similar(buffer_forest_t)
890+
edge_set = Vector{Bool}(undef, order_t - 1)
891+
edge_set_tmp = similar(edge_set)
845892
end
893+
894+
skeleton = RootedTree(buffer_skeleton, true)
895+
t_forest = RootedTree(buffer_forest_t, true)
896+
forest = PartitionForestIterator(t_forest, level_sequence, edge_set_tmp)
897+
PartitionIterator{Int, RootedTree{Int, Vector{Int}}}(
898+
t, forest, skeleton, edge_set, edge_set_tmp)
846899
end
847900

901+
848902
Base.IteratorSize(::Type{<:PartitionIterator}) = Base.HasLength()
849903
Base.length(partitions::PartitionIterator) = 2^length(partitions.edge_set)
850904
Base.eltype(::Type{PartitionIterator{T, Tree}}) where {T, Tree} = Tuple{Vector{RootedTree{T, Vector{T}}}, RootedTree{T, Vector{T}}}

test/runtests.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Plots.unicodeplots()
7272
t = rootedtree([1, 2, 3, 2, 3, 3, 2, 3])
7373
@test t.level_sequence == [1, 2, 3, 3, 2, 3, 2, 3]
7474

75-
level_sequence = zeros(Int, RootedTrees.CANONICAL_REPRESENTATION_BUFFER_LENGTH + 1)
75+
level_sequence = zeros(Int, RootedTrees.BUFFER_LENGTH + 1)
7676
level_sequence[1] -= 1
7777
@inferred rootedtree(level_sequence)
7878
end
@@ -496,6 +496,13 @@ end
496496
@test collect(zip(forests, skeletons)) == collect(PartitionIterator(t))
497497
end
498498
end
499+
500+
level_sequence = zeros(Int, RootedTrees.BUFFER_LENGTH + 1)
501+
level_sequence[1] -= 1
502+
t = rootedtree(level_sequence)
503+
@inferred PartitionIterator(t)
504+
t = @inferred rootedtree!(view(level_sequence, :))
505+
@inferred PartitionIterator(t)
499506
end
500507
end
501508

0 commit comments

Comments
 (0)