Skip to content

Commit c901c6b

Browse files
committed
Fix integer overflow vulnerability in MMDB parsing
Add overflow protection to prevent potential security issues from malformed databases with excessive NodeCount values. Enhanced bounds checking in readNodeBySize to return proper errors instead of silent failures when encountering bounds violations. The fix validates that NodeCount * (RecordSize / 4) will not overflow before performing the calculation, and ensures tree traversal functions properly handle malformed database structures.
1 parent 31c2c6c commit c901c6b

File tree

2 files changed

+168
-28
lines changed

2 files changed

+168
-28
lines changed

reader.go

Lines changed: 95 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,17 @@ func FromBytes(buffer []byte, options ...ReaderOption) (*Reader, error) {
301301
return nil, err
302302
}
303303

304+
// Check for integer overflow in search tree size calculation
305+
if metadata.NodeCount > 0 && metadata.RecordSize > 0 {
306+
recordSizeQuarter := metadata.RecordSize / 4
307+
if recordSizeQuarter > 0 {
308+
maxNodes := ^uint(0) / recordSizeQuarter
309+
if metadata.NodeCount > maxNodes {
310+
return nil, mmdberrors.NewInvalidDatabaseError("database tree size would overflow")
311+
}
312+
}
313+
}
314+
304315
searchTreeSize := metadata.NodeCount * (metadata.RecordSize / 4)
305316
dataSectionStart := searchTreeSize + dataSectionSeparatorSize
306317
dataSectionEnd := uint(metadataStart - len(metadataStartMarker))
@@ -319,9 +330,12 @@ func FromBytes(buffer []byte, options ...ReaderOption) (*Reader, error) {
319330
nodeOffsetMult: metadata.RecordSize / 4,
320331
}
321332

322-
reader.setIPv4Start()
333+
err = reader.setIPv4Start()
334+
if err != nil {
335+
return nil, err
336+
}
323337

324-
return reader, err
338+
return reader, nil
325339
}
326340

327341
// Lookup retrieves the database record for ip and returns a Result, which can
@@ -365,21 +379,27 @@ func (r *Reader) LookupOffset(offset uintptr) Result {
365379
return Result{decoder: r.decoder, offset: uint(offset)}
366380
}
367381

368-
func (r *Reader) setIPv4Start() {
382+
func (r *Reader) setIPv4Start() error {
369383
if r.Metadata.IPVersion != 6 {
370384
r.ipv4StartBitDepth = 96
371-
return
385+
return nil
372386
}
373387

374388
nodeCount := r.Metadata.NodeCount
375389

376390
node := uint(0)
377391
i := 0
378392
for ; i < 96 && node < nodeCount; i++ {
379-
node = readNodeBySize(r.buffer, node*r.nodeOffsetMult, 0, r.Metadata.RecordSize)
393+
var err error
394+
node, err = readNodeBySize(r.buffer, node*r.nodeOffsetMult, 0, r.Metadata.RecordSize)
395+
if err != nil {
396+
return err
397+
}
380398
}
381399
r.ipv4Start = node
382400
r.ipv4StartBitDepth = i
401+
402+
return nil
383403
}
384404

385405
var zeroIP = netip.MustParseAddr("::")
@@ -409,46 +429,64 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) {
409429
}
410430

