Skip to content
86 changes: 38 additions & 48 deletions pgvectorscale/src/access_method/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,6 @@ unsafe fn aminsert_internal(
}
let vec = vec.unwrap();

// PgVector is not cloneable, but in some cases we need a second copy of it
// for the insert. This is a bit of a hack to get that second copy.
let spare_vec = LabeledVector::from_datums(values, isnull, &meta_page).unwrap();

let heap_pointer = ItemPointer::with_item_pointer_data(*heap_tid);
let mut storage = meta_page.get_storage_type();
let mut stats = InsertStats::default();
Expand All @@ -209,7 +205,6 @@ unsafe fn aminsert_internal(
&plain,
&index_relation,
vec,
spare_vec,
heap_pointer,
&mut meta_page,
&mut stats,
Expand All @@ -226,7 +221,6 @@ unsafe fn aminsert_internal(
&bq,
&index_relation,
vec,
spare_vec,
heap_pointer,
&mut meta_page,
&mut stats,
Expand All @@ -240,7 +234,6 @@ unsafe fn insert_storage<S: Storage>(
storage: &S,
index_relation: &PgRelation,
vector: LabeledVector,
spare_vector: LabeledVector,
heap_pointer: ItemPointer,
meta_page: &mut MetaPage,
stats: &mut InsertStats,
Expand All @@ -256,14 +249,7 @@ unsafe fn insert_storage<S: Storage>(
);

let mut graph = Graph::new(GraphNeighborStore::Disk, meta_page);
graph.insert(
index_relation,
index_pointer,
vector,
spare_vector,
storage,
stats,
);
graph.insert(index_relation, index_pointer, vector, storage, stats);
}

#[pg_guard]
Expand Down Expand Up @@ -488,33 +474,13 @@ unsafe extern "C" fn build_callback(
StorageBuildState::SbqSpeedup(bq, state) => {
let vec = LabeledVector::from_datums(values, isnull, state.graph.get_meta_page());
if let Some(vec) = vec {
let spare_vec =
LabeledVector::from_datums(values, isnull, state.graph.get_meta_page())
.unwrap();
build_callback_memory_wrapper(
&index_relation,
heap_pointer,
vec,
spare_vec,
state,
*bq,
);
build_callback_memory_wrapper(&index_relation, heap_pointer, vec, state, *bq);
}
}
StorageBuildState::Plain(plain, state) => {
let vec = LabeledVector::from_datums(values, isnull, state.graph.get_meta_page());
if let Some(vec) = vec {
let spare_vec =
LabeledVector::from_datums(values, isnull, state.graph.get_meta_page())
.unwrap();
build_callback_memory_wrapper(
&index_relation,
heap_pointer,
vec,
spare_vec,
state,
*plain,
);
build_callback_memory_wrapper(&index_relation, heap_pointer, vec, state, *plain);
}
}
}
Expand All @@ -525,13 +491,12 @@ unsafe fn build_callback_memory_wrapper<S: Storage>(
index: &PgRelation,
heap_pointer: ItemPointer,
vector: LabeledVector,
spare_vector: LabeledVector,
state: &mut BuildState,
storage: &mut S,
) {
let mut old_context = state.memcxt.set_as_current();

build_callback_internal(index, heap_pointer, vector, spare_vector, state, storage);
build_callback_internal(index, heap_pointer, vector, state, storage);

old_context.set_as_current();
state.memcxt.reset();
Expand All @@ -542,7 +507,6 @@ fn build_callback_internal<S: Storage>(
index: &PgRelation,
heap_pointer: ItemPointer,
vector: LabeledVector,
spare_vector: LabeledVector,
state: &mut BuildState,
storage: &mut S,
) {
Expand Down Expand Up @@ -571,14 +535,9 @@ fn build_callback_internal<S: Storage>(
&mut state.stats,
);

state.graph.insert(
index,
index_pointer,
vector,
spare_vector,
storage,
&mut state.stats,
);
state
.graph
.insert(index, index_pointer, vector, storage, &mut state.stats);
}

const BUILD_PHASE_TRAINING: i64 = 0;
Expand Down Expand Up @@ -1386,4 +1345,35 @@ pub mod tests {
Spi::run("DROP TABLE test_data CASCADE;")?;
Ok(())
}

#[pg_test]
pub unsafe fn test_null_vector_scan() -> spi::Result<()> {
// Test for issue #238 - NULL vectors should not crash index scans
// Instead the index scan should return all vectors in some arbitrary order.

Spi::run(
"CREATE TABLE test(embedding vector(3));

CREATE INDEX idxtest
ON test
USING diskann(embedding vector_l2_ops)
WITH (num_neighbors=10, search_list_size=10);

INSERT INTO test(embedding) VALUES ('[1,1,1]'), ('[2,2,2]'), ('[3,3,3]');
",
)?;

// Scan the table with a NULL vector - this should not crash
// The main goal is to verify NULL vector handling doesn't cause segfaults
let count: Option<i64> = Spi::get_one(
"set enable_seqscan = 0;
SELECT COUNT(*) FROM (SELECT embedding FROM test ORDER BY embedding <-> NULL LIMIT 3) t;",
)?;

// Should return 3 rows (all vectors, since the index scan completes successfully)
assert_eq!(count, Some(3));
// Clean up
Spi::run("DROP TABLE test CASCADE;")?;
Ok(())
}
}
3 changes: 1 addition & 2 deletions pgvectorscale/src/access_method/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,6 @@ digraph G {
index: &PgRelation,
index_pointer: IndexPointer,
vec: LabeledVector,
spare_vec: LabeledVector,
storage: &S,
stats: &mut InsertStats,
) {
Expand All @@ -653,7 +652,7 @@ digraph G {

if vec.labels().is_some() {
// Insert starting from label start nodes and apply label filtering
self.insert_internal(index_pointer, spare_vec, false, storage, stats);
self.insert_internal(index_pointer, vec.clone(), false, storage, stats);
}

// Insert starting from default start node and avoid label filtering
Expand Down
6 changes: 5 additions & 1 deletion pgvectorscale/src/access_method/labels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ impl Debug for ArchivedLabelSet {
}

/// A labeled vector is a vector with an optional set of labels.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct LabeledVector {
vec: PgVector,
labels: Option<LabelSet>,
Expand Down Expand Up @@ -211,6 +211,10 @@ impl LabeledVector {
orderbys: &[ScanKeyData],
meta_page: &MetaPage,
) -> Self {
if orderbys[0].sk_argument.is_null() {
return Self::new(PgVector::zeros(meta_page), None);
}

let query = unsafe {
PgVector::from_datum(
orderbys[0].sk_argument,
Expand Down
114 changes: 113 additions & 1 deletion pgvectorscale/src/access_method/pg_vector.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use pgrx::*;
use std::mem::MaybeUninit;

use crate::access_method::distance::DistanceType;

Expand All @@ -10,7 +11,7 @@ use super::{distance::preprocess_cosine, meta_page};
pub struct PgVectorInternal {
vl_len_: i32, /* varlena header (do not touch directly!) */
pub dim: i16, /* number of dimensions */
unused: i16,
unused: MaybeUninit<i16>,
pub x: pg_sys::__IncompleteArrayField<std::os::raw::c_float>,
}

Expand Down Expand Up @@ -50,6 +51,53 @@ impl Drop for PgVector {
}

impl PgVector {
/// Creates a zero-filled PgVector with the specified dimensions
pub fn zeros(meta_page: &meta_page::MetaPage) -> Self {
let num_dimensions = meta_page.get_num_dimensions();
let num_dimensions_to_index = meta_page.get_num_dimensions_to_index();

unsafe {
if num_dimensions == num_dimensions_to_index {
// Optimization: same pointer for both index and full distance
let inner = Self::create_zeros_inner(num_dimensions as i16);
PgVector {
index_distance: Some(inner),
index_distance_needs_pfree: true,
full_distance: Some(inner),
full_distance_needs_pfree: false,
}
} else {
// Different dimensions for index vs full
let index_inner = Self::create_zeros_inner(num_dimensions_to_index as i16);
let full_inner = Self::create_zeros_inner(num_dimensions as i16);
PgVector {
index_distance: Some(index_inner),
index_distance_needs_pfree: true,
full_distance: Some(full_inner),
full_distance_needs_pfree: true,
}
}
}
}

unsafe fn create_zeros_inner(dimensions: i16) -> *mut PgVectorInternal {
// Calculate total size needed: header + array of f32s
let header_size = std::mem::size_of::<PgVectorInternal>();
let array_size = dimensions as usize * std::mem::size_of::<f32>();
let total_size = header_size + array_size;

// Allocate PostgreSQL memory
let ptr = pg_sys::palloc0(total_size) as *mut PgVectorInternal;

// Initialize the header
(*ptr).vl_len_ = total_size as i32;
(*ptr).dim = dimensions;
(*ptr).unused = MaybeUninit::new(0);

// The array is already zero-filled due to palloc0
ptr
}

/// # Safety
///
/// TODO
Expand Down Expand Up @@ -117,6 +165,8 @@ impl PgVector {
index_distance: bool,
full_distance: bool,
) -> PgVector {
assert!(!datum.is_null(), "Datum should not be NULL");

if meta_page.get_num_dimensions() == meta_page.get_num_dimensions_to_index() {
/* optimization if the num dimensions are the same */
let inner = Self::create_inner(datum, meta_page, true);
Expand Down Expand Up @@ -156,3 +206,65 @@ impl PgVector {
unsafe { (*self.full_distance.unwrap()).to_slice() }
}
}

impl Clone for PgVector {
fn clone(&self) -> Self {
unsafe {
let index_distance = self
.index_distance
.map(|original| Self::clone_inner(original));

let full_distance = if let Some(original) = self.full_distance {
// Check if full_distance points to the same memory as index_distance
if self.index_distance.is_some()
&& std::ptr::eq(original, self.index_distance.unwrap())
{
// Reuse the same cloned pointer
index_distance
} else {
// Clone separately
Some(Self::clone_inner(original))
}
} else {
None
};

PgVector {
index_distance,
index_distance_needs_pfree: index_distance.is_some(),
full_distance,
full_distance_needs_pfree: full_distance.is_some()
&& !std::ptr::eq(
full_distance.unwrap_or(std::ptr::null_mut()),
index_distance.unwrap_or(std::ptr::null_mut()),
),
}
}
}
}

impl PgVector {
unsafe fn clone_inner(original: *mut PgVectorInternal) -> *mut PgVectorInternal {
let dim = (*original).dim;
let slice = (*original).to_slice();

// Calculate total size needed: header + array of f32s
let header_size = std::mem::size_of::<PgVectorInternal>();
let array_size = dim as usize * std::mem::size_of::<f32>();
let total_size = header_size + array_size;

// Allocate new PostgreSQL memory
let new_ptr = pg_sys::palloc(total_size) as *mut PgVectorInternal;

// Copy the header
(*new_ptr).vl_len_ = (*original).vl_len_;
(*new_ptr).dim = dim;
(*new_ptr).unused = (*original).unused;

// Copy the vector data
let new_slice = (*new_ptr).x.as_mut_slice(dim as _);
new_slice.copy_from_slice(slice);

new_ptr
}
}