Skip to content

Commit 4bccc80

Browse files
authored
Merge pull request #42 from SciML/hr/iterate
simplify iterate for the PartitionForestIterator
1 parent bfcb756 commit 4bccc80

File tree

1 file changed

+10
-24
lines changed

1 file changed

+10
-24
lines changed

src/RootedTrees.jl

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -632,37 +632,28 @@ Base.length(forest::PartitionForestIterator) = count(==(false), forest.edge_set)
632632
Base.eltype(::Type{PartitionForestIterator{T, V, Tree}}) where {T, V, Tree} = Tree
633633

634634
function Base.iterate(forest::PartitionForestIterator)
635-
edge_to_remove = findlast(==(false), forest.edge_set)
636-
# If `edge_to_remove` can only be inferred as `Union{Nothing, Int}`, the
637-
# return value of `iterate` can only be inferred as a double union, which
638-
# is too much for the compiler right now and introduces type instabilities.
639-
# Thus, we set `edge_to_remove` to a sensible value indicating that there
640-
# is no edge to remove.
641-
if edge_to_remove === nothing
642-
edge_to_remove = typemin(Int)
643-
end
644-
finished = false
645-
iterate(forest, (edge_to_remove, finished))
635+
iterate(forest, lastindex(forest.edge_set))
646636
end
647637

648-
function Base.iterate(forest::PartitionForestIterator, state)
638+
function Base.iterate(forest::PartitionForestIterator, search_start)
649639
t = forest.t
650640
edge_set = forest.edge_set
651641
level_sequence = forest.level_sequence
652-
edge_to_remove, finished = state
653642

654-
# We have already returned the final tree.
655-
if finished
643+
# We use `search_start = typemin(Int)` to indicate that we have already
644+
# returned the final tree in the previous call.
645+
if search_start == typemin(Int)
656646
return nothing
657647
end
658648

649+
edge_to_remove = findprev(==(false), edge_set, search_start)
650+
659651
# There are no further edges to remove and we can return the final tree.
660-
if edge_to_remove == typemin(Int)
652+
if edge_to_remove === nothing
661653
resize!(t.level_sequence, length(level_sequence))
662654
copy!(t.level_sequence, level_sequence)
663655
canonical_representation!(t)
664-
finished = true
665-
return (t, (edge_to_remove, finished))
656+
return (t, typemin(Int))
666657
end
667658

668659
# On to the next subtree
@@ -683,12 +674,7 @@ function Base.iterate(forest::PartitionForestIterator, state)
683674
deleteat!(level_sequence, subtree_root_index:subtree_last_index)
684675
deleteat!(edge_set, subtree_root_index-1:subtree_last_index-1)
685676

686-
edge_to_remove = findprev(==(false), edge_set, edge_to_remove - 1)
687-
if edge_to_remove === nothing
688-
edge_to_remove = typemin(Int)
689-
end
690-
finished = false
691-
return (t, (edge_to_remove, finished))
677+
return (t, edge_to_remove - 1)
692678
end
693679

694680
# necessary for simple and convenient use since the iterates may be modified

0 commit comments

Comments
 (0)