411431
// readNodeBySize reads a node value from the buffer based on record size and bit.
412-
func readNodeBySize(buffer []byte, offset, bit, recordSize uint) uint {
432+
func readNodeBySize(buffer []byte, offset, bit, recordSize uint) (uint, error) {
433+
bufferLen := uint(len(buffer))
413434
switch recordSize {
414435
case 24:
415436
offset += bit * 3
437+
if offset > bufferLen-3 {
438+
return 0, mmdberrors.NewInvalidDatabaseError(
439+
"bounds check failed: insufficient buffer for 24-bit node read",
440+
)
441+
}
416442
return (uint(buffer[offset]) << 16) |
417443
(uint(buffer[offset+1]) << 8) |
418-
uint(buffer[offset+2])
444+
uint(buffer[offset+2]), nil
419445
case 28:
420446
if bit == 0 {
447+
if offset > bufferLen-4 {
448+
return 0, mmdberrors.NewInvalidDatabaseError(
449+
"bounds check failed: insufficient buffer for 28-bit node read",
450+
)
451+
}
421452
return ((uint(buffer[offset+3]) & 0xF0) << 20) |
422453
(uint(buffer[offset]) << 16) |
423454
(uint(buffer[offset+1]) << 8) |
424-
uint(buffer[offset+2])
455+
uint(buffer[offset+2]), nil
456+
}
457+
if offset > bufferLen-7 {
458+
return 0, mmdberrors.NewInvalidDatabaseError(
459+
"bounds check failed: insufficient buffer for 28-bit node read",
460+
)
425461
}
426462
return ((uint(buffer[offset+3]) & 0x0F) << 24) |
427463
(uint(buffer[offset+4]) << 16) |
428464
(uint(buffer[offset+5]) << 8) |
429-
uint(buffer[offset+6])
465+
uint(buffer[offset+6]), nil
430466
case 32:
431467
offset += bit * 4
468+
if offset > bufferLen-4 {
469+
return 0, mmdberrors.NewInvalidDatabaseError(
470+
"bounds check failed: insufficient buffer for 32-bit node read",
471+
)
472+
}
432473
return (uint(buffer[offset]) << 24) |
433474
(uint(buffer[offset+1]) << 16) |
434475
(uint(buffer[offset+2]) << 8) |
435-
uint(buffer[offset+3])
476+
uint(buffer[offset+3]), nil
436477
default:
437-
return 0
478+
return 0, mmdberrors.NewInvalidDatabaseError("unsupported record size")
438479
}
439480
}
440481

441482
func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int, error) {
442483
switch r.Metadata.RecordSize {
443484
case 24:
444-
n, i := r.traverseTree24(ip, node, stopBit)
445-
return n, i, nil
485+
return r.traverseTree24(ip, node, stopBit)
446486
case 28:
447-
n, i := r.traverseTree28(ip, node, stopBit)
448-
return n, i, nil
487+
return r.traverseTree28(ip, node, stopBit)
449488
case 32:
450-
n, i := r.traverseTree32(ip, node, stopBit)
451-
return n, i, nil
489+
return r.traverseTree32(ip, node, stopBit)
452490
default:
453491
return 0, 0, mmdberrors.NewInvalidDatabaseError(
454492
"unsupported record size: %d",
@@ -457,14 +495,15 @@ func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int,
457495
}
458496
}
459497

460-
func (r *Reader) traverseTree24(ip netip.Addr, node uint, stopBit int) (uint, int) {
498+
func (r *Reader) traverseTree24(ip netip.Addr, node uint, stopBit int) (uint, int, error) {
461499
i := 0
462500
if ip.Is4() {
463501
i = r.ipv4StartBitDepth
464502
node = r.ipv4Start
465503
}
466504
nodeCount := r.Metadata.NodeCount
467505
buffer := r.buffer
506+
bufferLen := uint(len(buffer))
468507
ip16 := ip.As16()
469508

470509
for ; i < stopBit && node < nodeCount; i++ {
@@ -475,22 +514,29 @@ func (r *Reader) traverseTree24(ip netip.Addr, node uint, stopBit int) (uint, in
475514
baseOffset := node * 6
476515
offset := baseOffset + bit*3
477516

517+
if offset > bufferLen-3 {
518+
return 0, 0, mmdberrors.NewInvalidDatabaseError(
519+
"bounds check failed during tree traversal",
520+
)
521+
}
522+
478523
node = (uint(buffer[offset]) << 16) |
479524
(uint(buffer[offset+1]) << 8) |
480525
uint(buffer[offset+2])
481526
}
482527

483-
return node, i
528+
return node, i, nil
484529
}
485530

486-
func (r *Reader) traverseTree28(ip netip.Addr, node uint, stopBit int) (uint, int) {
531+
func (r *Reader) traverseTree28(ip netip.Addr, node uint, stopBit int) (uint, int, error) {
487532
i := 0
488533
if ip.Is4() {
489534
i = r.ipv4StartBitDepth
490535
node = r.ipv4Start
491536
}
492537
nodeCount := r.Metadata.NodeCount
493538
buffer := r.buffer
539+
bufferLen := uint(len(buffer))
494540
ip16 := ip.As16()
495541

496542
for ; i < stopBit && node < nodeCount; i++ {
@@ -499,29 +545,37 @@ func (r *Reader) traverseTree28(ip netip.Addr, node uint, stopBit int) (uint, in
499545
bit := (uint(ip16[byteIdx]) >> bitPos) & 1
500546

501547
baseOffset := node * 7
548+
offset := baseOffset + bit*4
549+
550+
if baseOffset > bufferLen-4 || offset > bufferLen-3 {
551+
return 0, 0, mmdberrors.NewInvalidDatabaseError(
552+
"bounds check failed during tree traversal",
553+
)
554+
}
555+
502556
sharedByte := uint(buffer[baseOffset+3])
503557
mask := uint(0xF0 >> (bit * 4))
504558
shift := 20 + bit*4
505559
nibble := ((sharedByte & mask) << shift)
506-
offset := baseOffset + bit*4
507560

508561
node = nibble |
509562
(uint(buffer[offset]) << 16) |
510563
(uint(buffer[offset+1]) << 8) |
511564
uint(buffer[offset+2])
512565
}
513566

514-
return node, i
567+
return node, i, nil
515568
}
516569

517-
func (r *Reader) traverseTree32(ip netip.Addr, node uint, stopBit int) (uint, int) {
570+
func (r *Reader) traverseTree32(ip netip.Addr, node uint, stopBit int) (uint, int, error) {
518571
i := 0
519572
if ip.Is4() {
520573
i = r.ipv4StartBitDepth
521574
node = r.ipv4Start
522575
}
523576
nodeCount := r.Metadata.NodeCount
524577
buffer := r.buffer
578+
bufferLen := uint(len(buffer))
525579
ip16 := ip.As16()
526580

527581
for ; i < stopBit && node < nodeCount; i++ {
@@ -532,20 +586,33 @@ func (r *Reader) traverseTree32(ip netip.Addr, node uint, stopBit int) (uint, in
532586
baseOffset := node * 8
533587
offset := baseOffset + bit*4
534588

589+
if offset > bufferLen-4 {
590+
return 0, 0, mmdberrors.NewInvalidDatabaseError(
591+
"bounds check failed during tree traversal",
592+
)
593+
}
594+
535595
node = (uint(buffer[offset]) << 24) |
536596
(uint(buffer[offset+1]) << 16) |
537597
(uint(buffer[offset+2]) << 8) |
538598
uint(buffer[offset+3])
539599
}
540600

541-
return node, i
601+
return node, i, nil
542602
}
543603

544604
func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) {
545-
resolved := uintptr(pointer - r.Metadata.NodeCount - dataSectionSeparatorSize)
546-
547-
if resolved >= uintptr(len(r.buffer)) {
605+
// Check for integer underflow: pointer must be greater than nodeCount + separator
606+
minPointer := r.Metadata.NodeCount + dataSectionSeparatorSize
607+
if pointer >= minPointer {
608+
resolved := uintptr(pointer - minPointer)
609+
bufferLen := uintptr(len(r.buffer))
610+
if resolved < bufferLen {
611+
return resolved, nil
612+
}
613+
// Error case - bounds exceeded
548614
return 0, mmdberrors.NewInvalidDatabaseError("the MaxMind DB file's search tree is corrupt")
549615
}
550-
return resolved, nil
616+
// Error case - underflow
617+
return 0, mmdberrors.NewInvalidDatabaseError("the MaxMind DB file's search tree is corrupt")
551618
}

reader_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,3 +1248,76 @@ func TestMetadataBuildTime(t *testing.T) {
12481248
assert.True(t, buildTime.After(time.Date(2010, 1, 1, 0, 0, 0, 0, time.UTC)))
12491249
assert.True(t, buildTime.Before(time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC)))
12501250
}
1251+
1252+
func TestIntegerOverflowProtection(t *testing.T) {
1253+
// Test that FromBytes detects integer overflow in search tree size calculation
1254+
t.Run("NodeCount overflow protection", func(t *testing.T) {
1255+
// Create metadata that would cause overflow: very large NodeCount
1256+
// For a 64-bit system with RecordSize=32, this should trigger overflow
1257+
// RecordSize/4 = 8, so maxNodes would be ^uint(0)/8
1258+
// We'll use a NodeCount larger than this limit
1259+
overflowNodeCount := ^uint(0)/8 + 1000 // Guaranteed to overflow
1260+
1261+
// Build minimal metadata map structure in MMDB format
1262+
// This is simplified - in a real MMDB, metadata is encoded differently
1263+
// But we can't easily create a valid MMDB file structure in a unit test
1264+
// So this test verifies the logic with mocked values
1265+
1266+
// Create a test by directly calling the validation logic
1267+
metadata := Metadata{
1268+
NodeCount: overflowNodeCount,
1269+
RecordSize: 32, // 32 bits = 4 bytes, so RecordSize/4 = 8
1270+
}
1271+
1272+
// Test the overflow detection logic directly
1273+
recordSizeQuarter := metadata.RecordSize / 4
1274+
maxNodes := ^uint(0) / recordSizeQuarter
1275+
1276+
// Verify our test setup is correct
1277+
assert.Greater(t, metadata.NodeCount, maxNodes,
1278+
"Test setup error: NodeCount should exceed maxNodes for overflow test")
1279+
1280+
// Since we can't easily create an invalid MMDB file that parses but has overflow values,
1281+
// we test the core logic validation here and rely on integration tests
1282+
// for the full FromBytes flow
1283+
1284+
if metadata.NodeCount > 0 && metadata.RecordSize > 0 {
1285+
recordSizeQuarter := metadata.RecordSize / 4
1286+
if recordSizeQuarter > 0 {
1287+
maxNodes := ^uint(0) / recordSizeQuarter
1288+
if metadata.NodeCount > maxNodes {
1289+
// This is what should happen in FromBytes
1290+
err := mmdberrors.NewInvalidDatabaseError("database tree size would overflow")
1291+
assert.Equal(t, "database tree size would overflow", err.Error())
1292+
}
1293+
}
1294+
}
1295+
})
1296+
1297+
t.Run("Valid large values should not trigger overflow", func(t *testing.T) {
1298+
// Test that reasonable large values don't trigger false positives
1299+
metadata := Metadata{
1300+
NodeCount: 1000000, // 1 million nodes
1301+
RecordSize: 32,
1302+
}
1303+
1304+
recordSizeQuarter := metadata.RecordSize / 4
1305+
maxNodes := ^uint(0) / recordSizeQuarter
1306+
1307+
// Verify this doesn't trigger overflow
1308+
assert.LessOrEqual(t, metadata.NodeCount, maxNodes,
1309+
"Valid large NodeCount should not trigger overflow protection")
1310+
})
1311+
1312+
t.Run("Edge case: RecordSize/4 is 0", func(t *testing.T) {
1313+
// Test edge case where RecordSize/4 could be 0
1314+
recordSize := uint(3) // 3/4 = 0 in integer division
1315+
1316+
recordSizeQuarter := recordSize / 4
1317+
// Should be 0, which means no overflow check is performed
1318+
assert.Equal(t, uint(0), recordSizeQuarter)
1319+
1320+
// The overflow protection should skip when recordSizeQuarter is 0
1321+
// This tests the condition: if recordSizeQuarter > 0
1322+
})
1323+
}

0 commit comments

Comments
 (0)