Skip to content

Commit 905d2ef

Browse files
committed
wip
1 parent 6dfd041 commit 905d2ef

File tree

2 files changed

+130
-18
lines changed

2 files changed

+130
-18
lines changed

pgvectorscale/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ name = "pgrx_embed_vectorscale"
1111
path = "./src/bin/pgrx_embed.rs"
1212

1313
[features]
14-
default = ["pg17"]
14+
default = ["pg17", "build_parallel"]
1515
pg13 = ["pgrx/pg13", "pgrx-tests/pg13"]
1616
pg14 = ["pgrx/pg14", "pgrx-tests/pg14"]
1717
pg15 = ["pgrx/pg15", "pgrx-tests/pg15"]

pgvectorscale/src/access_method/build.rs

Lines changed: 129 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
use std::time::Instant;
2+
use std::cell::Cell;
23

34
use pg_sys::{FunctionCall0Coll, InvalidOid};
45
use pgrx::ffi::c_char;
56
use pgrx::pg_sys::{index_getprocinfo, pgstat_progress_update_param, AsPgCStr, Oid};
67
use pgrx::*;
78

9+
thread_local! {
10+
static PARALLEL_SHM_TOC: Cell<*mut pg_sys::shm_toc> = Cell::new(std::ptr::null_mut());
11+
}
12+
813
use crate::access_method::distance::DistanceType;
914
use crate::access_method::graph::neighbor_store::GraphNeighborStore;
1015
use crate::access_method::graph::Graph;
@@ -167,7 +172,7 @@ pub extern "C" fn ambuild(
167172

168173
// TODO: unsafe { (*index_info).ii_ParallelWorkers };
169174
let workers = if cfg!(feature = "build_parallel") {
170-
1
175+
2 // TODO
171176
} else {
172177
0
173178
};
@@ -263,6 +268,7 @@ pub extern "C" fn ambuild(
263268
&index_relation,
264269
meta_page,
265270
write_stats,
271+
false,
266272
)
267273
};
268274

