Skip to content

Commit 2c45ac0

Browse files
authored
[spark] Compaction add parallelize parallelism to avoid small partitions (#4158)
1 parent 588d7f2 commit 2c45ac0

File tree

2 files changed

+108
-3
lines changed

2 files changed

+108
-3
lines changed

paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.apache.spark.sql.Dataset;
5757
import org.apache.spark.sql.PaimonUtils;
5858
import org.apache.spark.sql.Row;
59+
import org.apache.spark.sql.SparkSession;
5960
import org.apache.spark.sql.catalyst.InternalRow;
6061
import org.apache.spark.sql.catalyst.expressions.Expression;
6162
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
@@ -66,6 +67,8 @@
6667
import org.apache.spark.sql.types.Metadata;
6768
import org.apache.spark.sql.types.StructField;
6869
import org.apache.spark.sql.types.StructType;
70+
import org.slf4j.Logger;
71+
import org.slf4j.LoggerFactory;
6972

7073
import javax.annotation.Nullable;
7174

@@ -97,6 +100,8 @@
97100
*/
98101
public class CompactProcedure extends BaseProcedure {
99102

103+
private static final Logger LOG = LoggerFactory.getLogger(CompactProcedure.class);
104+
100105
private static final ProcedureParameter[] PARAMETERS =
101106
new ProcedureParameter[] {
102107
ProcedureParameter.required("table", StringType),
@@ -182,7 +187,6 @@ public InternalRow[] call(InternalRow args) {
182187
dynamicOptions.putAll(ParameterUtils.parseCommaSeparatedKeyValues(options));
183188
}
184189
table = table.copy(dynamicOptions);
185-
186190
InternalRow internalRow =
187191
newInternalRow(
188192
execute(
@@ -279,10 +283,11 @@ private void compactAwareBucketTable(
279283
return;
280284
}
281285

286+
int readParallelism = readParallelism(partitionBuckets, spark());
282287
BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder();
283288
JavaRDD<byte[]> commitMessageJavaRDD =
284289
javaSparkContext
285-
.parallelize(partitionBuckets)
290+
.parallelize(partitionBuckets, readParallelism)
286291
.mapPartitions(
287292
(FlatMapFunction<Iterator<Pair<byte[], Integer>>, byte[]>)
288293
pairIterator -> {
@@ -355,6 +360,7 @@ private void compactUnAwareBucketTable(
355360
.collect(Collectors.toList());
356361
}
357362
if (compactionTasks.isEmpty()) {
363+
System.out.println("compaction task is empty.");
358364
return;
359365
}
360366

@@ -368,10 +374,11 @@ private void compactUnAwareBucketTable(
368374
throw new RuntimeException("serialize compaction task failed");
369375
}
370376

377+
int readParallelism = readParallelism(serializedTasks, spark());
371378
String commitUser = createCommitUser(table.coreOptions().toConfiguration());
372379
JavaRDD<byte[]> commitMessageJavaRDD =
373380
javaSparkContext
374-
.parallelize(serializedTasks)
381+
.parallelize(serializedTasks, readParallelism)
375382
.mapPartitions(
376383
(FlatMapFunction<Iterator<byte[]>, byte[]>)
377384
taskIterator -> {
@@ -485,6 +492,22 @@ private Map<BinaryRow, DataSplit[]> packForSort(List<DataSplit> dataSplits) {
485492
list -> list.toArray(new DataSplit[0]))));
486493
}
487494

495+
private int readParallelism(List<?> groupedTasks, SparkSession spark) {
496+
int sparkParallelism =
497+
Math.max(
498+
spark.sparkContext().defaultParallelism(),
499+
spark.sessionState().conf().numShufflePartitions());
500+
int readParallelism = Math.min(groupedTasks.size(), sparkParallelism);
501+
if (sparkParallelism > readParallelism) {
502+
LOG.warn(
503+
String.format(
504+
"Spark default parallelism (%s) is greater than bucket or task parallelism (%s),"
505+
+ "we use %s as the final read parallelism",
506+
sparkParallelism, readParallelism, readParallelism));
507+
}
508+
return readParallelism;
509+
}
510+
488511
@VisibleForTesting
489512
static String toWhere(String partitions) {
490513
List<Map<String, String>> maps = ParameterUtils.getPartitions(partitions.split(";"));

paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.paimon.spark.PaimonSparkTestBase
2424
import org.apache.paimon.table.FileStoreTable
2525
import org.apache.paimon.table.source.DataSplit
2626

27+
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerStageCompleted, SparkListenerStageSubmitted}
2728
import org.apache.spark.sql.{Dataset, Row}
2829
import org.apache.spark.sql.execution.streaming.MemoryStream
2930
import org.apache.spark.sql.streaming.StreamTest
@@ -648,6 +649,87 @@ abstract class CompactProcedureTestBase extends PaimonSparkTestBase with StreamT
648649
}
649650
}
650651

652+
test("Paimon Procedure: test aware-bucket compaction read parallelism") {
653+
spark.sql(s"""
654+
|CREATE TABLE T (id INT, value STRING)
655+
|TBLPROPERTIES ('primary-key'='id', 'bucket'='3', 'write-only'='true')
656+
|""".stripMargin)
657+
658+
val table = loadTable("T")
659+
for (i <- 1 to 10) {
660+
sql(s"INSERT INTO T VALUES ($i, '$i')")
661+
}
662+
assertResult(10)(table.snapshotManager().snapshotCount())
663+
664+
val buckets = table.newSnapshotReader().bucketEntries().asScala.map(_.bucket()).distinct.size
665+
assertResult(3)(buckets)
666+
667+
val taskBuffer = scala.collection.mutable.ListBuffer.empty[Int]
668+
val listener = new SparkListener {
669+
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
670+
taskBuffer += stageSubmitted.stageInfo.numTasks
671+
}
672+
}
673+
674+
try {
675+
spark.sparkContext.addSparkListener(listener)
676+
677+
// spark.default.parallelism cannot be change in spark session
678+
// sparkParallelism is 2, bucket is 3, use 2 as the read parallelism
679+
spark.conf.set("spark.sql.shuffle.partitions", 2)
680+
spark.sql("CALL sys.compact(table => 'T')")
681+
682+
// sparkParallelism is 5, bucket is 3, use 3 as the read parallelism
683+
spark.conf.set("spark.sql.shuffle.partitions", 5)
684+
spark.sql("CALL sys.compact(table => 'T')")
685+
686+
assertResult(Seq(2, 3))(taskBuffer)
687+
} finally {
688+
spark.sparkContext.removeSparkListener(listener)
689+
}
690+
}
691+
692+
test("Paimon Procedure: test unaware-bucket compaction read parallelism") {
693+
spark.sql(s"""
694+
|CREATE TABLE T (id INT, value STRING)
695+
|TBLPROPERTIES ('bucket'='-1', 'write-only'='true')
696+
|""".stripMargin)
697+
698+
val table = loadTable("T")
699+
for (i <- 1 to 12) {
700+
sql(s"INSERT INTO T VALUES ($i, '$i')")
701+
}
702+
assertResult(12)(table.snapshotManager().snapshotCount())
703+
704+
val buckets = table.newSnapshotReader().bucketEntries().asScala.map(_.bucket()).distinct.size
705+
// only has bucket-0
706+
assertResult(1)(buckets)
707+
708+
val taskBuffer = scala.collection.mutable.ListBuffer.empty[Int]
709+
val listener = new SparkListener {
710+
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
711+
taskBuffer += stageSubmitted.stageInfo.numTasks
712+
}
713+
}
714+
715+
try {
716+
spark.sparkContext.addSparkListener(listener)
717+
718+
// spark.default.parallelism cannot be change in spark session
719+
// sparkParallelism is 2, task groups is 6, use 2 as the read parallelism
720+
spark.conf.set("spark.sql.shuffle.partitions", 2)
721+
spark.sql("CALL sys.compact(table => 'T', options => 'compaction.max.file-num=2')")
722+
723+
// sparkParallelism is 5, task groups is 3, use 3 as the read parallelism
724+
spark.conf.set("spark.sql.shuffle.partitions", 5)
725+
spark.sql("CALL sys.compact(table => 'T', options => 'compaction.max.file-num=2')")
726+
727+
assertResult(Seq(2, 3))(taskBuffer)
728+
} finally {
729+
spark.sparkContext.removeSparkListener(listener)
730+
}
731+
}
732+
651733
def lastSnapshotCommand(table: FileStoreTable): CommitKind = {
652734
table.snapshotManager().latestSnapshot().commitKind()
653735
}

0 commit comments

Comments
 (0)