Skip to content

Commit 423bdf8

Browse files
committed
fix(inflate): use inputwrapper struct instead of iter to simplify input reading and change some data types for performance
1 parent 9f1fc5e commit 423bdf8

File tree

2 files changed

+106
-66
lines changed

2 files changed

+106
-66
lines changed

miniz_oxide/src/inflate/core.rs

Lines changed: 52 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ use super::*;
44
use crate::shared::{update_adler32, HUFFMAN_LENGTH_ORDER};
55
use ::core::cell::Cell;
66

7+
use ::core::cmp;
78
use ::core::convert::TryInto;
8-
use ::core::{cmp, slice};
99

10-
use self::output_buffer::OutputBuffer;
10+
use self::output_buffer::{InputWrapper, OutputBuffer};
1111

1212
pub const TINFL_LZ_DICT_SIZE: usize = 32_768;
1313

@@ -47,7 +47,7 @@ impl HuffmanTable {
4747

4848
/// Get the symbol and the code length from the huffman tree.
4949
#[inline]
50-
fn tree_lookup(&self, fast_symbol: i32, bit_buf: BitBuffer, mut code_len: u32) -> (i32, u32) {
50+
fn tree_lookup(&self, fast_symbol: i32, bit_buf: BitBuffer, mut code_len: u8) -> (i32, u32) {
5151
let mut symbol = fast_symbol;
5252
// We step through the tree until we encounter a positive value, which indicates a
5353
// symbol.
@@ -65,7 +65,9 @@ impl HuffmanTable {
6565
break;
6666
}
6767
}
68-
(symbol, code_len)
68+
// Note: Using a u8 for code_len inside this function seems to improve performance, but changing it
69+
// in localvars seems to worsen things so we convert it to a u32 here.
70+
(symbol, u32::from(code_len))
6971
}
7072

7173
#[inline]
@@ -87,7 +89,7 @@ impl HuffmanTable {
8789
}
8890
} else {
8991
// We didn't get a symbol from the fast lookup table, so check the tree instead.
90-
Some(self.tree_lookup(symbol, bit_buf, FAST_LOOKUP_BITS.into()))
92+
Some(self.tree_lookup(symbol, bit_buf, FAST_LOOKUP_BITS))
9193
}
9294
}
9395
}
@@ -370,10 +372,12 @@ const DIST_BASE: [u16; 30] = [
370372
/// Get the number of extra bits used for a distance code.
371373
/// (Code numbers above `NUM_DISTANCE_CODES` will give some garbage
372374
/// value.)
375+
#[inline(always)]
373376
const fn num_extra_bits_for_distance_code(code: u8) -> u8 {
377+
// TODO: Need to verify that this is faster on all platforms.
374378
// This can be easily calculated without a lookup.
375379
let c = code >> 1;
376-
c - (c != 0) as u8
380+
c.saturating_sub(1)
377381
}
378382

379383
/// The mask used when indexing the base/extra arrays.
@@ -392,27 +396,12 @@ fn memset<T: Copy>(slice: &mut [T], val: T) {
392396
/// # Panics
393397
/// Panics if there are less than two bytes left.
394398
#[inline]
395-
fn read_u16_le(iter: &mut slice::Iter<u8>) -> u16 {
399+
fn read_u16_le(iter: &mut InputWrapper) -> u16 {
396400
let ret = {
397-
let two_bytes = iter.as_ref()[..2].try_into().unwrap();
401+
let two_bytes = iter.as_slice()[..2].try_into().unwrap_or_default();
398402
u16::from_le_bytes(two_bytes)
399403
};
400-
iter.nth(1);
401-
ret
402-
}
403-
404-
/// Read an le u32 value from the slice iterator.
405-
///
406-
/// # Panics
407-
/// Panics if there are less than four bytes left.
408-
#[inline(always)]
409-
#[cfg(target_pointer_width = "64")]
410-
fn read_u32_le(iter: &mut slice::Iter<u8>) -> u32 {
411-
let ret = {
412-
let four_bytes: [u8; 4] = iter.as_ref()[..4].try_into().unwrap();
413-
u32::from_le_bytes(four_bytes)
414-
};
415-
iter.nth(3);
404+
iter.advance(2);
416405
ret
417406
}
418407

@@ -423,10 +412,10 @@ fn read_u32_le(iter: &mut slice::Iter<u8>) -> u32 {
423412
/// This function assumes that there is at least 4 bytes left in the input buffer.
424413
#[inline(always)]
425414
#[cfg(target_pointer_width = "64")]
426-
fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut slice::Iter<u8>) {
415+
fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut InputWrapper) {
427416
// Read four bytes into the buffer at once.
428417
if l.num_bits < 30 {
429-
l.bit_buf |= BitBuffer::from(read_u32_le(in_iter)) << l.num_bits;
418+
l.bit_buf |= BitBuffer::from(in_iter.read_u32_le()) << l.num_bits;
430419
l.num_bits += 32;
431420
}
432421
}
@@ -435,7 +424,7 @@ fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut slice::Iter<u8>) {
435424
/// Ensures at least 16 bits are present, requires at least 2 bytes in the in buffer.
436425
#[inline(always)]
437426
#[cfg(not(target_pointer_width = "64"))]
438-
fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut slice::Iter<u8>) {
427+
fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut InputWrapper) {
439428
// If the buffer is 32-bit wide, read 2 bytes instead.
440429
if l.num_bits < 15 {
441430
l.bit_buf |= BitBuffer::from(read_u16_le(in_iter)) << l.num_bits;
@@ -491,7 +480,7 @@ fn decode_huffman_code<F>(
491480
l: &mut LocalVars,
492481
table: usize,
493482
flags: u32,
494-
in_iter: &mut slice::Iter<u8>,
483+
in_iter: &mut InputWrapper,
495484
f: F,
496485
) -> Action
497486
where
@@ -501,7 +490,7 @@ where
501490
// ready in the bit buffer to start decoding the next huffman code.
502491
if l.num_bits < 15 {
503492
// First, make sure there is enough data in the bit buffer to decode a huffman code.
504-
if in_iter.len() < 2 {
493+
if in_iter.bytes_left() < 2 {
505494
// If there is less than 2 bytes left in the input buffer, we try to look up
506495
// the huffman code with what's available, and return if that doesn't succeed.
507496
// Original explanation in miniz:
@@ -581,7 +570,7 @@ where
581570
// Mask out the length value.
582571
symbol &= 511;
583572
} else {
584-
let res = r.tables[table].tree_lookup(symbol, l.bit_buf, u32::from(FAST_LOOKUP_BITS));
573+
let res = r.tables[table].tree_lookup(symbol, l.bit_buf, FAST_LOOKUP_BITS);
585574
symbol = res.0;
586575
code_len = res.1;
587576
};
@@ -599,13 +588,13 @@ where
599588
/// returning the result.
600589
/// If reading fails, `Action::End is returned`
601590
#[inline]
602-
fn read_byte<F>(in_iter: &mut slice::Iter<u8>, flags: u32, f: F) -> Action
591+
fn read_byte<F>(in_iter: &mut InputWrapper, flags: u32, f: F) -> Action
603592
where
604593
F: FnOnce(u8) -> Action,
605594
{
606-
match in_iter.next() {
595+
match in_iter.read_byte() {
607596
None => end_of_input(flags),
608-
Some(&byte) => f(byte),
597+
Some(byte) => f(byte),
609598
}
610599
}
611600

@@ -618,7 +607,7 @@ where
618607
fn read_bits<F>(
619608
l: &mut LocalVars,
620609
amount: u32,
621-
in_iter: &mut slice::Iter<u8>,
610+
in_iter: &mut InputWrapper,
622611
flags: u32,
623612
f: F,
624613
) -> Action
@@ -647,7 +636,7 @@ where
647636
}
648637

649638
#[inline]
650-
fn pad_to_bytes<F>(l: &mut LocalVars, in_iter: &mut slice::Iter<u8>, flags: u32, f: F) -> Action
639+
fn pad_to_bytes<F>(l: &mut LocalVars, in_iter: &mut InputWrapper, flags: u32, f: F) -> Action
651640
where
652641
F: FnOnce(&mut LocalVars) -> Action,
653642
{
@@ -854,7 +843,7 @@ struct LocalVars {
854843
pub num_bits: u32,
855844
pub dist: u32,
856845
pub counter: u32,
857-
pub num_extra: u32,
846+
pub num_extra: u8,
858847
}
859848

860849
#[inline]
@@ -981,7 +970,7 @@ fn apply_match(
981970
/// and already improves decompression speed a fair bit.
982971
fn decompress_fast(
983972
r: &mut DecompressorOxide,
984-
in_iter: &mut slice::Iter<u8>,
973+
in_iter: &mut InputWrapper,
985974
out_buf: &mut OutputBuffer,
986975
flags: u32,
987976
local_vars: &mut LocalVars,
@@ -1001,7 +990,7 @@ fn decompress_fast(
1001990
// + 29 + 32 (left in bit buf, including last 13 dist extra) = 111 bits < 14 bytes
1002991
// We need the one extra byte as we may write one length and one full match
1003992
// before checking again.
1004-
if out_buf.bytes_left() < 259 || in_iter.len() < 14 {
993+
if out_buf.bytes_left() < 259 || in_iter.bytes_left() < 14 {
1005994
state = State::DecodeLitlen;
1006995
break 'o TINFLStatus::Done;
1007996
}
@@ -1063,18 +1052,19 @@ fn decompress_fast(
10631052
// The symbol was a length code.
10641053
// # Optimization
10651054
// Mask the value to avoid bounds checks
1066-
// We could use get_unchecked later if can statically verify that
1067-
// this will never go out of bounds.
1068-
l.num_extra = u32::from(LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK]);
1055+
// While the maximum is checked, the compiler isn't able to know that the
1056+
// value won't wrap around here.
1057+
l.num_extra = LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK];
10691058
l.counter = u32::from(LENGTH_BASE[(l.counter - 257) as usize & BASE_EXTRA_MASK]);
10701059
// Length and distance codes have a number of extra bits depending on
10711060
// the base, which together with the base gives us the exact value.
10721061

1062+
// We need to make sure we have at least 33 (so min 5 bytes) bits in the buffer at this spot.
10731063
fill_bit_buffer(&mut l, in_iter);
10741064
if l.num_extra != 0 {
10751065
let extra_bits = l.bit_buf & ((1 << l.num_extra) - 1);
10761066
l.bit_buf >>= l.num_extra;
1077-
l.num_bits -= l.num_extra;
1067+
l.num_bits -= u32::from(l.num_extra);
10781068
l.counter += extra_bits as u32;
10791069
}
10801070

@@ -1093,7 +1083,7 @@ fn decompress_fast(
10931083
break 'o TINFLStatus::Failed;
10941084
}
10951085

1096-
l.num_extra = u32::from(num_extra_bits_for_distance_code(symbol as u8));
1086+
l.num_extra = num_extra_bits_for_distance_code(symbol as u8);
10971087
l.dist = u32::from(DIST_BASE[symbol as usize]);
10981088
} else {
10991089
state.begin(InvalidCodeLen);
@@ -1104,7 +1094,7 @@ fn decompress_fast(
11041094
fill_bit_buffer(&mut l, in_iter);
11051095
let extra_bits = l.bit_buf & ((1 << l.num_extra) - 1);
11061096
l.bit_buf >>= l.num_extra;
1107-
l.num_bits -= l.num_extra;
1097+
l.num_bits -= u32::from(l.num_extra);
11081098
l.dist += extra_bits as u32;
11091099
}
11101100

@@ -1194,7 +1184,7 @@ pub fn decompress(
11941184
return (TINFLStatus::BadParam, 0, 0);
11951185
}
11961186

1197-
let mut in_iter = in_buf.iter();
1187+
let mut in_iter = InputWrapper::from_slice(in_buf);
11981188

11991189
let mut state = r.state;
12001190

@@ -1206,7 +1196,7 @@ pub fn decompress(
12061196
num_bits: r.num_bits,
12071197
dist: r.dist,
12081198
counter: r.counter,
1209-
num_extra: r.num_extra,
1199+
num_extra: r.num_extra as u8,
12101200
};
12111201

12121202
let mut status = 'state_machine: loop {
@@ -1351,20 +1341,20 @@ pub fn decompress(
13511341
}),
13521342

13531343
RawMemcpy2 => generate_state!(state, 'state_machine, {
1354-
if in_iter.len() > 0 {
1344+
if in_iter.bytes_left() > 0 {
13551345
// Copy as many raw bytes as possible from the input to the output using memcpy.
13561346
// Raw block lengths are limited to 64 * 1024, so casting through usize and u32
13571347
// is not an issue.
13581348
let space_left = out_buf.bytes_left();
13591349
let bytes_to_copy = cmp::min(cmp::min(
13601350
space_left,
1361-
in_iter.len()),
1351+
in_iter.bytes_left()),
13621352
l.counter as usize
13631353
);
13641354

13651355
out_buf.write_slice(&in_iter.as_slice()[..bytes_to_copy]);
13661356

1367-
in_iter.nth(bytes_to_copy - 1);
1357+
in_iter.advance(bytes_to_copy);
13681358
l.counter -= bytes_to_copy as u32;
13691359
Action::Jump(RawMemcpy1)
13701360
} else {
@@ -1456,7 +1446,7 @@ pub fn decompress(
14561446
}),
14571447

14581448
ReadExtraBitsCodeSize => generate_state!(state, 'state_machine, {
1459-
let num_extra = l.num_extra;
1449+
let num_extra = l.num_extra.into();
14601450
read_bits(&mut l, num_extra, &mut in_iter, flags, |l, mut extra_bits| {
14611451
// Mask to avoid a bounds check.
14621452
extra_bits += [3, 3, 11][(l.dist as usize - 16) & 3];
@@ -1478,7 +1468,7 @@ pub fn decompress(
14781468
}),
14791469

14801470
DecodeLitlen => generate_state!(state, 'state_machine, {
1481-
if in_iter.len() < 4 || out_buf.bytes_left() < 2 {
1471+
if in_iter.bytes_left() < 4 || out_buf.bytes_left() < 2 {
14821472
// See if we can decode a literal with the data we have left.
14831473
// Jumps to next state (WriteSymbol) if successful.
14841474
decode_huffman_code(
@@ -1496,7 +1486,7 @@ pub fn decompress(
14961486
// If there is enough space, use the fast inner decompression
14971487
// function.
14981488
out_buf.bytes_left() >= 259 &&
1499-
in_iter.len() >= 14
1489+
in_iter.bytes_left() >= 14
15001490
{
15011491
let (status, new_state) = decompress_fast(
15021492
r,
@@ -1587,7 +1577,7 @@ pub fn decompress(
15871577
// We could use get_unchecked later if can statically verify that
15881578
// this will never go out of bounds.
15891579
l.num_extra =
1590-
u32::from(LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK]);
1580+
LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK];
15911581
l.counter = u32::from(LENGTH_BASE[(l.counter - 257) as usize & BASE_EXTRA_MASK]);
15921582
// Length and distance codes have a number of extra bits depending on
15931583
// the base, which together with the base gives us the exact value.
@@ -1600,7 +1590,7 @@ pub fn decompress(
16001590
}),
16011591

16021592
ReadExtraBitsLitlen => generate_state!(state, 'state_machine, {
1603-
let num_extra = l.num_extra;
1593+
let num_extra = l.num_extra.into();
16041594
read_bits(&mut l, num_extra, &mut in_iter, flags, |l, extra_bits| {
16051595
l.counter += extra_bits as u32;
16061596
Action::Jump(DecodeDistance)
@@ -1622,7 +1612,7 @@ pub fn decompress(
16221612
// Invalid distance code.
16231613
return Action::Jump(InvalidDist)
16241614
}
1625-
l.num_extra = u32::from(num_extra_bits_for_distance_code(symbol as u8));
1615+
l.num_extra = num_extra_bits_for_distance_code(symbol as u8);
16261616
l.dist = u32::from(DIST_BASE[symbol]);
16271617
if l.num_extra != 0 {
16281618
// ReadEXTRA_BITS_DISTACNE
@@ -1634,7 +1624,7 @@ pub fn decompress(
16341624
}),
16351625

16361626
ReadExtraBitsDistance => generate_state!(state, 'state_machine, {
1637-
let num_extra = l.num_extra;
1627+
let num_extra = l.num_extra.into();
16381628
read_bits(&mut l, num_extra, &mut in_iter, flags, |l, extra_bits| {
16391629
l.dist += extra_bits as u32;
16401630
Action::Jump(HuffDecodeOuterLoop2)
@@ -1710,9 +1700,9 @@ pub fn decompress(
17101700
if r.finish != 0 {
17111701
pad_to_bytes(&mut l, &mut in_iter, flags, |_| Action::None);
17121702

1713-
let in_consumed = in_buf.len() - in_iter.len();
1703+
let in_consumed = in_buf.len() - in_iter.bytes_left();
17141704
let undo = undo_bytes(&mut l, in_consumed as u32) as usize;
1715-
in_iter = in_buf[in_consumed - undo..].iter();
1705+
in_iter = InputWrapper::from_slice(in_buf[in_consumed - undo..].iter().as_slice());
17161706

17171707
l.bit_buf &= ((1 as BitBuffer) << l.num_bits) - 1;
17181708
debug_assert_eq!(l.num_bits, 0);
@@ -1765,7 +1755,7 @@ pub fn decompress(
17651755
let in_undo = if status != TINFLStatus::NeedsMoreInput
17661756
&& status != TINFLStatus::FailedCannotMakeProgress
17671757
{
1768-
undo_bytes(&mut l, (in_buf.len() - in_iter.len()) as u32) as usize
1758+
undo_bytes(&mut l, (in_buf.len() - in_iter.bytes_left()) as u32) as usize
17691759
} else {
17701760
0
17711761
};
@@ -1785,7 +1775,7 @@ pub fn decompress(
17851775
r.num_bits = l.num_bits;
17861776
r.dist = l.dist;
17871777
r.counter = l.counter;
1788-
r.num_extra = l.num_extra;
1778+
r.num_extra = l.num_extra.into();
17891779

17901780
r.bit_buf &= ((1 as BitBuffer) << r.num_bits) - 1;
17911781

@@ -1816,7 +1806,7 @@ pub fn decompress(
18161806

18171807
(
18181808
status,
1819-
in_buf.len() - in_iter.len() - in_undo,
1809+
in_buf.len() - in_iter.bytes_left() - in_undo,
18201810
out_buf.position() - out_pos,
18211811
)
18221812
}
@@ -1911,7 +1901,7 @@ mod test {
19111901
num_bits: d.num_bits,
19121902
dist: d.dist,
19131903
counter: d.counter,
1914-
num_extra: d.num_extra,
1904+
num_extra: d.num_extra as u8,
19151905
};
19161906
init_tree(&mut d, &mut l).unwrap();
19171907
let llt = &d.tables[LITLEN_TABLE];

0 commit comments

Comments
 (0)