Skip to content

Commit 4f8b5f8

Browse files
committed
Optimize tree traversal with specialized functions
Replace nodeReader interface with specialized traverseTree functions for each record size. Eliminates interface dispatch overhead and implements branchless offset calculations for improved performance.
1 parent 8230991 commit 4f8b5f8

File tree

3 files changed

+132
-89
lines changed

3 files changed

+132
-89
lines changed

node.go

Lines changed: 0 additions & 58 deletions
This file was deleted.

reader.go

Lines changed: 122 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ var metadataStartMarker = []byte("\xAB\xCD\xEFMaxMind.com")
128128
// All of the methods on Reader are thread-safe. The struct may be safely
129129
// shared across goroutines.
130130
type Reader struct {
131-
nodeReader nodeReader
132131
buffer []byte
133132
decoder decoder.ReflectionDecoder
134133
Metadata Metadata
@@ -312,25 +311,8 @@ func FromBytes(buffer []byte, options ...ReaderOption) (*Reader, error) {
312311
buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)],
313312
)
314313

315-
nodeBuffer := buffer[:searchTreeSize]
316-
var nodeReader nodeReader
317-
switch metadata.RecordSize {
318-
case 24:
319-
nodeReader = nodeReader24{buffer: nodeBuffer}
320-
case 28:
321-
nodeReader = nodeReader28{buffer: nodeBuffer}
322-
case 32:
323-
nodeReader = nodeReader32{buffer: nodeBuffer}
324-
default:
325-
return nil, mmdberrors.NewInvalidDatabaseError(
326-
"unknown record size: %d",
327-
metadata.RecordSize,
328-
)
329-
}
330-
331314
reader := &Reader{
332315
buffer: buffer,
333-
nodeReader: nodeReader,
334316
decoder: d,
335317
Metadata: metadata,
336318
ipv4Start: 0,
@@ -394,7 +376,7 @@ func (r *Reader) setIPv4Start() {
394376
node := uint(0)
395377
i := 0
396378
for ; i < 96 && node < nodeCount; i++ {
397-
node = r.nodeReader.readLeft(node * r.nodeOffsetMult)
379+
node = readNodeBySize(r.buffer, node*r.nodeOffsetMult, 0, r.Metadata.RecordSize)
398380
}
399381
r.ipv4Start = node
400382
r.ipv4StartBitDepth = i
@@ -410,7 +392,10 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) {
410392
)
411393
}
412394

413-
node, prefixLength := r.traverseTree(ip, 0, 128)
395+
node, prefixLength, err := r.traverseTree(ip, 0, 128)
396+
if err != nil {
397+
return 0, 0, err
398+
}
414399

415400
nodeCount := r.Metadata.NodeCount
416401
if node == nodeCount {
@@ -423,25 +408,134 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) {
423408
return 0, prefixLength, mmdberrors.NewInvalidDatabaseError("invalid node in search tree")
424409
}
425410

