Skip to content

Commit 52c090e

Browse files
ZacBlancoyingsu00
authored andcommitted
[parquet] Support 64-bit RLE-encoded ShortDecimal
Previously, in the parquet writer short decimals could be written as RLE-encoded with an Int64 logical type. However, we lacked support in the reader to decode this type properly back into a short decimal. This commit adds support for the RLE-encoded 64-bit short decimals.
1 parent 97d88cf commit 52c090e

File tree

4 files changed

+188
-1
lines changed

4 files changed

+188
-1
lines changed

presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,23 @@ public void testDecimalBackedByINT64()
939939
}
940940
}
941941

942+
@Test
943+
public void testRLEDecimalBackedByINT64()
944+
throws Exception
945+
{
946+
int[] scales = {9, 9, 9, 9, 9, 9, 9, 9, 9};
947+
for (int precision = MAX_PRECISION_INT32 + 1; precision <= MAX_PRECISION_INT64; precision++) {
948+
int scale = scales[precision - MAX_PRECISION_INT32 - 1];
949+
MessageType parquetSchema = parseMessageType(format("message hive_decimal { optional INT64 test (DECIMAL(%d, %d)); }", precision, scale));
950+
ContiguousSet<Long> longValues = longsBetween(1, 1_000);
951+
ImmutableList.Builder<SqlDecimal> expectedValues = new ImmutableList.Builder<>();
952+
for (Long value : longValues) {
953+
expectedValues.add(SqlDecimal.of(value, precision, scale));
954+
}
955+
tester.testRoundTrip(javaLongObjectInspector, longValues, expectedValues.build(), createDecimalType(precision, scale), Optional.of(parquetSchema));
956+
}
957+
}
958+
942959
private void testDecimal(int precision, int scale, Optional<MessageType> parquetSchema)
943960
throws Exception
944961
{

presto-parquet/src/main/java/com/facebook/presto/parquet/batchreader/decoders/Decoders.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ private static ValuesDecoder createValuesDecoder(ColumnDescriptor columnDescript
179179
if (isTimeStampMicrosType(columnDescriptor) || isTimeMicrosType(columnDescriptor)) {
180180
return new Int64TimeAndTimestampMicrosRLEDictionaryValuesDecoder(bitWidth, inputStream, (LongDictionary) dictionary);
181181
}
182+
if (isDecimalType(columnDescriptor) && isShortDecimalType(columnDescriptor)) {
183+
return new Int64RLEDictionaryValuesDecoder(bitWidth, inputStream, (LongDictionary) dictionary);
184+
}
182185
}
183186
case DOUBLE: {
184187
return new Int64RLEDictionaryValuesDecoder(bitWidth, inputStream, (LongDictionary) dictionary);

presto-parquet/src/main/java/com/facebook/presto/parquet/batchreader/decoders/rle/Int64RLEDictionaryValuesDecoder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
package com.facebook.presto.parquet.batchreader.decoders.rle;
1515

1616
import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.Int64ValuesDecoder;
17+
import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.ShortDecimalValuesDecoder;
1718
import com.facebook.presto.parquet.dictionary.LongDictionary;
1819
import org.apache.parquet.io.ParquetDecodingException;
1920
import org.openjdk.jol.info.ClassLayout;
@@ -27,7 +28,7 @@
2728

2829
public class Int64RLEDictionaryValuesDecoder
2930
extends BaseRLEBitPackedDecoder
30-
implements Int64ValuesDecoder
31+
implements Int64ValuesDecoder, ShortDecimalValuesDecoder
3132
{
3233
private static final int INSTANCE_SIZE = ClassLayout.parseClass(Int64RLEDictionaryValuesDecoder.class).instanceSize();
3334

presto-parquet/src/test/java/com/facebook/presto/parquet/batchreader/decoders/TestValuesDecoders.java

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@
1919
import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.Int32ValuesDecoder;
2020
import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.Int64TimeAndTimestampMicrosValuesDecoder;
2121
import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.Int64ValuesDecoder;
22+
import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.ShortDecimalValuesDecoder;
2223
import com.facebook.presto.parquet.batchreader.decoders.ValuesDecoder.TimestampValuesDecoder;
2324
import com.facebook.presto.parquet.batchreader.decoders.plain.BinaryPlainValuesDecoder;
2425
import com.facebook.presto.parquet.batchreader.decoders.plain.BooleanPlainValuesDecoder;
2526
import com.facebook.presto.parquet.batchreader.decoders.plain.Int32PlainValuesDecoder;
27+
import com.facebook.presto.parquet.batchreader.decoders.plain.Int32ShortDecimalPlainValuesDecoder;
2628
import com.facebook.presto.parquet.batchreader.decoders.plain.Int64PlainValuesDecoder;
29+
import com.facebook.presto.parquet.batchreader.decoders.plain.Int64ShortDecimalPlainValuesDecoder;
2730
import com.facebook.presto.parquet.batchreader.decoders.plain.Int64TimeAndTimestampMicrosPlainValuesDecoder;
2831
import com.facebook.presto.parquet.batchreader.decoders.plain.TimestampPlainValuesDecoder;
2932
import com.facebook.presto.parquet.batchreader.decoders.rle.BinaryRLEDictionaryValuesDecoder;
3033
import com.facebook.presto.parquet.batchreader.decoders.rle.BooleanRLEValuesDecoder;
3134
import com.facebook.presto.parquet.batchreader.decoders.rle.Int32RLEDictionaryValuesDecoder;
35+
import com.facebook.presto.parquet.batchreader.decoders.rle.Int32ShortDecimalRLEDictionaryValuesDecoder;
3236
import com.facebook.presto.parquet.batchreader.decoders.rle.Int64RLEDictionaryValuesDecoder;
3337
import com.facebook.presto.parquet.batchreader.decoders.rle.Int64TimeAndTimestampMicrosRLEDictionaryValuesDecoder;
3438
import com.facebook.presto.parquet.batchreader.decoders.rle.TimestampRLEDictionaryValuesDecoder;
@@ -118,6 +122,26 @@ private static BooleanValuesDecoder booleanRLE(byte[] pageBytes)
118122
return new BooleanRLEValuesDecoder(ByteBuffer.wrap(pageBytes));
119123
}
120124

125+
private static ShortDecimalValuesDecoder int32ShortDecimalPlain(byte[] pageBytes)
126+
{
127+
return new Int32ShortDecimalPlainValuesDecoder(pageBytes, 0, pageBytes.length);
128+
}
129+
130+
private static ShortDecimalValuesDecoder int64ShortDecimalPlain(byte[] pageBytes)
131+
{
132+
return new Int64ShortDecimalPlainValuesDecoder(pageBytes, 0, pageBytes.length);
133+
}
134+
135+
private static ShortDecimalValuesDecoder int32ShortDecimalRLE(byte[] pageBytes, int dictionarySize, IntegerDictionary dictionary)
136+
{
137+
return new Int32ShortDecimalRLEDictionaryValuesDecoder(getWidthFromMaxInt(dictionarySize), new ByteArrayInputStream(pageBytes), dictionary);
138+
}
139+
140+
private static ShortDecimalValuesDecoder int64ShortDecimalRLE(byte[] pageBytes, int dictionarySize, LongDictionary dictionary)
141+
{
142+
return new Int64RLEDictionaryValuesDecoder(getWidthFromMaxInt(dictionarySize), new ByteArrayInputStream(pageBytes), dictionary);
143+
}
144+
121145
private static void int32BatchReadWithSkipHelper(int batchSize, int skipSize, int valueCount, Int32ValuesDecoder decoder, List<Object> expectedValues)
122146
throws IOException
123147
{
@@ -213,6 +237,52 @@ private static void int64BatchReadWithSkipHelper(int batchSize, int skipSize, in
213237
}
214238
}
215239

240+
private static void int32ShortDecimalBatchReadWithSkipHelper(int batchSize, int skipSize, int valueCount, ShortDecimalValuesDecoder decoder, List<Object> expectedValues)
241+
throws IOException
242+
{
243+
long[] actualValues = new long[valueCount];
244+
int inputOffset = 0;
245+
int outputOffset = 0;
246+
while (inputOffset < valueCount) {
247+
int readBatchSize = min(batchSize, valueCount - inputOffset);
248+
decoder.readNext(actualValues, outputOffset, readBatchSize);
249+
250+
for (int i = 0; i < readBatchSize; i++) {
251+
assertEquals(actualValues[outputOffset + i], (int) expectedValues.get(inputOffset + i));
252+
}
253+
254+
inputOffset += readBatchSize;
255+
outputOffset += readBatchSize;
256+
257+
int skipBatchSize = min(skipSize, valueCount - inputOffset);
258+
decoder.skip(skipBatchSize);
259+
inputOffset += skipBatchSize;
260+
}
261+
}
262+
263+
private static void int64ShortDecimalBatchReadWithSkipHelper(int batchSize, int skipSize, int valueCount, ShortDecimalValuesDecoder decoder, List<Object> expectedValues)
264+
throws IOException
265+
{
266+
long[] actualValues = new long[valueCount];
267+
int inputOffset = 0;
268+
int outputOffset = 0;
269+
while (inputOffset < valueCount) {
270+
int readBatchSize = min(batchSize, valueCount - inputOffset);
271+
decoder.readNext(actualValues, outputOffset, readBatchSize);
272+
273+
for (int i = 0; i < readBatchSize; i++) {
274+
assertEquals(actualValues[outputOffset + i], expectedValues.get(inputOffset + i));
275+
}
276+
277+
inputOffset += readBatchSize;
278+
outputOffset += readBatchSize;
279+
280+
int skipBatchSize = min(skipSize, valueCount - inputOffset);
281+
decoder.skip(skipBatchSize);
282+
inputOffset += skipBatchSize;
283+
}
284+
}
285+
216286
private static void timestampBatchReadWithSkipHelper(int batchSize, int skipSize, int valueCount, TimestampValuesDecoder decoder, List<Object> expectedValues)
217287
throws IOException
218288
{
@@ -515,4 +585,100 @@ public void testBooleanRLE()
515585
booleanBatchReadWithSkipHelper(89, 29, valueCount, booleanRLE(dataPage), expectedValues);
516586
booleanBatchReadWithSkipHelper(1024, 1024, valueCount, booleanRLE(dataPage), expectedValues);
517587
}
588+
589+
@Test
590+
public void testInt32ShortDecimalPlain()
591+
throws IOException
592+
{
593+
int valueCount = 2048;
594+
List<Object> expectedValues = new ArrayList<>();
595+
596+
byte[] pageBytes = generatePlainValuesPage(valueCount, 32, new Random(83), expectedValues);
597+
int32ShortDecimalBatchReadWithSkipHelper(valueCount, 0, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues); // read all values in one batch
598+
int32ShortDecimalBatchReadWithSkipHelper(29, 0, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues);
599+
int32ShortDecimalBatchReadWithSkipHelper(89, 0, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues);
600+
int32ShortDecimalBatchReadWithSkipHelper(1024, 0, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues);
601+
602+
int32ShortDecimalBatchReadWithSkipHelper(256, 29, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues);
603+
int32ShortDecimalBatchReadWithSkipHelper(89, 29, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues);
604+
int32ShortDecimalBatchReadWithSkipHelper(1024, 1024, valueCount, int32ShortDecimalPlain(pageBytes), expectedValues);
605+
}
606+
607+
@Test
608+
public void testInt64ShortDecimalPlain()
609+
throws IOException
610+
{
611+
int valueCount = 2048;
612+
List<Object> expectedValues = new ArrayList<>();
613+
614+
byte[] pageBytes = generatePlainValuesPage(valueCount, 64, new Random(83), expectedValues);
615+
int64ShortDecimalBatchReadWithSkipHelper(valueCount, 0, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues); // read all values in one batch
616+
int64ShortDecimalBatchReadWithSkipHelper(29, 0, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues);
617+
int64ShortDecimalBatchReadWithSkipHelper(89, 0, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues);
618+
int64ShortDecimalBatchReadWithSkipHelper(1024, 0, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues);
619+
620+
int64ShortDecimalBatchReadWithSkipHelper(256, 29, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues);
621+
int64ShortDecimalBatchReadWithSkipHelper(89, 29, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues);
622+
int64ShortDecimalBatchReadWithSkipHelper(1024, 1024, valueCount, int64ShortDecimalPlain(pageBytes), expectedValues);
623+
}
624+
625+
@Test
626+
public void testInt32ShortDecimalRLE()
627+
throws IOException
628+
{
629+
Random random = new Random(83);
630+
int valueCount = 2048;
631+
int dictionarySize = 29;
632+
List<Object> dictionary = new ArrayList<>();
633+
List<Integer> dictionaryIds = new ArrayList<>();
634+
635+
byte[] dictionaryPage = generatePlainValuesPage(dictionarySize, 32, random, dictionary);
636+
byte[] dataPage = generateDictionaryIdPage2048(dictionarySize - 1, random, dictionaryIds);
637+
638+
List<Object> expectedValues = new ArrayList<>();
639+
for (Integer dictionaryId : dictionaryIds) {
640+
expectedValues.add(dictionary.get(dictionaryId));
641+
}
642+
643+
IntegerDictionary integerDictionary = new IntegerDictionary(new DictionaryPage(Slices.wrappedBuffer(dictionaryPage), dictionarySize, PLAIN_DICTIONARY));
644+
645+
int32ShortDecimalBatchReadWithSkipHelper(valueCount, 0, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues); // read all values in one batch
646+
int32ShortDecimalBatchReadWithSkipHelper(29, 0, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues);
647+
int32ShortDecimalBatchReadWithSkipHelper(89, 0, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues);
648+
int32ShortDecimalBatchReadWithSkipHelper(1024, 0, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues);
649+
650+
int32ShortDecimalBatchReadWithSkipHelper(256, 29, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues);
651+
int32ShortDecimalBatchReadWithSkipHelper(89, 29, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues);
652+
int32ShortDecimalBatchReadWithSkipHelper(1024, 1024, valueCount, int32ShortDecimalRLE(dataPage, dictionarySize, integerDictionary), expectedValues);
653+
}
654+
655+
@Test
656+
public void testInt64ShortDecimalRLE()
657+
throws IOException
658+
{
659+
Random random = new Random(83);
660+
int valueCount = 2048;
661+
int dictionarySize = 29;
662+
List<Object> dictionary = new ArrayList<>();
663+
List<Integer> dictionaryIds = new ArrayList<>();
664+
665+
byte[] dictionaryPage = generatePlainValuesPage(dictionarySize, 64, random, dictionary);
666+
byte[] dataPage = generateDictionaryIdPage2048(dictionarySize - 1, random, dictionaryIds);
667+
668+
List<Object> expectedValues = new ArrayList<>();
669+
for (Integer dictionaryId : dictionaryIds) {
670+
expectedValues.add(dictionary.get(dictionaryId));
671+
}
672+
673+
LongDictionary longDictionary = new LongDictionary(new DictionaryPage(Slices.wrappedBuffer(dictionaryPage), dictionarySize, PLAIN_DICTIONARY));
674+
675+
int64ShortDecimalBatchReadWithSkipHelper(valueCount, 0, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues); // read all values in one batch
676+
int64ShortDecimalBatchReadWithSkipHelper(29, 0, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues);
677+
int64ShortDecimalBatchReadWithSkipHelper(89, 0, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues);
678+
int64ShortDecimalBatchReadWithSkipHelper(1024, 0, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues);
679+
680+
int64ShortDecimalBatchReadWithSkipHelper(256, 29, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues);
681+
int64ShortDecimalBatchReadWithSkipHelper(89, 29, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues);
682+
int64ShortDecimalBatchReadWithSkipHelper(1024, 1024, valueCount, int64ShortDecimalRLE(dataPage, dictionarySize, longDictionary), expectedValues);
683+
}
518684
}

0 commit comments

Comments
 (0)