Skip to content

Commit 2446064

Browse files
Add (s *Set) From(int) iter.Seq[int]
1 parent fe1772d commit 2446064

File tree

4 files changed

+86
-19
lines changed

4 files changed

+86
-19
lines changed

.travis.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
language: go
22

33
go:
4-
- "1.10.x"
4+
- "1.23.x"
5+
- "1.x"
56
- tip

bitset.go

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ package bitset
33

44
import (
55
"encoding/binary"
6+
"iter"
67
"math/bits"
78
"strconv"
89
"strings"
910
)
1011

11-
const maxUint = 1<<bits.UintSize - 1
12+
const maxUint = ^uint(0)
1213

1314
// Set represents a set of positive integers. Memory usage is proportional to the largest integer in the Set.
1415
type Set struct {
@@ -44,8 +45,8 @@ func (s *Set) AddRange(low, hi int) {
4445
s.grow(hi)
4546
w0, _ := idx(low)
4647
w1, _ := idx(hi - 1)
47-
leftMask := uint(maxUint) << (uint(low) % bits.UintSize)
48-
rightMask := uint(maxUint) >> (uint(bits.UintSize-hi) % bits.UintSize)
48+
leftMask := maxUint << (uint(low) % bits.UintSize)
49+
rightMask := maxUint >> (uint(bits.UintSize-hi) % bits.UintSize)
4950
if w1 == w0 {
5051
s.s[w0] |= leftMask & rightMask
5152
return
@@ -86,8 +87,8 @@ func (s *Set) RemoveRange(low, hi int) {
8687
hi = len(s.s) * bits.UintSize
8788
w1 = len(s.s) - 1
8889
}
89-
leftMask := uint(maxUint) << (uint(low) % bits.UintSize)
90-
rightMask := uint(maxUint) >> (uint(bits.UintSize-hi) % bits.UintSize)
90+
leftMask := maxUint << (uint(low) % bits.UintSize)
91+
rightMask := maxUint >> (uint(bits.UintSize-hi) % bits.UintSize)
9192
if w1 == w0 {
9293
s.s[w0] &^= leftMask & rightMask
9394
return
@@ -187,19 +188,43 @@ func (s *Set) Equal(ss *Set) bool {
187188
return true
188189
}
189190

191+
// From returns a sequence of integers in s starting at i.
192+
func (s *Set) From(i int) iter.Seq[int] {
193+
return func(yield func(int) bool) {
194+
if i < 0 {
195+
i = 0
196+
}
197+
si := i / bits.UintSize
198+
for idx := si; idx < len(s.s); idx++ {
199+
word := s.s[idx]
200+
if idx == si {
201+
word &= maxUint << (i % bits.UintSize)
202+
}
203+
for word != 0 {
204+
j := bits.TrailingZeros(word)
205+
if !yield(idx*bits.UintSize + j) {
206+
return
207+
}
208+
word &^= 1 << j
209+
}
210+
}
211+
}
212+
}
213+
190214
// NextAfter returns the smallest integer in s greater than or equal to i or -1 if no such integer exists.
191215
func (s *Set) NextAfter(i int) int {
192216
if i < 0 {
193217
// There can be no integers in s less than 0 by definition
194218
i = 0
195219
}
196-
mask := uint(maxUint) << (uint(i) % bits.UintSize)
220+
mask := maxUint << (uint(i) % bits.UintSize)
197221
for j := i / bits.UintSize; j < len(s.s); j++ {
198222
word := s.s[j] & mask
199223
mask = maxUint
200224
if word != 0 {
201225
return j*bits.UintSize + bits.TrailingZeros(word)
202226
}
227+
word &^= 1 << j
203228
}
204229
return -1
205230
}
@@ -221,7 +246,7 @@ func (s *Set) String() string {
221246
var buf strings.Builder
222247
buf.WriteByte('[')
223248
first := true
224-
for i := s.NextAfter(0); i >= 0; i = s.NextAfter(i + 1) {
249+
for i := range s.From(0) {
225250
if !first {
226251
buf.WriteByte(' ')
227252
}
@@ -285,10 +310,3 @@ func idx(i int) (w int, mask uint) {
285310
mask = 1 << (uint(i) % bits.UintSize)
286311
return
287312
}
288-
289-
func min(i, j int) int {
290-
if i < j {
291-
return i
292-
}
293-
return j
294-
}

bitset_test.go

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@ package bitset
33
import (
44
"bytes"
55
"fmt"
6+
"iter"
67
"math/bits"
78
"math/rand"
89
"reflect"
910
"runtime"
11+
"slices"
1012
"sort"
1113
"testing"
1214
"testing/quick"
1315
)
1416

15-
// NextAfter can be used to iterate over the elements of the set.
16-
func ExampleSet_NextAfter() {
17+
// From can be used to iterate over the elements of the set.
18+
func ExampleSet_From() {
1719
s := new(Set)
1820
s.Add(2)
1921
s.Add(42)
2022
s.Add(13)
21-
for i := s.NextAfter(0); i >= 0; i = s.NextAfter(i + 1) {
23+
for i := range s.From(0) {
2224
fmt.Println(i)
2325
}
2426
// Output:
@@ -211,7 +213,7 @@ func TestNextAfter(t *testing.T) {
211213
}
212214
var n int
213215
var oldi int
214-
for i := b.NextAfter(0); i >= 0; i = b.NextAfter(i + 1) {
216+
for i := b.NextAfter(-1); i >= 0; i = b.NextAfter(i + 1) {
215217
if l[n] != i {
216218
t.Logf("b.NextAfter(%d) = %d, expected %d", oldi, i, l[n])
217219
return false
@@ -226,6 +228,31 @@ func TestNextAfter(t *testing.T) {
226228
}
227229
}
228230

231+
func TestFrom(t *testing.T) {
232+
f := func(l ascendingInts, fstart float64) bool {
233+
b := new(Set)
234+
for _, i := range l {
235+
b.Add(i)
236+
}
237+
start := int(fstart*float64(len(l))) - 1
238+
got := slices.Collect(b.From(start))
239+
var want ascendingInts
240+
for _, num := range l {
241+
if num >= start {
242+
want = append(want, num)
243+
}
244+
}
245+
if !slices.Equal(got, want) {
246+
t.Logf("b.From(%d) = %v, expected %v", start, got, want)
247+
return false
248+
}
249+
return true
250+
}
251+
if err := quick.Check(f, nil); err != nil {
252+
t.Error(err)
253+
}
254+
}
255+
229256
func TestBytes(t *testing.T) {
230257
f := func(data0 []byte) bool {
231258
// Get rid of trailing zero bytes
@@ -429,6 +456,25 @@ func BenchmarkNextAfter(b *testing.B) {
429456
}
430457
}
431458

459+
func BenchmarkFrom(b *testing.B) {
460+
buf := make([]byte, 10000)
461+
rand.Read(buf)
462+
s := new(Set)
463+
s.FromBytes(buf)
464+
next, stop := iter.Pull(s.From(0))
465+
defer stop()
466+
var x int
467+
b.ResetTimer()
468+
for i := 0; i < b.N; i++ {
469+
var ok bool
470+
x, ok = next()
471+
if !ok {
472+
x = 0
473+
}
474+
}
475+
_ = x
476+
}
477+
432478
func bitwiseF(f func(p, q bool) bool, l0, l1 []int) []int {
433479
var x []int
434480
lim := max(l0, l1)

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
module github.com/takeyourhatoff/bitset
2+
3+
go 1.23.0

0 commit comments

Comments
 (0)