Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pgvectorscale/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pg15 = ["pgrx/pg15", "pgrx-tests/pg15"]
pg16 = ["pgrx/pg16", "pgrx-tests/pg16"]
pg17 = ["pgrx/pg17", "pgrx-tests/pg17"]
pg_test = []
build_parallel = []

[lints.rust]
unexpected_cfgs = { level = "allow", check-cfg = [
Expand Down
241 changes: 215 additions & 26 deletions pgvectorscale/src/access_method/build.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::time::Instant;

use pg_sys::{FunctionCall0Coll, InvalidOid};
use pgrx::pg_sys::{index_getprocinfo, pgstat_progress_update_param, AsPgCStr};
use pgrx::ffi::c_char;
use pgrx::pg_sys::{index_getprocinfo, pgstat_progress_update_param, AsPgCStr, Oid};
use pgrx::*;

use crate::access_method::distance::DistanceType;
Expand Down Expand Up @@ -29,6 +30,8 @@ use super::plain::storage::PlainStorage;
use super::sbq::SbqMeans;
use super::storage::{Storage, StorageType};

mod parallel;

struct SbqTrainState<'a, 'b> {
quantizer: &'a mut SbqQuantizer,
meta_page: &'b MetaPage,
Expand Down Expand Up @@ -71,24 +74,28 @@ pub const MAX_DIMENSION: u32 = 16000;
/// using the SBQ storage type.
pub const MAX_DIMENSION_NO_SBQ: u32 = 2000;

#[pg_guard]
pub extern "C" fn ambuild(
heaprel: pg_sys::Relation,
indexrel: pg_sys::Relation,
index_info: *mut pg_sys::IndexInfo,
) -> *mut pg_sys::IndexBuildResult {
let heap_relation = unsafe { PgRelation::from_pg(heaprel) };
let index_relation = unsafe { PgRelation::from_pg(indexrel) };
let opt = TSVIndexOptions::from_relation(&index_relation);
/// Data about parallel index build that never changes.
#[derive(Debug, Copy, Clone)]
#[cfg_attr(not(feature = "build_parallel"), allow(dead_code))]
struct ParallelSharedParams {
heaprelid: Oid,
indexrelid: Oid,
is_concurrent: bool,
}

notice!(
"Starting index build with num_neighbors={}, search_list_size={}, max_alpha={}, storage_layout={:?}.",
opt.get_num_neighbors(),
opt.search_list_size,
opt.max_alpha,
opt.get_storage_type(),
);
/// Status data for parallel index builds, shared among all parallel workers.
#[derive(Debug)]
#[cfg_attr(not(feature = "build_parallel"), allow(dead_code))]
struct ParallelShared {
params: ParallelSharedParams,
ntuples: usize,
}

fn get_meta_page(
indexrel: pg_sys::Relation,
index_relation: &PgRelation,
opt: PgBox<TSVIndexOptions>,
) -> MetaPage {
let dimensions = index_relation.tuple_desc().get(0).unwrap().atttypmod;

let distance_type = unsafe {
Expand All @@ -104,7 +111,7 @@ pub extern "C" fn ambuild(
error!("Inner product distance type is not supported with plain storage");
}

let mut meta_page =
let meta_page =
unsafe { MetaPage::create(&index_relation, dimensions as _, distance_type, opt) };

if meta_page.get_num_dimensions_to_index() == 0 {
Expand All @@ -128,15 +135,136 @@ pub extern "C" fn ambuild(
error!("Labeled filtering is not supported with plain storage");
}

meta_page
}

#[pg_guard]
pub extern "C" fn ambuild(
heaprel: pg_sys::Relation,
indexrel: pg_sys::Relation,
index_info: *mut pg_sys::IndexInfo,
) -> *mut pg_sys::IndexBuildResult {
let heap_relation = unsafe { PgRelation::from_pg(heaprel) };
let index_relation = unsafe { PgRelation::from_pg(indexrel) };
let opt = TSVIndexOptions::from_relation(&index_relation);
let mut meta_page = get_meta_page(indexrel, &index_relation, opt);
let opt = TSVIndexOptions::from_relation(&index_relation);

notice!(
"Starting index build with num_neighbors={}, search_list_size={}, max_alpha={}, storage_layout={:?}.",
opt.get_num_neighbors(),
opt.search_list_size,
opt.max_alpha,
opt.get_storage_type(),
);

// Train quantizer before doing anything in parallel
let write_stats =
maybe_train_quantizer(index_info, &heap_relation, &index_relation, &mut meta_page);
let ntuples = do_heap_scan(
index_info,
&heap_relation,
&index_relation,
meta_page,
write_stats,
);
unsafe {
meta_page.store(&index_relation, false);
};

// TODO: unsafe { (*index_info).ii_ParallelWorkers };
let workers = if cfg!(feature = "build_parallel") {
1
} else {
0
};
let is_concurrent = unsafe { (*index_info).ii_Concurrent };
struct ParallelData {
pcxt: *mut pg_sys::ParallelContext,
snapshot: *mut pg_sys::SnapshotData,
}
let parallel_data = if workers > 0 {
notice!("Parallel build with {} workers", workers);
unsafe {
pg_sys::EnterParallelMode();
const EXTENSION_NAME: *const c_char = {
static NAME: &str =
concat!(env!("CARGO_PKG_NAME"), "-", env!("CARGO_PKG_VERSION"), "\0");
NAME.as_ptr() as *const c_char
};

let pcxt = pg_sys::CreateParallelContext(EXTENSION_NAME, PARALLEL_BUILD_MAIN, workers);
let snapshot = if is_concurrent {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any tests today that use CREATE INDEX CONCURRENT?

pg_sys::RegisterSnapshot(pg_sys::GetTransactionSnapshot())
} else {
&raw mut pg_sys::SnapshotAnyData
};

// Estimate things we put in shared memory
parallel::toc_estimate_single_chunk(pcxt, size_of::<ParallelShared>());
let tablescandesc_size_estimate =
pg_sys::table_parallelscan_estimate(heaprel, snapshot);
parallel::toc_estimate_single_chunk(pcxt, tablescandesc_size_estimate);

pg_sys::InitializeParallelDSM(pcxt);
if (*pcxt).seg.is_null() {
parallel::cleanup_pcxt(pcxt, snapshot);
None
} else {
let parallel_shared =
pg_sys::shm_toc_allocate((*pcxt).toc, size_of::<ParallelShared>())
.cast::<ParallelShared>();
parallel_shared.write(ParallelShared {
params: ParallelSharedParams {
heaprelid: heap_relation.rd_id,
indexrelid: index_relation.rd_id,
is_concurrent,
},
ntuples: 0,
});
let tablescandesc =
pg_sys::shm_toc_allocate((*pcxt).toc, tablescandesc_size_estimate)
.cast::<pg_sys::ParallelTableScanDescData>();
pg_sys::table_parallelscan_initialize(heaprel, tablescandesc, snapshot);

pg_sys::shm_toc_insert(
(*pcxt).toc,
parallel::SHM_TOC_SHARED_KEY,
parallel_shared.cast(),
);
pg_sys::shm_toc_insert(
(*pcxt).toc,
parallel::SHM_TOC_TABLESCANDESC_KEY,
tablescandesc.cast(),
);

pg_sys::LaunchParallelWorkers(pcxt);
if (*pcxt).nworkers_launched == 0 {
warning!("No workers launched");
parallel::cleanup_pcxt(pcxt, snapshot);
None
} else {
pg_sys::WaitForParallelWorkersToAttach(pcxt);
Some(ParallelData { pcxt, snapshot })
}
}
}
} else {
None
};

let ntuples = if let Some(ParallelData { pcxt, snapshot }) = parallel_data {
unsafe {
pg_sys::WaitForParallelWorkersToFinish(pcxt);
let parallel_shared: *mut ParallelShared =
pg_sys::shm_toc_lookup((*pcxt).toc, parallel::SHM_TOC_SHARED_KEY, false)
.cast::<ParallelShared>();
let ntuples = (*parallel_shared).ntuples;
parallel::cleanup_pcxt(pcxt, snapshot);
ntuples
}
} else {
do_heap_scan(
index_info,
&heap_relation,
&index_relation,
meta_page,
write_stats,
)
};

let mut result = unsafe { PgBox::<pg_sys::IndexBuildResult>::alloc0() };
result.heap_tuples = ntuples as f64;
Expand Down Expand Up @@ -322,6 +450,67 @@ fn maybe_train_quantizer(
write_stats
}

const PARALLEL_BUILD_MAIN: *const c_char = c"_vectorscale_build_main".as_ptr();
#[pg_guard]
#[unsafe(no_mangle)]
#[cfg(feature = "build_parallel")]
pub extern "C-unwind" fn _vectorscale_build_main(
_seg: *mut pg_sys::dsm_segment,
shm_toc: *mut pg_sys::shm_toc,
) {
let status_flags = unsafe { (*pg_sys::MyProc).statusFlags };
assert!(
status_flags == 0 || status_flags == pg_sys::PROC_IN_SAFE_IC as u8,
"Status flags for an index build process must be unset or PROC_IN_SAFE_IC (in a safe index creation)"
);

let parallel_shared: *mut ParallelShared = unsafe {
pg_sys::shm_toc_lookup(shm_toc, parallel::SHM_TOC_SHARED_KEY, false)
.cast::<ParallelShared>()
};
let _tablescandesc = unsafe {
pg_sys::shm_toc_lookup(shm_toc, parallel::SHM_TOC_TABLESCANDESC_KEY, false)
.cast::<pg_sys::ParallelTableScanDescData>()
};

let params = unsafe {
// SAFETY: these parameters never change, so no data races
(*parallel_shared).params
};

let (heap_lockmode, index_lockmode) = if params.is_concurrent {
(
pg_sys::ShareLock as pg_sys::LOCKMODE,
pg_sys::AccessExclusiveLock as pg_sys::LOCKMODE,
)
} else {
(
pg_sys::ShareUpdateExclusiveLock as pg_sys::LOCKMODE,
pg_sys::RowExclusiveLock as pg_sys::LOCKMODE,
)
};

let heaprel = unsafe { pg_sys::table_open(params.heaprelid, heap_lockmode) };
let indexrel = unsafe { pg_sys::index_open(params.indexrelid, index_lockmode) };
let index_info = unsafe { pg_sys::BuildIndexInfo(indexrel) };
let heap_relation = unsafe { PgRelation::from_pg(heaprel) };
let index_relation = unsafe { PgRelation::from_pg(indexrel) };
let meta_page = MetaPage::fetch(&index_relation);

let ntuples = do_heap_scan(
index_info,
&heap_relation,
&index_relation,
meta_page,
WriteStats::default(),
);

unsafe {
// SAFETY: nobody reads this until all the parallel workers are done
(*parallel_shared).ntuples = ntuples;
}
}

fn do_heap_scan(
index_info: *mut pg_sys::IndexInfo,
heap_relation: &PgRelation,
Expand Down Expand Up @@ -1196,7 +1385,7 @@ pub mod tests {
id int,
embedding vector ({dimensions})
);

CREATE INDEX idx_diskann_bq ON test_data USING diskann (embedding) WITH ({index_options});

select setseed(0.5);
Expand Down
31 changes: 31 additions & 0 deletions pgvectorscale/src/access_method/build/parallel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::util::ports;
use pgrx::pg_sys;

pub const SHM_TOC_SHARED_KEY: u64 = 0xD000000000000001;
pub const SHM_TOC_TABLESCANDESC_KEY: u64 = 0xD000000000000002;

/// Is a snapshop MVCC-safe? (This should really be a part of pgrx)
pub unsafe fn is_mvcc_snapshot(snapshot: *mut pg_sys::SnapshotData) -> bool {
let typ = (*snapshot).snapshot_type;
typ == pg_sys::SnapshotType::SNAPSHOT_MVCC
|| typ == pg_sys::SnapshotType::SNAPSHOT_HISTORIC_MVCC
}

/// Cleans up a parallel context when we're done with it.
pub unsafe fn cleanup_pcxt(
pcxt: *mut pg_sys::ParallelContext,
snapshot: *mut pg_sys::SnapshotData,
) {
// need DSM segment to do parallel build
if is_mvcc_snapshot(snapshot) {
pg_sys::UnregisterSnapshot(snapshot);
}
pg_sys::DestroyParallelContext(pcxt);
pg_sys::ExitParallelMode();
}

/// Estimate a single chunk in the shared memory TOC.
pub unsafe fn toc_estimate_single_chunk(pcxt: *mut pg_sys::ParallelContext, size: usize) {
(*pcxt).estimator.space_for_chunks += ports::buffer_align(size);
(*pcxt).estimator.number_of_keys += 1;
}
5 changes: 5 additions & 0 deletions pgvectorscale/src/access_method/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ fn amhandler(_fcinfo: pg_sys::FunctionCallInfo) -> PgBox<pg_sys::IndexAmRoutine>

amroutine.ambuildphasename = Some(build::ambuildphasename);

#[cfg(all(feature = "pg17", feature = "build_parallel"))]
{
amroutine.amcanbuildparallel = true;
}

amroutine.into_pg_boxed()
}

Expand Down
6 changes: 5 additions & 1 deletion pgvectorscale/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ pub mod pg_test {
// perform one-off initialization when the pg_test framework starts
}

#[cfg(feature = "build_parallel")]
pub fn postgresql_conf_options() -> Vec<&'static str> {
vec!["maintenance_work_mem = '640MB'"]
}
#[cfg(not(feature = "build_parallel"))]
pub fn postgresql_conf_options() -> Vec<&'static str> {
// return any postgresql.conf settings that are required for your tests
vec![]
}
}
8 changes: 8 additions & 0 deletions pgvectorscale/src/util/ports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,11 @@ pub unsafe fn pgstat_count_index_scan(index_relation: pg_sys::Relation, indexrel
}
}
}

/// Reimplementation of Postgres BUFFERALIGN macro.
pub fn buffer_align(len: usize) -> usize {
unsafe {
// SAFETY: TYPEALIGN is just arithmetic, it shouldn't be marked as unsafe
pg_sys::TYPEALIGN(pg_sys::ALIGNOF_BUFFER as usize, len)
}
}