From 1ce1de7ab6c07352c5e8afc0efb587c2e7096f38 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 10 Mar 2026 08:02:25 -0600 Subject: [PATCH] feat: add debug assertions before unsafe code blocks Add debug_assert! statements between SAFETY comments and unsafe blocks to catch precondition violations during development and testing. Assertions cover: - Null pointer checks before raw pointer dereference - Index bounds checks before array/bitset access - Initialization checks before accessing global singletons - Alignment checks before aligned pointer writes - Non-negative size/length checks before slice construction --- native/core/src/execution/jni_api.rs | 7 ++ .../execution/shuffle/spark_unsafe/list.rs | 6 ++ .../src/execution/shuffle/spark_unsafe/map.rs | 2 + .../src/execution/shuffle/spark_unsafe/row.rs | 75 +++++++++++++++++++ native/core/src/execution/utils.rs | 10 +++ native/core/src/jvm_bridge/mod.rs | 8 ++ 6 files changed, 108 insertions(+) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 116d8167a4..49f723a8c8 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -849,6 +849,13 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( tracing_enabled != JNI_FALSE, || { // SAFETY: JVM unsafe memory allocation is aligned with long. + debug_assert!(address != 0, "sortRowPartitionsNative: null address"); + debug_assert!(size >= 0, "sortRowPartitionsNative: negative size {size}"); + debug_assert_eq!( + (address as usize) % std::mem::align_of::(), + 0, + "sortRowPartitionsNative: address not aligned to i64" + ); let array = unsafe { std::slice::from_raw_parts_mut(address as *mut i64, size as usize) }; array.rdxsort(); diff --git a/native/core/src/execution/shuffle/spark_unsafe/list.rs b/native/core/src/execution/shuffle/spark_unsafe/list.rs index 259ff29a79..9e58c71d30 100644 --- a/native/core/src/execution/shuffle/spark_unsafe/list.rs +++ b/native/core/src/execution/shuffle/spark_unsafe/list.rs @@ -53,6 +53,7 @@ impl SparkUnsafeArray { pub fn new(addr: i64) -> Self { // SAFETY: addr points to valid Spark UnsafeArray data from the JVM. // The first 8 bytes contain the element count as a little-endian i64. + debug_assert!(addr != 0, "SparkUnsafeArray::new: null address"); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; let num_elements = i64::from_le_bytes(slice.try_into().unwrap()); @@ -87,6 +88,11 @@ impl SparkUnsafeArray { // SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts // at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures // index < num_elements, so word_offset is within the bitset region. + debug_assert!( + index < self.num_elements, + "is_null_at: index {index} >= num_elements {}", + self.num_elements + ); unsafe { let mask: i64 = 1i64 << (index & 0x3f); let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64; diff --git a/native/core/src/execution/shuffle/spark_unsafe/map.rs b/native/core/src/execution/shuffle/spark_unsafe/map.rs index dbb5b404aa..19b67c43dc 100644 --- a/native/core/src/execution/shuffle/spark_unsafe/map.rs +++ b/native/core/src/execution/shuffle/spark_unsafe/map.rs @@ -32,6 +32,8 @@ impl SparkUnsafeMap { pub(crate) fn new(addr: i64, size: i32) -> Self { // SAFETY: addr points to valid Spark UnsafeMap data from the JVM. // The first 8 bytes contain the key array size as a little-endian i64. + debug_assert!(addr != 0, "SparkUnsafeMap::new: null address"); + debug_assert!(size >= 0, "SparkUnsafeMap::new: negative size {size}"); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; let key_array_size = i64::from_le_bytes(slice.try_into().unwrap()); diff --git a/native/core/src/execution/shuffle/spark_unsafe/row.rs b/native/core/src/execution/shuffle/spark_unsafe/row.rs index fe80d56d9a..7962caacef 100644 --- a/native/core/src/execution/shuffle/spark_unsafe/row.rs +++ b/native/core/src/execution/shuffle/spark_unsafe/row.rs @@ -92,6 +92,10 @@ pub trait SparkUnsafeObject { let addr = self.get_element_offset(index, 1); // SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region. // The caller ensures index is within bounds. + debug_assert!( + !addr.is_null(), + "get_boolean: null pointer at index {index}" + ); unsafe { *addr != 0 } } @@ -99,6 +103,7 @@ pub trait SparkUnsafeObject { fn get_byte(&self, index: usize) -> i8 { let addr = self.get_element_offset(index, 1); // SAFETY: addr points to valid element data (1 byte) within the row/array region. + debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}"); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) }; i8::from_le_bytes(slice.try_into().unwrap()) } @@ -107,6 +112,7 @@ pub trait SparkUnsafeObject { fn get_short(&self, index: usize) -> i16 { let addr = self.get_element_offset(index, 2); // SAFETY: addr points to valid element data (2 bytes) within the row/array region. + debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}"); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) }; i16::from_le_bytes(slice.try_into().unwrap()) } @@ -115,6 +121,7 @@ pub trait SparkUnsafeObject { fn get_int(&self, index: usize) -> i32 { let addr = self.get_element_offset(index, 4); // SAFETY: addr points to valid element data (4 bytes) within the row/array region. + debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}"); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; i32::from_le_bytes(slice.try_into().unwrap()) } @@ -123,6 +130,7 @@ pub trait SparkUnsafeObject { fn get_long(&self, index: usize) -> i64 { let addr = self.get_element_offset(index, 8); // SAFETY: addr points to valid element data (8 bytes) within the row/array region. + debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}"); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; i64::from_le_bytes(slice.try_into().unwrap()) } @@ -131,6 +139,7 @@ pub trait SparkUnsafeObject { fn get_float(&self, index: usize) -> f32 { let addr = self.get_element_offset(index, 4); // SAFETY: addr points to valid element data (4 bytes) within the row/array region. + debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}"); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; f32::from_le_bytes(slice.try_into().unwrap()) } @@ -139,6 +148,7 @@ pub trait SparkUnsafeObject { fn get_double(&self, index: usize) -> f64 { let addr = self.get_element_offset(index, 8); // SAFETY: addr points to valid element data (8 bytes) within the row/array region. + debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}"); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; f64::from_le_bytes(slice.try_into().unwrap()) } @@ -149,6 +159,11 @@ pub trait SparkUnsafeObject { let addr = self.get_row_addr() + offset as i64; // SAFETY: addr points to valid UTF-8 string data within the variable-length region. // Offset and length are read from the fixed-length portion of the row/array. + debug_assert!(addr != 0, "get_string: null address at index {index}"); + debug_assert!( + len >= 0, + "get_string: negative length {len} at index {index}" + ); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }; from_utf8(slice).unwrap() @@ -160,6 +175,11 @@ pub trait SparkUnsafeObject { let addr = self.get_row_addr() + offset as i64; // SAFETY: addr points to valid binary data within the variable-length region. // Offset and length are read from the fixed-length portion of the row/array. + debug_assert!(addr != 0, "get_binary: null address at index {index}"); + debug_assert!( + len >= 0, + "get_binary: negative length {len} at index {index}" + ); unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) } } @@ -167,6 +187,7 @@ pub trait SparkUnsafeObject { fn get_date(&self, index: usize) -> i32 { let addr = self.get_element_offset(index, 4); // SAFETY: addr points to valid element data (4 bytes) within the row/array region. + debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}"); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; i32::from_le_bytes(slice.try_into().unwrap()) } @@ -175,6 +196,10 @@ pub trait SparkUnsafeObject { fn get_timestamp(&self, index: usize) -> i64 { let addr = self.get_element_offset(index, 8); // SAFETY: addr points to valid element data (8 bytes) within the row/array region. + debug_assert!( + !addr.is_null(), + "get_timestamp: null pointer at index {index}" + ); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; i64::from_le_bytes(slice.try_into().unwrap()) } @@ -287,6 +312,7 @@ impl SparkUnsafeRow { // SAFETY: row_addr points to valid Spark UnsafeRow data with at least // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. // word_offset is within the bitset region since (index >> 6) << 3 < bitset size. + debug_assert!(self.row_addr != -1, "is_null_at: row not initialized"); unsafe { let mask: i64 = 1i64 << (index & 0x3f); let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64; @@ -300,6 +326,7 @@ impl SparkUnsafeRow { // SAFETY: row_addr points to valid Spark UnsafeRow data with at least // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. // Writing is safe because we have mutable access and the memory is owned by the JVM. + debug_assert!(self.row_addr != -1, "set_not_null_at: row not initialized"); unsafe { let mask: i64 = 1i64 << (index & 0x3f); let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64; @@ -498,6 +525,18 @@ fn append_columns( for i in row_start..row_end { // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least // row_end elements. i is in [row_start, row_end) so the offset is in bounds. + debug_assert!( + !row_addresses_ptr.is_null(), + "append_columns: null row_addresses_ptr" + ); + debug_assert!( + !row_sizes_ptr.is_null(), + "append_columns: null row_sizes_ptr" + ); + debug_assert!( + i < row_end, + "append_columns: index {i} out of bounds (row_end={row_end})" + ); let row_addr = unsafe { *row_addresses_ptr.add(i) }; let row_size = unsafe { *row_sizes_ptr.add(i) }; row.point_to(row_addr, row_size); @@ -630,6 +669,18 @@ fn append_columns( for i in row_start..row_end { // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least // row_end elements. i is in [row_start, row_end) so the offset is in bounds. + debug_assert!( + !row_addresses_ptr.is_null(), + "append_columns: null row_addresses_ptr" + ); + debug_assert!( + !row_sizes_ptr.is_null(), + "append_columns: null row_sizes_ptr" + ); + debug_assert!( + i < row_end, + "append_columns: index {i} out of bounds (row_end={row_end})" + ); let row_addr = unsafe { *row_addresses_ptr.add(i) }; let row_size = unsafe { *row_sizes_ptr.add(i) }; row.point_to(row_addr, row_size); @@ -652,6 +703,18 @@ fn append_columns( for i in row_start..row_end { // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least // row_end elements. i is in [row_start, row_end) so the offset is in bounds. + debug_assert!( + !row_addresses_ptr.is_null(), + "append_columns: null row_addresses_ptr" + ); + debug_assert!( + !row_sizes_ptr.is_null(), + "append_columns: null row_sizes_ptr" + ); + debug_assert!( + i < row_end, + "append_columns: index {i} out of bounds (row_end={row_end})" + ); let row_addr = unsafe { *row_addresses_ptr.add(i) }; let row_size = unsafe { *row_sizes_ptr.add(i) }; row.point_to(row_addr, row_size); @@ -681,6 +744,18 @@ fn append_columns( for i in row_start..row_end { // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least // row_end elements. i is in [row_start, row_end) so the offset is in bounds. + debug_assert!( + !row_addresses_ptr.is_null(), + "append_columns: null row_addresses_ptr" + ); + debug_assert!( + !row_sizes_ptr.is_null(), + "append_columns: null row_sizes_ptr" + ); + debug_assert!( + i < row_end, + "append_columns: index {i} out of bounds (row_end={row_end})" + ); let row_addr = unsafe { *row_addresses_ptr.add(i) }; let row_size = unsafe { *row_sizes_ptr.add(i) }; row.point_to(row_addr, row_size); diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 838c8523bb..9e6f2a56e7 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -97,6 +97,16 @@ impl SparkArrowConvert for ArrayData { } } else { // SAFETY: `array_ptr` and `schema_ptr` are aligned correctly. + debug_assert_eq!( + array_ptr.align_offset(array_align), + 0, + "move_to_spark: array_ptr not aligned" + ); + debug_assert_eq!( + schema_ptr.align_offset(schema_align), + 0, + "move_to_spark: schema_ptr not aligned" + ); unsafe { std::ptr::write(array_ptr, FFI_ArrowArray::new(self)); std::ptr::write(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?); diff --git a/native/core/src/jvm_bridge/mod.rs b/native/core/src/jvm_bridge/mod.rs index aa4e71ea11..00fe7b33c3 100644 --- a/native/core/src/jvm_bridge/mod.rs +++ b/native/core/src/jvm_bridge/mod.rs @@ -263,11 +263,19 @@ impl JVMClasses<'_> { } pub fn get() -> &'static JVMClasses<'static> { + debug_assert!( + JVM_CLASSES.get().is_some(), + "JVMClasses::get: not initialized" + ); unsafe { JVM_CLASSES.get_unchecked() } } /// Gets the JNIEnv for the current thread. pub fn get_env() -> CometResult> { + debug_assert!( + JAVA_VM.get().is_some(), + "JVMClasses::get_env: JAVA_VM not initialized" + ); unsafe { let java_vm = JAVA_VM.get_unchecked(); java_vm.attach_current_thread().map_err(|e| {