@@ -497,12 +503,18 @@ pub extern "C-unwind" fn _vectorscale_build_main(
497503
let index_relation = unsafe { PgRelation::from_pg(indexrel) };
498504
let meta_page = MetaPage::fetch(&index_relation);
499505

506+
// Store the shm_toc in a thread-local variable for access during parallel scan
507+
PARALLEL_SHM_TOC.with(|toc_cell| {
508+
toc_cell.set(shm_toc);
509+
});
510+
500511
let ntuples = do_heap_scan(
501512
index_info,
502513
&heap_relation,
503514
&index_relation,
504515
meta_page,
505516
WriteStats::default(),
517+
true,
506518
);
507519

508520
unsafe {
@@ -517,6 +529,7 @@ fn do_heap_scan(
517529
index_relation: &PgRelation,
518530
mut meta_page: MetaPage,
519531
mut write_stats: WriteStats,
532+
parallel: bool,
520533
) -> usize {
521534
unsafe {
522535
pgstat_progress_update_param(PROGRESS_CREATE_IDX_SUBPHASE, BUILD_PHASE_BUILDING_GRAPH);
@@ -540,14 +553,20 @@ fn do_heap_scan(
540553
let mut bs = BuildState::new(index_relation, graph, page_type);
541554
let mut state = StorageBuildState::Plain(&mut plain, &mut bs);
542555

543-
unsafe {
544-
pg_sys::IndexBuildHeapScan(
545-
heap_relation.as_ptr(),
546-
index_relation.as_ptr(),
547-
index_info,
548-
Some(build_callback),
549-
&mut state,
550-
);
556+
if parallel {
557+
unsafe {
558+
do_parallel_heap_scan(heap_relation, index_relation, index_info, &mut state);
559+
}
560+
} else {
561+
unsafe {
562+
pg_sys::IndexBuildHeapScan(
563+
heap_relation.as_ptr(),
564+
index_relation.as_ptr(),
565+
index_info,
566+
Some(build_callback),
567+
&mut state,
568+
);
569+
}
551570
}
552571

553572
finalize_index_build(&mut plain, bs, index_relation, write_stats)
@@ -566,14 +585,20 @@ fn do_heap_scan(
566585
let mut bs = BuildState::new(index_relation, graph, page_type);
567586
let mut state = StorageBuildState::SbqSpeedup(&mut bq, &mut bs);
568587

569-
unsafe {
570-
pg_sys::IndexBuildHeapScan(
571-
heap_relation.as_ptr(),
572-
index_relation.as_ptr(),
573-
index_info,
574-
Some(build_callback),
575-
&mut state,
576-
);
588+
if parallel {
589+
unsafe {
590+
do_parallel_heap_scan(heap_relation, index_relation, index_info, &mut state);
591+
}
592+
} else {
593+
unsafe {
594+
pg_sys::IndexBuildHeapScan(
595+
heap_relation.as_ptr(),
596+
index_relation.as_ptr(),
597+
index_info,
598+
Some(build_callback),
599+
&mut state,
600+
);
601+
}
577602
}
578603

579604
unsafe {
@@ -645,6 +670,93 @@ fn finalize_index_build<S: Storage>(
645670
ntuples
646671
}
647672

673+
#[cfg(feature = "build_parallel")]
674+
unsafe fn do_parallel_heap_scan(
675+
heap_relation: &PgRelation,
676+
index_relation: &PgRelation,
677+
_index_info: *mut pg_sys::IndexInfo,
678+
state: &mut StorageBuildState,
679+
) {
680+
use pgrx::pg_sys::ScanDirection::ForwardScanDirection;
681+
682+
// Get the parallel table scan descriptor from shared memory
683+
let shm_toc = PARALLEL_SHM_TOC.with(|toc_cell| toc_cell.get());
684+
if shm_toc.is_null() {
685+
panic!("No shared memory TOC available for parallel scan");
686+
}
687+
let tablescandesc: *mut pg_sys::ParallelTableScanDescData =
688+
pg_sys::shm_toc_lookup(
689+
shm_toc,
690+
parallel::SHM_TOC_TABLESCANDESC_KEY,
691+
false
692+
).cast::<pg_sys::ParallelTableScanDescData>();
693+
694+
// Begin the parallel table scan
695+
let scan = pg_sys::table_beginscan_parallel(heap_relation.as_ptr(), tablescandesc);
696+
697+
// Create a tuple table slot for receiving tuples
698+
let slot = pg_sys::table_slot_create(heap_relation.as_ptr(), std::ptr::null_mut());
699+
700+
loop {
701+
// Get next tuple from parallel scan
702+
let has_tuple = pg_sys::table_scan_getnextslot(scan, ForwardScanDirection, slot);
703+
704+
if !has_tuple {
705+
break;
706+
}
707+
708+
// Get the tuple's ctid
709+
let ctid = &(*slot).tts_tid;
710+
711+
// Get the number of attributes
712+
let natts = (*(*heap_relation.as_ptr()).rd_att).natts as usize;
713+
714+
// Allocate arrays for values and nulls
715+
let values = pg_sys::palloc(natts * std::mem::size_of::<pg_sys::Datum>()) as *mut pg_sys::Datum;
716+
let isnull = pg_sys::palloc(natts * std::mem::size_of::<bool>()) as *mut bool;
717+
718+
// Extract all attributes from the tuple slot
719+
pg_sys::slot_getallattrs(slot);
720+
721+
// Copy values and nulls from the slot
722+
for i in 0..natts {
723+
*values.add(i) = (*slot).tts_values.add(i).read();
724+
*isnull.add(i) = (*slot).tts_isnull.add(i).read();
725+
}
726+
727+
// Call the build callback for this tuple
728+
build_callback(
729+
index_relation.as_ptr(),
730+
ctid as *const pg_sys::ItemPointerData as *mut pg_sys::ItemPointerData,
731+
values,
732+
isnull,
733+
true, // tuple_is_alive - assume true for parallel scan
734+
state as *mut StorageBuildState as *mut std::os::raw::c_void,
735+
);
736+
737+
// Clean up allocated memory
738+
pg_sys::pfree(values as *mut std::os::raw::c_void);
739+
pg_sys::pfree(isnull as *mut std::os::raw::c_void);
740+
741+
// Clear the slot for next tuple
742+
pg_sys::ExecClearTuple(slot);
743+
}
744+
745+
// Clean up
746+
pg_sys::ExecDropSingleTupleTableSlot(slot);
747+
pg_sys::table_endscan(scan);
748+
}
749+
750+
#[cfg(not(feature = "build_parallel"))]
751+
unsafe fn do_parallel_heap_scan(
752+
_heap_relation: &PgRelation,
753+
_index_relation: &PgRelation,
754+
_index_info: *mut pg_sys::IndexInfo,
755+
_state: &mut StorageBuildState,
756+
) {
757+
panic!("Parallel build not enabled");
758+
}
759+
648760
#[pg_guard]
649761
unsafe extern "C" fn build_callback_bq_train(
650762
_index: pg_sys::Relation,

0 commit comments

Comments
 (0)