@@ -128,7 +128,6 @@ var metadataStartMarker = []byte("\xAB\xCD\xEFMaxMind.com")
128
128
// All of the methods on Reader are thread-safe. The struct may be safely
129
129
// shared across goroutines.
130
130
type Reader struct {
131
- nodeReader nodeReader
132
131
buffer []byte
133
132
decoder decoder.ReflectionDecoder
134
133
Metadata Metadata
@@ -312,25 +311,8 @@ func FromBytes(buffer []byte, options ...ReaderOption) (*Reader, error) {
312
311
buffer [searchTreeSize + dataSectionSeparatorSize : metadataStart - len (metadataStartMarker )],
313
312
)
314
313
315
- nodeBuffer := buffer [:searchTreeSize ]
316
- var nodeReader nodeReader
317
- switch metadata .RecordSize {
318
- case 24 :
319
- nodeReader = nodeReader24 {buffer : nodeBuffer }
320
- case 28 :
321
- nodeReader = nodeReader28 {buffer : nodeBuffer }
322
- case 32 :
323
- nodeReader = nodeReader32 {buffer : nodeBuffer }
324
- default :
325
- return nil , mmdberrors .NewInvalidDatabaseError (
326
- "unknown record size: %d" ,
327
- metadata .RecordSize ,
328
- )
329
- }
330
-
331
314
reader := & Reader {
332
315
buffer : buffer ,
333
- nodeReader : nodeReader ,
334
316
decoder : d ,
335
317
Metadata : metadata ,
336
318
ipv4Start : 0 ,
@@ -394,7 +376,7 @@ func (r *Reader) setIPv4Start() {
394
376
node := uint (0 )
395
377
i := 0
396
378
for ; i < 96 && node < nodeCount ; i ++ {
397
- node = r . nodeReader . readLeft ( node * r .nodeOffsetMult )
379
+ node = readNodeBySize ( r . buffer , node * r .nodeOffsetMult , 0 , r . Metadata . RecordSize )
398
380
}
399
381
r .ipv4Start = node
400
382
r .ipv4StartBitDepth = i
@@ -410,7 +392,10 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) {
410
392
)
411
393
}
412
394
413
- node , prefixLength := r .traverseTree (ip , 0 , 128 )
395
+ node , prefixLength , err := r .traverseTree (ip , 0 , 128 )
396
+ if err != nil {
397
+ return 0 , 0 , err
398
+ }
414
399
415
400
nodeCount := r .Metadata .NodeCount
416
401
if node == nodeCount {
@@ -423,25 +408,134 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) {
423
408
return 0 , prefixLength , mmdberrors .NewInvalidDatabaseError ("invalid node in search tree" )
424
409
}
425
410
426
- func (r * Reader ) traverseTree (ip netip.Addr , node uint , stopBit int ) (uint , int ) {
411
+ // readNodeBySize reads a node value from the buffer based on record size and bit.
412
+ func readNodeBySize (buffer []byte , offset , bit , recordSize uint ) uint {
413
+ switch recordSize {
414
+ case 24 :
415
+ offset += bit * 3
416
+ return (uint (buffer [offset ]) << 16 ) |
417
+ (uint (buffer [offset + 1 ]) << 8 ) |
418
+ uint (buffer [offset + 2 ])
419
+ case 28 :
420
+ if bit == 0 {
421
+ return ((uint (buffer [offset + 3 ]) & 0xF0 ) << 20 ) |
422
+ (uint (buffer [offset ]) << 16 ) |
423
+ (uint (buffer [offset + 1 ]) << 8 ) |
424
+ uint (buffer [offset + 2 ])
425
+ }
426
+ return ((uint (buffer [offset + 3 ]) & 0x0F ) << 24 ) |
427
+ (uint (buffer [offset + 4 ]) << 16 ) |
428
+ (uint (buffer [offset + 5 ]) << 8 ) |
429
+ uint (buffer [offset + 6 ])
430
+ case 32 :
431
+ offset += bit * 4
432
+ return (uint (buffer [offset ]) << 24 ) |
433
+ (uint (buffer [offset + 1 ]) << 16 ) |
434
+ (uint (buffer [offset + 2 ]) << 8 ) |
435
+ uint (buffer [offset + 3 ])
436
+ default :
437
+ return 0
438
+ }
439
+ }
440
+
441
+ func (r * Reader ) traverseTree (ip netip.Addr , node uint , stopBit int ) (uint , int , error ) {
442
+ switch r .Metadata .RecordSize {
443
+ case 24 :
444
+ n , i := r .traverseTree24 (ip , node , stopBit )
445
+ return n , i , nil
446
+ case 28 :
447
+ n , i := r .traverseTree28 (ip , node , stopBit )
448
+ return n , i , nil
449
+ case 32 :
450
+ n , i := r .traverseTree32 (ip , node , stopBit )
451
+ return n , i , nil
452
+ default :
453
+ return 0 , 0 , mmdberrors .NewInvalidDatabaseError (
454
+ "unsupported record size: %d" ,
455
+ r .Metadata .RecordSize ,
456
+ )
457
+ }
458
+ }
459
+
460
+ func (r * Reader ) traverseTree24 (ip netip.Addr , node uint , stopBit int ) (uint , int ) {
427
461
i := 0
428
462
if ip .Is4 () {
429
463
i = r .ipv4StartBitDepth
430
464
node = r .ipv4Start
431
465
}
432
466
nodeCount := r .Metadata .NodeCount
467
+ buffer := r .buffer
468
+ ip16 := ip .As16 ()
469
+
470
+ for ; i < stopBit && node < nodeCount ; i ++ {
471
+ byteIdx := i >> 3
472
+ bitPos := 7 - (i & 7 )
473
+ bit := (uint (ip16 [byteIdx ]) >> bitPos ) & 1
433
474
475
+ baseOffset := node * 6
476
+ offset := baseOffset + bit * 3
477
+
478
+ node = (uint (buffer [offset ]) << 16 ) |
479
+ (uint (buffer [offset + 1 ]) << 8 ) |
480
+ uint (buffer [offset + 2 ])
481
+ }
482
+
483
+ return node , i
484
+ }
485
+
486
+ func (r * Reader ) traverseTree28 (ip netip.Addr , node uint , stopBit int ) (uint , int ) {
487
+ i := 0
488
+ if ip .Is4 () {
489
+ i = r .ipv4StartBitDepth
490
+ node = r .ipv4Start
491
+ }
492
+ nodeCount := r .Metadata .NodeCount
493
+ buffer := r .buffer
434
494
ip16 := ip .As16 ()
435
495
436
496
for ; i < stopBit && node < nodeCount ; i ++ {
437
- bit := uint (1 ) & (uint (ip16 [i >> 3 ]) >> (7 - (i % 8 )))
497
+ byteIdx := i >> 3
498
+ bitPos := 7 - (i & 7 )
499
+ bit := (uint (ip16 [byteIdx ]) >> bitPos ) & 1
500
+
501
+ baseOffset := node * 7
502
+ sharedByte := uint (buffer [baseOffset + 3 ])
503
+ mask := uint (0xF0 >> (bit * 4 ))
504
+ shift := 20 + bit * 4
505
+ nibble := ((sharedByte & mask ) << shift )
506
+ offset := baseOffset + bit * 4
507
+
508
+ node = nibble |
509
+ (uint (buffer [offset ]) << 16 ) |
510
+ (uint (buffer [offset + 1 ]) << 8 ) |
511
+ uint (buffer [offset + 2 ])
512
+ }
438
513
439
- offset := node * r .nodeOffsetMult
440
- if bit == 0 {
441
- node = r .nodeReader .readLeft (offset )
442
- } else {
443
- node = r .nodeReader .readRight (offset )
444
- }
514
+ return node , i
515
+ }
516
+
517
+ func (r * Reader ) traverseTree32 (ip netip.Addr , node uint , stopBit int ) (uint , int ) {
518
+ i := 0
519
+ if ip .Is4 () {
520
+ i = r .ipv4StartBitDepth
521
+ node = r .ipv4Start
522
+ }
523
+ nodeCount := r .Metadata .NodeCount
524
+ buffer := r .buffer
525
+ ip16 := ip .As16 ()
526
+
527
+ for ; i < stopBit && node < nodeCount ; i ++ {
528
+ byteIdx := i >> 3
529
+ bitPos := 7 - (i & 7 )
530
+ bit := (uint (ip16 [byteIdx ]) >> bitPos ) & 1
531
+
532
+ baseOffset := node * 8
533
+ offset := baseOffset + bit * 4
534
+
535
+ node = (uint (buffer [offset ]) << 24 ) |
536
+ (uint (buffer [offset + 1 ]) << 16 ) |
537
+ (uint (buffer [offset + 2 ]) << 8 ) |
538
+ uint (buffer [offset + 3 ])
445
539
}
446
540
447
541
return node , i
0 commit comments