426-
func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int) {
411+
// readNodeBySize reads a node value from the buffer based on record size and bit.
412+
func readNodeBySize(buffer []byte, offset, bit, recordSize uint) uint {
413+
switch recordSize {
414+
case 24:
415+
offset += bit * 3
416+
return (uint(buffer[offset]) << 16) |
417+
(uint(buffer[offset+1]) << 8) |
418+
uint(buffer[offset+2])
419+
case 28:
420+
if bit == 0 {
421+
return ((uint(buffer[offset+3]) & 0xF0) << 20) |
422+
(uint(buffer[offset]) << 16) |
423+
(uint(buffer[offset+1]) << 8) |
424+
uint(buffer[offset+2])
425+
}
426+
return ((uint(buffer[offset+3]) & 0x0F) << 24) |
427+
(uint(buffer[offset+4]) << 16) |
428+
(uint(buffer[offset+5]) << 8) |
429+
uint(buffer[offset+6])
430+
case 32:
431+
offset += bit * 4
432+
return (uint(buffer[offset]) << 24) |
433+
(uint(buffer[offset+1]) << 16) |
434+
(uint(buffer[offset+2]) << 8) |
435+
uint(buffer[offset+3])
436+
default:
437+
return 0
438+
}
439+
}
440+
441+
func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int, error) {
442+
switch r.Metadata.RecordSize {
443+
case 24:
444+
n, i := r.traverseTree24(ip, node, stopBit)
445+
return n, i, nil
446+
case 28:
447+
n, i := r.traverseTree28(ip, node, stopBit)
448+
return n, i, nil
449+
case 32:
450+
n, i := r.traverseTree32(ip, node, stopBit)
451+
return n, i, nil
452+
default:
453+
return 0, 0, mmdberrors.NewInvalidDatabaseError(
454+
"unsupported record size: %d",
455+
r.Metadata.RecordSize,
456+
)
457+
}
458+
}
459+
460+
func (r *Reader) traverseTree24(ip netip.Addr, node uint, stopBit int) (uint, int) {
427461
i := 0
428462
if ip.Is4() {
429463
i = r.ipv4StartBitDepth
430464
node = r.ipv4Start
431465
}
432466
nodeCount := r.Metadata.NodeCount
467+
buffer := r.buffer
468+
ip16 := ip.As16()
469+
470+
for ; i < stopBit && node < nodeCount; i++ {
471+
byteIdx := i >> 3
472+
bitPos := 7 - (i & 7)
473+
bit := (uint(ip16[byteIdx]) >> bitPos) & 1
433474

475+
baseOffset := node * 6
476+
offset := baseOffset + bit*3
477+
478+
node = (uint(buffer[offset]) << 16) |
479+
(uint(buffer[offset+1]) << 8) |
480+
uint(buffer[offset+2])
481+
}
482+
483+
return node, i
484+
}
485+
486+
func (r *Reader) traverseTree28(ip netip.Addr, node uint, stopBit int) (uint, int) {
487+
i := 0
488+
if ip.Is4() {
489+
i = r.ipv4StartBitDepth
490+
node = r.ipv4Start
491+
}
492+
nodeCount := r.Metadata.NodeCount
493+
buffer := r.buffer
434494
ip16 := ip.As16()
435495

436496
for ; i < stopBit && node < nodeCount; i++ {
437-
bit := uint(1) & (uint(ip16[i>>3]) >> (7 - (i % 8)))
497+
byteIdx := i >> 3
498+
bitPos := 7 - (i & 7)
499+
bit := (uint(ip16[byteIdx]) >> bitPos) & 1
500+
501+
baseOffset := node * 7
502+
sharedByte := uint(buffer[baseOffset+3])
503+
mask := uint(0xF0 >> (bit * 4))
504+
shift := 20 + bit*4
505+
nibble := ((sharedByte & mask) << shift)
506+
offset := baseOffset + bit*4
507+
508+
node = nibble |
509+
(uint(buffer[offset]) << 16) |
510+
(uint(buffer[offset+1]) << 8) |
511+
uint(buffer[offset+2])
512+
}
438513

439-
offset := node * r.nodeOffsetMult
440-
if bit == 0 {
441-
node = r.nodeReader.readLeft(offset)
442-
} else {
443-
node = r.nodeReader.readRight(offset)
444-
}
514+
return node, i
515+
}
516+
517+
func (r *Reader) traverseTree32(ip netip.Addr, node uint, stopBit int) (uint, int) {
518+
i := 0
519+
if ip.Is4() {
520+
i = r.ipv4StartBitDepth
521+
node = r.ipv4Start
522+
}
523+
nodeCount := r.Metadata.NodeCount
524+
buffer := r.buffer
525+
ip16 := ip.As16()
526+
527+
for ; i < stopBit && node < nodeCount; i++ {
528+
byteIdx := i >> 3
529+
bitPos := 7 - (i & 7)
530+
bit := (uint(ip16[byteIdx]) >> bitPos) & 1
531+
532+
baseOffset := node * 8
533+
offset := baseOffset + bit*4
534+
535+
node = (uint(buffer[offset]) << 24) |
536+
(uint(buffer[offset+1]) << 16) |
537+
(uint(buffer[offset+2]) << 8) |
538+
uint(buffer[offset+3])
445539
}
446540

447541
return node, i

traverse.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,14 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption)
101101
stopBit += 96
102102
}
103103

104-
pointer, bit := r.traverseTree(ip, 0, stopBit)
104+
pointer, bit, err := r.traverseTree(ip, 0, stopBit)
105+
if err != nil {
106+
yield(Result{
107+
ip: ip,
108+
err: err,
109+
})
110+
return
111+
}
105112

106113
prefix, err := netIP.Prefix(bit)
107114
if err != nil {
@@ -182,7 +189,7 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption)
182189
ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8))
183190

184191
offset := node.pointer * r.nodeOffsetMult
185-
rightPointer := r.nodeReader.readRight(offset)
192+
rightPointer := readNodeBySize(r.buffer, offset, 1, r.Metadata.RecordSize)
186193

187194
node.bit++
188195
nodes = append(nodes, netNode{
@@ -191,7 +198,7 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption)
191198
bit: node.bit,
192199
})
193200

194-
node.pointer = r.nodeReader.readLeft(offset)
201+
node.pointer = readNodeBySize(r.buffer, offset, 0, r.Metadata.RecordSize)
195202
}
196203
}
197204
}

0 commit comments

Comments
 (0)