Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 97 additions & 2 deletions native/core/src/execution/shuffle/spark_unsafe/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ impl SparkUnsafeArray {
// SAFETY: addr points to valid Spark UnsafeArray data from the JVM.
// The first 8 bytes contain the element count as a little-endian i64.
debug_assert!(addr != 0, "SparkUnsafeArray::new: null address");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
let num_elements = i64::from_le_bytes(slice.try_into().unwrap());
// SAFETY: Spark UnsafeArray stores element count as the first 8 bytes.
// Use read_unaligned because nested arrays may not be 8-byte aligned.
let num_elements = unsafe { (addr as *const i64).read_unaligned() };

if num_elements < 0 {
panic!("Negative number of elements: {num_elements}");
Expand Down Expand Up @@ -478,3 +479,97 @@ pub fn append_list_element(

Ok(())
}

#[cfg(test)]
mod test {
use super::*;
use arrow::array::builder::Int32Builder;
use arrow::array::Array;

/// Helper to create a SparkUnsafeArray buffer with i32 elements.
/// Layout: 8 bytes num_elements + null bitset + element data.
fn create_i32_array_buffer(values: &[i32], null_indices: &[usize]) -> Vec<u8> {
let num_elements = values.len();
let null_bitset_words = num_elements.div_ceil(64);
let header_size = 8 + null_bitset_words * 8;
let data_size = num_elements * 4;
let mut buffer = vec![0u8; header_size + data_size];

buffer[0..8].copy_from_slice(&(num_elements as i64).to_le_bytes());

for &idx in null_indices {
let word_offset = 8 + (idx / 64) * 8;
let current =
i64::from_le_bytes(buffer[word_offset..word_offset + 8].try_into().unwrap());
let updated = current | (1i64 << (idx % 64));
buffer[word_offset..word_offset + 8].copy_from_slice(&updated.to_le_bytes());
}

for (i, &val) in values.iter().enumerate() {
let offset = header_size + i * 4;
buffer[offset..offset + 4].copy_from_slice(&val.to_le_bytes());
}

buffer
}

/// Test that SparkUnsafeArray works correctly when placed at a misaligned
/// address. This is a regression test for a bug where `SparkUnsafeArray::new`
/// used a direct pointer dereference `*(addr as *const i64)` which panics
/// on non-8-byte-aligned addresses. Nested arrays within Spark UnsafeRow
/// can be at arbitrary offsets.
#[test]
fn test_misaligned_array_construction() {
let values = vec![10i32, 20, 30, 40, 50];
let buffer = create_i32_array_buffer(&values, &[]);

// Place the array data at a 4-byte-aligned but not 8-byte-aligned offset
// by prepending 4 bytes. This simulates a nested array within a row
// where preceding fields cause misalignment.
let mut misaligned_buf = vec![0u8; 4 + buffer.len()];
misaligned_buf[4..].copy_from_slice(&buffer);
let misaligned_addr = misaligned_buf.as_ptr() as usize + 4;
assert_ne!(
misaligned_addr % 8,
0,
"address should not be 8-byte aligned"
);

let array = SparkUnsafeArray::new(misaligned_addr as i64);
assert_eq!(array.get_num_elements(), 5);

let mut builder = Int32Builder::with_capacity(5);
append_to_builder::<false>(&DataType::Int32, &mut builder, &array).unwrap();
let result = builder.finish();
assert_eq!(result.len(), 5);
for (i, &expected) in values.iter().enumerate() {
assert_eq!(result.value(i), expected);
}
}

/// Test misaligned array with nullable elements.
#[test]
fn test_misaligned_array_with_nulls() {
let values = vec![100i32, 0, 300, 0, 500];
let null_indices = vec![1, 3];
let buffer = create_i32_array_buffer(&values, &null_indices);

let mut misaligned_buf = vec![0u8; 4 + buffer.len()];
misaligned_buf[4..].copy_from_slice(&buffer);
let misaligned_addr = misaligned_buf.as_ptr() as usize + 4;
assert_ne!(misaligned_addr % 8, 0);

let array = SparkUnsafeArray::new(misaligned_addr as i64);
assert_eq!(array.get_num_elements(), 5);

let mut builder = Int32Builder::with_capacity(5);
append_to_builder::<true>(&DataType::Int32, &mut builder, &array).unwrap();
let result = builder.finish();
assert_eq!(result.len(), 5);
assert_eq!(result.value(0), 100);
assert!(result.is_null(1));
assert_eq!(result.value(2), 300);
assert!(result.is_null(3));
assert_eq!(result.value(4), 500);
}
}
24 changes: 8 additions & 16 deletions native/core/src/execution/shuffle/spark_unsafe/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ pub trait SparkUnsafeObject {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) };
i8::from_le_bytes(slice.try_into().unwrap())
unsafe { *(addr as *const i8) }
}

/// Returns short value at the given index of the object.
Expand All @@ -116,8 +115,7 @@ pub trait SparkUnsafeObject {
let addr = self.get_element_offset(index, 2);
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) };
i16::from_le_bytes(slice.try_into().unwrap())
unsafe { (addr as *const i16).read_unaligned() }
}

/// Returns integer value at the given index of the object.
Expand All @@ -126,8 +124,7 @@ pub trait SparkUnsafeObject {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
i32::from_le_bytes(slice.try_into().unwrap())
unsafe { (addr as *const i32).read_unaligned() }
}

/// Returns long value at the given index of the object.
Expand All @@ -136,8 +133,7 @@ pub trait SparkUnsafeObject {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
i64::from_le_bytes(slice.try_into().unwrap())
unsafe { (addr as *const i64).read_unaligned() }
}

/// Returns float value at the given index of the object.
Expand All @@ -146,8 +142,7 @@ pub trait SparkUnsafeObject {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
f32::from_le_bytes(slice.try_into().unwrap())
unsafe { (addr as *const f32).read_unaligned() }
}

/// Returns double value at the given index of the object.
Expand All @@ -156,8 +151,7 @@ pub trait SparkUnsafeObject {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
f64::from_le_bytes(slice.try_into().unwrap())
unsafe { (addr as *const f64).read_unaligned() }
}

/// Returns string value at the given index of the object.
Expand Down Expand Up @@ -196,8 +190,7 @@ pub trait SparkUnsafeObject {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
i32::from_le_bytes(slice.try_into().unwrap())
unsafe { (addr as *const i32).read_unaligned() }
}

/// Returns timestamp value at the given index of the object.
Expand All @@ -209,8 +202,7 @@ pub trait SparkUnsafeObject {
!addr.is_null(),
"get_timestamp: null pointer at index {index}"
);
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
i64::from_le_bytes(slice.try_into().unwrap())
unsafe { (addr as *const i64).read_unaligned() }
}

/// Returns decimal value at the given index of the object.
Expand Down
Loading