diff --git a/native/core/src/execution/shuffle/spark_unsafe/list.rs b/native/core/src/execution/shuffle/spark_unsafe/list.rs index 72610d2d82..5549e0fb7a 100644 --- a/native/core/src/execution/shuffle/spark_unsafe/list.rs +++ b/native/core/src/execution/shuffle/spark_unsafe/list.rs @@ -109,8 +109,9 @@ impl SparkUnsafeArray { // SAFETY: addr points to valid Spark UnsafeArray data from the JVM. // The first 8 bytes contain the element count as a little-endian i64. debug_assert!(addr != 0, "SparkUnsafeArray::new: null address"); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; - let num_elements = i64::from_le_bytes(slice.try_into().unwrap()); + // SAFETY: Spark UnsafeArray stores element count as the first 8 bytes. + // Use read_unaligned because nested arrays may not be 8-byte aligned. + let num_elements = unsafe { (addr as *const i64).read_unaligned() }; if num_elements < 0 { panic!("Negative number of elements: {num_elements}"); @@ -478,3 +479,97 @@ pub fn append_list_element( Ok(()) } + +#[cfg(test)] +mod test { + use super::*; + use arrow::array::builder::Int32Builder; + use arrow::array::Array; + + /// Helper to create a SparkUnsafeArray buffer with i32 elements. + /// Layout: 8 bytes num_elements + null bitset + element data. + fn create_i32_array_buffer(values: &[i32], null_indices: &[usize]) -> Vec { + let num_elements = values.len(); + let null_bitset_words = num_elements.div_ceil(64); + let header_size = 8 + null_bitset_words * 8; + let data_size = num_elements * 4; + let mut buffer = vec![0u8; header_size + data_size]; + + buffer[0..8].copy_from_slice(&(num_elements as i64).to_le_bytes()); + + for &idx in null_indices { + let word_offset = 8 + (idx / 64) * 8; + let current = + i64::from_le_bytes(buffer[word_offset..word_offset + 8].try_into().unwrap()); + let updated = current | (1i64 << (idx % 64)); + buffer[word_offset..word_offset + 8].copy_from_slice(&updated.to_le_bytes()); + } + + for (i, &val) in values.iter().enumerate() { + let offset = header_size + i * 4; + buffer[offset..offset + 4].copy_from_slice(&val.to_le_bytes()); + } + + buffer + } + + /// Test that SparkUnsafeArray works correctly when placed at a misaligned + /// address. This is a regression test for a bug where `SparkUnsafeArray::new` + /// used a direct pointer dereference `*(addr as *const i64)` which panics + /// on non-8-byte-aligned addresses. Nested arrays within Spark UnsafeRow + /// can be at arbitrary offsets. + #[test] + fn test_misaligned_array_construction() { + let values = vec![10i32, 20, 30, 40, 50]; + let buffer = create_i32_array_buffer(&values, &[]); + + // Place the array data at a 4-byte-aligned but not 8-byte-aligned offset + // by prepending 4 bytes. This simulates a nested array within a row + // where preceding fields cause misalignment. + let mut misaligned_buf = vec![0u8; 4 + buffer.len()]; + misaligned_buf[4..].copy_from_slice(&buffer); + let misaligned_addr = misaligned_buf.as_ptr() as usize + 4; + assert_ne!( + misaligned_addr % 8, + 0, + "address should not be 8-byte aligned" + ); + + let array = SparkUnsafeArray::new(misaligned_addr as i64); + assert_eq!(array.get_num_elements(), 5); + + let mut builder = Int32Builder::with_capacity(5); + append_to_builder::(&DataType::Int32, &mut builder, &array).unwrap(); + let result = builder.finish(); + assert_eq!(result.len(), 5); + for (i, &expected) in values.iter().enumerate() { + assert_eq!(result.value(i), expected); + } + } + + /// Test misaligned array with nullable elements. + #[test] + fn test_misaligned_array_with_nulls() { + let values = vec![100i32, 0, 300, 0, 500]; + let null_indices = vec![1, 3]; + let buffer = create_i32_array_buffer(&values, &null_indices); + + let mut misaligned_buf = vec![0u8; 4 + buffer.len()]; + misaligned_buf[4..].copy_from_slice(&buffer); + let misaligned_addr = misaligned_buf.as_ptr() as usize + 4; + assert_ne!(misaligned_addr % 8, 0); + + let array = SparkUnsafeArray::new(misaligned_addr as i64); + assert_eq!(array.get_num_elements(), 5); + + let mut builder = Int32Builder::with_capacity(5); + append_to_builder::(&DataType::Int32, &mut builder, &array).unwrap(); + let result = builder.finish(); + assert_eq!(result.len(), 5); + assert_eq!(result.value(0), 100); + assert!(result.is_null(1)); + assert_eq!(result.value(2), 300); + assert!(result.is_null(3)); + assert_eq!(result.value(4), 500); + } +} diff --git a/native/core/src/execution/shuffle/spark_unsafe/row.rs b/native/core/src/execution/shuffle/spark_unsafe/row.rs index 6b41afae8d..3a3055e825 100644 --- a/native/core/src/execution/shuffle/spark_unsafe/row.rs +++ b/native/core/src/execution/shuffle/spark_unsafe/row.rs @@ -106,8 +106,7 @@ pub trait SparkUnsafeObject { let addr = self.get_element_offset(index, 1); // SAFETY: addr points to valid element data (1 byte) within the row/array region. debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}"); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) }; - i8::from_le_bytes(slice.try_into().unwrap()) + unsafe { *(addr as *const i8) } } /// Returns short value at the given index of the object. @@ -116,8 +115,7 @@ pub trait SparkUnsafeObject { let addr = self.get_element_offset(index, 2); // SAFETY: addr points to valid element data (2 bytes) within the row/array region. debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}"); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) }; - i16::from_le_bytes(slice.try_into().unwrap()) + unsafe { (addr as *const i16).read_unaligned() } } /// Returns integer value at the given index of the object. @@ -126,8 +124,7 @@ pub trait SparkUnsafeObject { let addr = self.get_element_offset(index, 4); // SAFETY: addr points to valid element data (4 bytes) within the row/array region. debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}"); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; - i32::from_le_bytes(slice.try_into().unwrap()) + unsafe { (addr as *const i32).read_unaligned() } } /// Returns long value at the given index of the object. @@ -136,8 +133,7 @@ pub trait SparkUnsafeObject { let addr = self.get_element_offset(index, 8); // SAFETY: addr points to valid element data (8 bytes) within the row/array region. debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}"); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; - i64::from_le_bytes(slice.try_into().unwrap()) + unsafe { (addr as *const i64).read_unaligned() } } /// Returns float value at the given index of the object. @@ -146,8 +142,7 @@ pub trait SparkUnsafeObject { let addr = self.get_element_offset(index, 4); // SAFETY: addr points to valid element data (4 bytes) within the row/array region. debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}"); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; - f32::from_le_bytes(slice.try_into().unwrap()) + unsafe { (addr as *const f32).read_unaligned() } } /// Returns double value at the given index of the object. @@ -156,8 +151,7 @@ pub trait SparkUnsafeObject { let addr = self.get_element_offset(index, 8); // SAFETY: addr points to valid element data (8 bytes) within the row/array region. debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}"); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; - f64::from_le_bytes(slice.try_into().unwrap()) + unsafe { (addr as *const f64).read_unaligned() } } /// Returns string value at the given index of the object. @@ -196,8 +190,7 @@ pub trait SparkUnsafeObject { let addr = self.get_element_offset(index, 4); // SAFETY: addr points to valid element data (4 bytes) within the row/array region. debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}"); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; - i32::from_le_bytes(slice.try_into().unwrap()) + unsafe { (addr as *const i32).read_unaligned() } } /// Returns timestamp value at the given index of the object. @@ -209,8 +202,7 @@ pub trait SparkUnsafeObject { !addr.is_null(), "get_timestamp: null pointer at index {index}" ); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; - i64::from_le_bytes(slice.try_into().unwrap()) + unsafe { (addr as *const i64).read_unaligned() } } /// Returns decimal value at the given index of the object.