@@ -60,26 +60,27 @@ def gen_csv(part_idx: int, cachedir: str, scale_factor: float, num_parts: int) -
60
60
def pipelined_data_generation (
61
61
scratch_dir : str ,
62
62
scale_factor : float ,
63
- num_parts : int ,
63
+ num_batches : int ,
64
64
aws_s3_sync_location : str ,
65
65
parallelism : int = 4 ,
66
66
rows_per_file : int = 500_000 ,
67
67
) -> None :
68
- assert num_parts > 1 , "script should only be used if num_parts > 1"
69
-
70
68
if aws_s3_sync_location .endswith ("/" ):
71
69
aws_s3_sync_location = aws_s3_sync_location [:- 1 ]
72
70
73
- base_path = pathlib .Path (scratch_dir ) / str (num_parts )
71
+ base_path = pathlib .Path (scratch_dir ) / str (num_batches )
74
72
base_path .mkdir (parents = True , exist_ok = True )
75
73
76
- for i , part_indices in enumerate (batch (range (1 , num_parts + 1 ), n = parallelism )):
74
+ num_dbgen_partitions = num_batches * parallelism
75
+ for batch_idx , part_indices in enumerate (
76
+ batch (range (1 , num_dbgen_partitions + 1 ), n = parallelism )
77
+ ):
77
78
logger .info ("Partition %s: Generating CSV files" , part_indices )
78
79
with Pool (parallelism ) as process_pool :
79
80
process_pool .starmap (
80
81
gen_csv ,
81
82
[
82
- (part_idx , base_path , scale_factor , num_parts )
83
+ (part_idx , base_path , scale_factor , num_dbgen_partitions )
83
84
for part_idx in part_indices
84
85
],
85
86
)
@@ -88,20 +89,13 @@ def pipelined_data_generation(
88
89
for f in csv_files :
89
90
shutil .move (f , base_path / pathlib .Path (f ).name )
90
91
91
- gen_parquet (base_path , rows_per_file , partitioned = True , iteration_offset = i )
92
+ gen_parquet (base_path , rows_per_file , partitioned = True , batch_idx = batch_idx )
92
93
parquet_files = glob .glob (f"{ base_path } /*.parquet" ) # noqa: PTH207
93
94
94
- # Exclude static tables except for first iteration
95
- exclude_static_tables = (
96
- ""
97
- if i == 0
98
- else " " .join ([f'--exclude "*/{ tbl } /*"' for tbl in STATIC_TABLES ])
99
- )
100
-
101
95
if len (aws_s3_sync_location ):
102
96
subprocess .check_output (
103
97
shlex .split (
104
- f'aws s3 sync { scratch_dir } { aws_s3_sync_location } /scale-factor- { scale_factor } --exclude "*" --include "*.parquet" { exclude_static_tables } '
98
+ f'aws s3 sync { scratch_dir } { aws_s3_sync_location } /scale-{ scale_factor } --exclude "*" --include "*.parquet"'
105
99
)
106
100
)
107
101
for parquet_file in parquet_files :
@@ -197,9 +191,12 @@ def gen_parquet(
197
191
base_path : pathlib .Path ,
198
192
rows_per_file : int = 500_000 ,
199
193
partitioned : bool = False ,
200
- iteration_offset : int = 0 ,
194
+ batch_idx : int = 0 ,
201
195
) -> None :
202
196
for table_name , columns in table_columns .items ():
197
+ if table_name in STATIC_TABLES and batch_idx != 0 :
198
+ continue
199
+
203
200
path = base_path / f"{ table_name } .tbl*"
204
201
205
202
lf = pl .scan_csv (
@@ -214,9 +211,18 @@ def gen_parquet(
214
211
lf = lf .select (columns )
215
212
216
213
if partitioned :
217
- (base_path / table_name ).mkdir (parents = True , exist_ok = True )
218
- path = base_path / table_name / f"{ iteration_offset } _{{part}}.parquet"
219
- lf .sink_parquet (pl .PartitionMaxSize (path , max_size = rows_per_file ))
214
+
215
+ def partition_file_name (ctx : pl .BasePartitionContext ) -> pathlib .Path :
216
+ partition = f"{ batch_idx } _{ ctx .file_idx } "
217
+ (base_path / table_name / partition ).mkdir (parents = True , exist_ok = True ) # noqa: B023
218
+ return pathlib .Path (partition ) / "part.parquet"
219
+
220
+ path = base_path / table_name
221
+ lf .sink_parquet (
222
+ pl .PartitionMaxSize (
223
+ path , file_path = partition_file_name , max_size = rows_per_file
224
+ )
225
+ )
220
226
else :
221
227
path = base_path / f"{ table_name } .parquet"
222
228
lf .sink_parquet (path )
@@ -242,7 +248,11 @@ def gen_parquet(
242
248
type = int ,
243
249
)
244
250
parser .add_argument (
245
- "--num-parts" , default = 32 , help = "Number of parts to generate" , type = int
251
+ "--num-batches" ,
252
+ default = None ,
253
+ help = "Number of batches used to generate the data" ,
254
+ type = int ,
255
+ nargs = "?" ,
246
256
)
247
257
parser .add_argument (
248
258
"--aws-s3-sync-location" ,
@@ -257,7 +267,7 @@ def gen_parquet(
257
267
)
258
268
args = parser .parse_args ()
259
269
260
- if args .num_parts == 1 :
270
+ if args .num_batches is None :
261
271
# Assumes the tables are already created by the Makefile
262
272
gen_parquet (
263
273
pathlib .Path (args .tpch_gen_folder ),
@@ -268,7 +278,7 @@ def gen_parquet(
268
278
pipelined_data_generation (
269
279
args .tpch_gen_folder ,
270
280
args .scale_factor ,
271
- args .num_parts ,
281
+ args .num_batches ,
272
282
args .aws_s3_sync_location ,
273
283
parallelism = args .parallelism ,
274
284
rows_per_file = args .rows_per_file ,
0 commit comments