diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index 06c77f37021bf..26635b64dc90a 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -21,28 +21,25 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, StringBuilder}; use arrow::datatypes::DataType; -use arrow::{ - array::{as_dictionary_array, as_largestring_array, as_string_array}, - datatypes::Int32Type, +use arrow::downcast_dictionary_array; +use datafusion_common::cast::{ + as_binary_array, as_binary_view_array, as_fixed_size_binary_array, as_int64_array, + as_large_binary_array, as_large_string_array, as_string_array, as_string_view_array, }; -use datafusion_common::cast::as_large_binary_array; -use datafusion_common::cast::as_string_view_array; use datafusion_common::types::{NativeType, logical_int64, logical_string}; use datafusion_common::utils::take_function_args; -use datafusion_common::{ - DataFusionError, - cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, - exec_err, -}; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::function::Hint; use datafusion_expr::{ Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; +use datafusion_functions::utils::make_scalar_function; + /// #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkHex { signature: Signature, - aliases: Vec, } impl Default for SparkHex { @@ -74,7 +71,6 @@ impl SparkHex { Self { signature: Signature::one_of(variants, Volatility::Immutable), - aliases: vec![], } } } @@ -92,7 +88,7 @@ impl ScalarUDFImpl for SparkHex { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match &arg_types[0] { DataType::Dictionary(key_type, _) => { DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) @@ -101,63 +97,81 @@ impl ScalarUDFImpl for SparkHex { }) } - fn invoke_with_args( - &self, - args: ScalarFunctionArgs, - ) -> datafusion_common::Result { - spark_hex(&args.args) - } - - fn aliases(&self) -> &[String] { - &self.aliases + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(compute_hex, vec![Hint::AcceptsSingular])(&args.args) } } -/// Hex encoding lookup tables for fast byte-to-hex conversion -const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; -const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF"; - -#[inline] -fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] { - if num == 0 { - return b"0"; +fn compute_hex(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("hex", args)?; + downcast_dictionary_array! { + array => { + let values = hex_values(array.values())?; + Ok(Arc::new(array.with_values(values))) + }, + _ => { + hex_values(array) + } } +} - let mut n = num as u64; - let mut i = 16; - while n != 0 { - i -= 1; - buffer[i] = HEX_CHARS_UPPER[(n & 0xF) as usize]; - n >>= 4; +fn hex_values(array: &ArrayRef) -> Result { + match array.data_type() { + DataType::Int64 => { + let array = as_int64_array(array)?; + hex_encode_int64(array.iter()) + } + DataType::Utf8 => { + let array = as_string_array(array)?; + hex_encode_bytes(array.iter()) + } + DataType::Utf8View => { + let array = as_string_view_array(array)?; + hex_encode_bytes(array.iter()) + } + DataType::LargeUtf8 => { + let array = as_large_string_array(array)?; + hex_encode_bytes(array.iter()) + } + DataType::Binary => { + let array = as_binary_array(array)?; + hex_encode_bytes(array.iter()) + } + DataType::LargeBinary => { + let array = as_large_binary_array(array)?; + hex_encode_bytes(array.iter()) + } + DataType::BinaryView => { + let array = as_binary_view_array(array)?; + hex_encode_bytes(array.iter()) + } + DataType::FixedSizeBinary(_) => { + let array = as_fixed_size_binary_array(array)?; + hex_encode_bytes(array.iter()) + } + dt => internal_err!("Unexpected data type for hex: {dt}"), } - &buffer[i..] } +/// Hex encoding lookup tables for fast byte-to-hex conversion +const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF"; + /// Generic hex encoding for byte array types -fn hex_encode_bytes<'a, I, T>( - iter: I, - lowercase: bool, - len: usize, -) -> Result +fn hex_encode_bytes<'a, I, T>(iter: I) -> Result where - I: Iterator>, + I: ExactSizeIterator>, T: AsRef<[u8]> + 'a, { - let mut builder = StringBuilder::with_capacity(len, len * 64); + let mut builder = StringBuilder::with_capacity(iter.len(), iter.len() * 64); let mut buffer = Vec::with_capacity(64); - let hex_chars = if lowercase { - HEX_CHARS_LOWER - } else { - HEX_CHARS_UPPER - }; for v in iter { if let Some(b) = v { buffer.clear(); let bytes = b.as_ref(); for &byte in bytes { - buffer.push(hex_chars[(byte >> 4) as usize]); - buffer.push(hex_chars[(byte & 0x0f) as usize]); + buffer.push(HEX_CHARS_UPPER[(byte >> 4) as usize]); + buffer.push(HEX_CHARS_UPPER[(byte & 0x0f) as usize]); } // SAFETY: buffer contains only ASCII hex digests, which are valid UTF-8 unsafe { @@ -171,12 +185,27 @@ where Ok(Arc::new(builder.finish())) } +#[inline] +fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] { + if num == 0 { + return b"0"; + } + + let mut n = num as u64; + let mut i = 16; + while n != 0 { + i -= 1; + buffer[i] = HEX_CHARS_UPPER[(n & 0xF) as usize]; + n >>= 4; + } + &buffer[i..] +} + /// Generic hex encoding for int64 type fn hex_encode_int64( - iter: impl Iterator>, - len: usize, -) -> Result { - let mut builder = StringBuilder::with_capacity(len, len * 16); + iter: impl ExactSizeIterator>, +) -> Result { + let mut builder = StringBuilder::with_capacity(iter.len(), iter.len() * 16); for v in iter { if let Some(num) = v { @@ -194,241 +223,16 @@ fn hex_encode_int64( Ok(Arc::new(builder.finish())) } -/// Spark-compatible `hex` function -pub fn spark_hex(args: &[ColumnarValue]) -> Result { - compute_hex(args, false) -} - -/// Spark-compatible `sha2` function -pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result { - compute_hex(args, true) -} - -pub fn compute_hex( - args: &[ColumnarValue], - lowercase: bool, -) -> Result { - let input = match take_function_args("hex", args)? { - [ColumnarValue::Scalar(value)] => ColumnarValue::Array(value.to_array()?), - [ColumnarValue::Array(arr)] => ColumnarValue::Array(Arc::clone(arr)), - }; - - match &input { - ColumnarValue::Array(array) => match array.data_type() { - DataType::Int64 => { - let array = as_int64_array(array)?; - Ok(ColumnarValue::Array(hex_encode_int64( - array.iter(), - array.len(), - )?)) - } - DataType::Utf8 => { - let array = as_string_array(array); - Ok(ColumnarValue::Array(hex_encode_bytes( - array.iter(), - lowercase, - array.len(), - )?)) - } - DataType::Utf8View => { - let array = as_string_view_array(array)?; - Ok(ColumnarValue::Array(hex_encode_bytes( - array.iter(), - lowercase, - array.len(), - )?)) - } - DataType::LargeUtf8 => { - let array = as_largestring_array(array); - Ok(ColumnarValue::Array(hex_encode_bytes( - array.iter(), - lowercase, - array.len(), - )?)) - } - DataType::Binary => { - let array = as_binary_array(array)?; - Ok(ColumnarValue::Array(hex_encode_bytes( - array.iter(), - lowercase, - array.len(), - )?)) - } - DataType::LargeBinary => { - let array = as_large_binary_array(array)?; - Ok(ColumnarValue::Array(hex_encode_bytes( - array.iter(), - lowercase, - array.len(), - )?)) - } - DataType::FixedSizeBinary(_) => { - let array = as_fixed_size_binary_array(array)?; - Ok(ColumnarValue::Array(hex_encode_bytes( - array.iter(), - lowercase, - array.len(), - )?)) - } - DataType::Dictionary(key_type, _) => { - if **key_type != DataType::Int32 { - return exec_err!( - "hex only supports Int32 dictionary keys, get: {}", - key_type - ); - } - - let dict = as_dictionary_array::(&array); - let dict_values = dict.values(); - - let encoded_values = match dict_values.data_type() { - DataType::Int64 => { - let arr = as_int64_array(dict_values)?; - hex_encode_int64(arr.iter(), arr.len())? - } - DataType::Utf8 => { - let arr = as_string_array(dict_values); - hex_encode_bytes(arr.iter(), lowercase, arr.len())? - } - DataType::LargeUtf8 => { - let arr = as_largestring_array(dict_values); - hex_encode_bytes(arr.iter(), lowercase, arr.len())? - } - DataType::Utf8View => { - let arr = as_string_view_array(dict_values)?; - hex_encode_bytes(arr.iter(), lowercase, arr.len())? - } - DataType::Binary => { - let arr = as_binary_array(dict_values)?; - hex_encode_bytes(arr.iter(), lowercase, arr.len())? - } - DataType::LargeBinary => { - let arr = as_large_binary_array(dict_values)?; - hex_encode_bytes(arr.iter(), lowercase, arr.len())? - } - DataType::FixedSizeBinary(_) => { - let arr = as_fixed_size_binary_array(dict_values)?; - hex_encode_bytes(arr.iter(), lowercase, arr.len())? - } - _ => { - return exec_err!( - "hex got an unexpected argument type: {}", - dict_values.data_type() - ); - } - }; - - let new_dict = dict.with_values(encoded_values); - Ok(ColumnarValue::Array(Arc::new(new_dict))) - } - _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), - }, - _ => exec_err!("native hex does not support scalar values at this time"), - } -} - #[cfg(test)] mod test { use std::str::from_utf8_unchecked; use std::sync::Arc; use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray}; - use arrow::{ - array::{ - BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder, - as_string_array, - }, - datatypes::{Int32Type, Int64Type}, - }; - use datafusion_common::cast::as_dictionary_array; - use datafusion_expr::ColumnarValue; - - #[test] - fn test_dictionary_hex_utf8() { - let mut input_builder = StringDictionaryBuilder::::new(); - input_builder.append_value("hi"); - input_builder.append_value("bye"); - input_builder.append_null(); - input_builder.append_value("rust"); - let input = input_builder.finish(); - - let mut expected_builder = StringDictionaryBuilder::::new(); - expected_builder.append_value("6869"); - expected_builder.append_value("627965"); - expected_builder.append_null(); - expected_builder.append_value("72757374"); - let expected = expected_builder.finish(); - - let columnar_value = ColumnarValue::Array(Arc::new(input)); - let result = super::spark_hex(&[columnar_value]).unwrap(); - - let result = match result { - ColumnarValue::Array(array) => array, - _ => panic!("Expected array"), - }; - let result = as_dictionary_array(&result).unwrap(); - - assert_eq!(result, &expected); - } - - #[test] - fn test_dictionary_hex_int64() { - let mut input_builder = PrimitiveDictionaryBuilder::::new(); - input_builder.append_value(1); - input_builder.append_value(2); - input_builder.append_null(); - input_builder.append_value(3); - let input = input_builder.finish(); - - let mut expected_builder = StringDictionaryBuilder::::new(); - expected_builder.append_value("1"); - expected_builder.append_value("2"); - expected_builder.append_null(); - expected_builder.append_value("3"); - let expected = expected_builder.finish(); - - let columnar_value = ColumnarValue::Array(Arc::new(input)); - let result = super::spark_hex(&[columnar_value]).unwrap(); - - let result = match result { - ColumnarValue::Array(array) => array, - _ => panic!("Expected array"), - }; - - let result = as_dictionary_array(&result).unwrap(); - - assert_eq!(result, &expected); - } - - #[test] - fn test_dictionary_hex_binary() { - let mut input_builder = BinaryDictionaryBuilder::::new(); - input_builder.append_value("1"); - input_builder.append_value("j"); - input_builder.append_null(); - input_builder.append_value("3"); - let input = input_builder.finish(); - - let mut expected_builder = StringDictionaryBuilder::::new(); - expected_builder.append_value("31"); - expected_builder.append_value("6A"); - expected_builder.append_null(); - expected_builder.append_value("33"); - let expected = expected_builder.finish(); - - let columnar_value = ColumnarValue::Array(Arc::new(input)); - let result = super::spark_hex(&[columnar_value]).unwrap(); - - let result = match result { - ColumnarValue::Array(array) => array, - _ => panic!("Expected array"), - }; - - let result = as_dictionary_array(&result).unwrap(); + use datafusion_common::cast::as_dictionary_array; - assert_eq!(result, &expected); - } + use super::*; #[test] fn test_hex_int64() { @@ -436,55 +240,23 @@ mod test { for (num, expected) in test_cases { let mut cache = [0u8; 16]; - let slice = super::hex_int64(num, &mut cache); + let slice = hex_int64(num, &mut cache); - unsafe { - let result = from_utf8_unchecked(slice); - assert_eq!(expected, result); - } + let result = unsafe { from_utf8_unchecked(slice) }; + assert_eq!(expected, result); } } - #[test] - fn test_spark_hex_int64() { - let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); - let columnar_value = ColumnarValue::Array(Arc::new(int_array)); - - let result = super::spark_hex(&[columnar_value]).unwrap(); - let result = match result { - ColumnarValue::Array(array) => array, - _ => panic!("Expected array"), - }; - - let string_array = as_string_array(&result); - let expected_array = StringArray::from(vec![ - Some("1".to_string()), - Some("2".to_string()), - None, - Some("3".to_string()), - ]); - - assert_eq!(string_array, &expected_array); - } - #[test] fn test_dict_values_null() { let keys = Int32Array::from(vec![Some(0), None, Some(1)]); let vals = Int64Array::from(vec![Some(32), None]); // [32, null, null] - let dict = DictionaryArray::new(keys, Arc::new(vals)); - - let columnar_value = ColumnarValue::Array(Arc::new(dict)); - let result = super::spark_hex(&[columnar_value]).unwrap(); - - let result = match result { - ColumnarValue::Array(array) => array, - _ => panic!("Expected array"), - }; + let dict = Arc::new(DictionaryArray::new(keys.clone(), Arc::new(vals))); + let result = compute_hex(&[dict]).unwrap(); let result = as_dictionary_array(&result).unwrap(); - let keys = Int32Array::from(vec![Some(0), None, Some(1)]); let vals = StringArray::from(vec![Some("20"), None]); let expected = DictionaryArray::new(keys, Arc::new(vals)); diff --git a/datafusion/sqllogictest/test_files/spark/math/hex.slt b/datafusion/sqllogictest/test_files/spark/math/hex.slt index 17e9ff432890d..60ea335d93d9d 100644 --- a/datafusion/sqllogictest/test_files/spark/math/hex.slt +++ b/datafusion/sqllogictest/test_files/spark/math/hex.slt @@ -83,3 +83,25 @@ query T SELECT arrow_typeof(hex(dict_col)) FROM t_dict_binary LIMIT 1; ---- Dictionary(Int32, Utf8) + +query TT +WITH values AS ( + SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') AS column1 + FROM VALUES ('hi'), ('bye'), (NULL), ('rust') +) SELECT hex(column1), arrow_typeof(hex(column1)) FROM values +---- +6869 Utf8 +627965 Utf8 +NULL Utf8 +72757374 Utf8 + +query TT +WITH values AS ( + SELECT arrow_cast(column1, 'Dictionary(Int32, Int64)') AS column1 + FROM VALUES (1), (2), (NULL), (3) +) SELECT hex(column1), arrow_typeof(hex(column1)) FROM values +---- +1 Utf8 +2 Utf8 +NULL Utf8 +3 Utf8