Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,13 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative(
tracing_enabled != JNI_FALSE,
|| {
// SAFETY: JVM unsafe memory allocation is aligned with long.
debug_assert!(address != 0, "sortRowPartitionsNative: null address");
debug_assert!(size >= 0, "sortRowPartitionsNative: negative size {size}");
debug_assert_eq!(
(address as usize) % std::mem::align_of::<i64>(),
0,
"sortRowPartitionsNative: address not aligned to i64"
);
let array =
unsafe { std::slice::from_raw_parts_mut(address as *mut i64, size as usize) };
array.rdxsort();
Expand Down
6 changes: 6 additions & 0 deletions native/core/src/execution/shuffle/spark_unsafe/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ impl SparkUnsafeArray {
pub fn new(addr: i64) -> Self {
// SAFETY: addr points to valid Spark UnsafeArray data from the JVM.
// The first 8 bytes contain the element count as a little-endian i64.
debug_assert!(addr != 0, "SparkUnsafeArray::new: null address");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
let num_elements = i64::from_le_bytes(slice.try_into().unwrap());

Expand Down Expand Up @@ -87,6 +88,11 @@ impl SparkUnsafeArray {
// SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts
// at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures
// index < num_elements, so word_offset is within the bitset region.
debug_assert!(
index < self.num_elements,
"is_null_at: index {index} >= num_elements {}",
self.num_elements
);
unsafe {
let mask: i64 = 1i64 << (index & 0x3f);
let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64;
Expand Down
2 changes: 2 additions & 0 deletions native/core/src/execution/shuffle/spark_unsafe/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ impl SparkUnsafeMap {
pub(crate) fn new(addr: i64, size: i32) -> Self {
// SAFETY: addr points to valid Spark UnsafeMap data from the JVM.
// The first 8 bytes contain the key array size as a little-endian i64.
debug_assert!(addr != 0, "SparkUnsafeMap::new: null address");
debug_assert!(size >= 0, "SparkUnsafeMap::new: negative size {size}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
let key_array_size = i64::from_le_bytes(slice.try_into().unwrap());

Expand Down
75 changes: 75 additions & 0 deletions native/core/src/execution/shuffle/spark_unsafe/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,18 @@ pub trait SparkUnsafeObject {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region.
// The caller ensures index is within bounds.
debug_assert!(
!addr.is_null(),
"get_boolean: null pointer at index {index}"
);
unsafe { *addr != 0 }
}

/// Returns byte value at the given index of the object.
fn get_byte(&self, index: usize) -> i8 {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) };
i8::from_le_bytes(slice.try_into().unwrap())
}
Expand All @@ -107,6 +112,7 @@ pub trait SparkUnsafeObject {
fn get_short(&self, index: usize) -> i16 {
let addr = self.get_element_offset(index, 2);
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) };
i16::from_le_bytes(slice.try_into().unwrap())
}
Expand All @@ -115,6 +121,7 @@ pub trait SparkUnsafeObject {
fn get_int(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
i32::from_le_bytes(slice.try_into().unwrap())
}
Expand All @@ -123,6 +130,7 @@ pub trait SparkUnsafeObject {
fn get_long(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
i64::from_le_bytes(slice.try_into().unwrap())
}
Expand All @@ -131,6 +139,7 @@ pub trait SparkUnsafeObject {
fn get_float(&self, index: usize) -> f32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
f32::from_le_bytes(slice.try_into().unwrap())
}
Expand All @@ -139,6 +148,7 @@ pub trait SparkUnsafeObject {
fn get_double(&self, index: usize) -> f64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
f64::from_le_bytes(slice.try_into().unwrap())
}
Expand All @@ -149,6 +159,11 @@ pub trait SparkUnsafeObject {
let addr = self.get_row_addr() + offset as i64;
// SAFETY: addr points to valid UTF-8 string data within the variable-length region.
// Offset and length are read from the fixed-length portion of the row/array.
debug_assert!(addr != 0, "get_string: null address at index {index}");
debug_assert!(
len >= 0,
"get_string: negative length {len} at index {index}"
);
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) };

from_utf8(slice).unwrap()
Expand All @@ -160,13 +175,19 @@ pub trait SparkUnsafeObject {
let addr = self.get_row_addr() + offset as i64;
// SAFETY: addr points to valid binary data within the variable-length region.
// Offset and length are read from the fixed-length portion of the row/array.
debug_assert!(addr != 0, "get_binary: null address at index {index}");
debug_assert!(
len >= 0,
"get_binary: negative length {len} at index {index}"
);
unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
}

/// Returns date value at the given index of the object.
fn get_date(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
i32::from_le_bytes(slice.try_into().unwrap())
}
Expand All @@ -175,6 +196,10 @@ pub trait SparkUnsafeObject {
fn get_timestamp(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(
!addr.is_null(),
"get_timestamp: null pointer at index {index}"
);
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
i64::from_le_bytes(slice.try_into().unwrap())
}
Expand Down Expand Up @@ -287,6 +312,7 @@ impl SparkUnsafeRow {
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
debug_assert!(self.row_addr != -1, "is_null_at: row not initialized");
unsafe {
let mask: i64 = 1i64 << (index & 0x3f);
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64;
Expand All @@ -300,6 +326,7 @@ impl SparkUnsafeRow {
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
// Writing is safe because we have mutable access and the memory is owned by the JVM.
debug_assert!(self.row_addr != -1, "set_not_null_at: row not initialized");
unsafe {
let mask: i64 = 1i64 << (index & 0x3f);
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64;
Expand Down Expand Up @@ -498,6 +525,18 @@ fn append_columns(
for i in row_start..row_end {
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
debug_assert!(
!row_addresses_ptr.is_null(),
"append_columns: null row_addresses_ptr"
);
debug_assert!(
!row_sizes_ptr.is_null(),
"append_columns: null row_sizes_ptr"
);
debug_assert!(
i < row_end,
"append_columns: index {i} out of bounds (row_end={row_end})"
);
let row_addr = unsafe { *row_addresses_ptr.add(i) };
let row_size = unsafe { *row_sizes_ptr.add(i) };
row.point_to(row_addr, row_size);
Expand Down Expand Up @@ -630,6 +669,18 @@ fn append_columns(
for i in row_start..row_end {
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
debug_assert!(
!row_addresses_ptr.is_null(),
"append_columns: null row_addresses_ptr"
);
debug_assert!(
!row_sizes_ptr.is_null(),
"append_columns: null row_sizes_ptr"
);
debug_assert!(
i < row_end,
"append_columns: index {i} out of bounds (row_end={row_end})"
);
let row_addr = unsafe { *row_addresses_ptr.add(i) };
let row_size = unsafe { *row_sizes_ptr.add(i) };
row.point_to(row_addr, row_size);
Expand All @@ -652,6 +703,18 @@ fn append_columns(
for i in row_start..row_end {
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
debug_assert!(
!row_addresses_ptr.is_null(),
"append_columns: null row_addresses_ptr"
);
debug_assert!(
!row_sizes_ptr.is_null(),
"append_columns: null row_sizes_ptr"
);
debug_assert!(
i < row_end,
"append_columns: index {i} out of bounds (row_end={row_end})"
);
let row_addr = unsafe { *row_addresses_ptr.add(i) };
let row_size = unsafe { *row_sizes_ptr.add(i) };
row.point_to(row_addr, row_size);
Expand Down Expand Up @@ -681,6 +744,18 @@ fn append_columns(
for i in row_start..row_end {
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
debug_assert!(
!row_addresses_ptr.is_null(),
"append_columns: null row_addresses_ptr"
);
debug_assert!(
!row_sizes_ptr.is_null(),
"append_columns: null row_sizes_ptr"
);
debug_assert!(
i < row_end,
"append_columns: index {i} out of bounds (row_end={row_end})"
);
let row_addr = unsafe { *row_addresses_ptr.add(i) };
let row_size = unsafe { *row_sizes_ptr.add(i) };
row.point_to(row_addr, row_size);
Expand Down
10 changes: 10 additions & 0 deletions native/core/src/execution/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ impl SparkArrowConvert for ArrayData {
}
} else {
// SAFETY: `array_ptr` and `schema_ptr` are aligned correctly.
debug_assert_eq!(
array_ptr.align_offset(array_align),
0,
"move_to_spark: array_ptr not aligned"
);
debug_assert_eq!(
schema_ptr.align_offset(schema_align),
0,
"move_to_spark: schema_ptr not aligned"
);
unsafe {
std::ptr::write(array_ptr, FFI_ArrowArray::new(self));
std::ptr::write(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?);
Expand Down
8 changes: 8 additions & 0 deletions native/core/src/jvm_bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,19 @@ impl JVMClasses<'_> {
}

pub fn get() -> &'static JVMClasses<'static> {
debug_assert!(
JVM_CLASSES.get().is_some(),
"JVMClasses::get: not initialized"
);
unsafe { JVM_CLASSES.get_unchecked() }
}

/// Gets the JNIEnv for the current thread.
pub fn get_env() -> CometResult<AttachGuard<'static>> {
debug_assert!(
JAVA_VM.get().is_some(),
"JVMClasses::get_env: JAVA_VM not initialized"
);
unsafe {
let java_vm = JAVA_VM.get_unchecked();
java_vm.attach_current_thread().map_err(|e| {
Expand Down
Loading