Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 153 additions & 14 deletions datafusion/spark/src/function/string/format_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,37 @@ fn take_numeric_param(s: &str, zero: bool) -> (NumericParam, &str) {
}
}

/// Convert a `u32` to a [`char`] for the `%c` conversion, returning a SQL
/// error if the value is not a valid Unicode scalar value (i.e. is in the
/// surrogate range `0xD800..=0xDFFF` or above `0x10FFFF`). Java's `Formatter`
/// raises `IllegalFormatCodePointException` in the same situations.
fn codepoint_to_char(value: u32) -> Result<char> {
char::from_u32(value).ok_or_else(|| {
exec_datafusion_err!("invalid Unicode scalar value for %c: {value:#x}")
})
}

/// `%c` codepoint validation for signed integer arguments. Negative values
/// and values above `0x10FFFF` (or in the surrogate range) error out, matching
/// Java's `Character.isValidCodePoint` rather than reinterpreting the bits as
/// unsigned.
fn signed_to_char(value: i64) -> Result<char> {
let codepoint = u32::try_from(value).map_err(|_| {
exec_datafusion_err!("invalid Unicode scalar value for %c: {value}")
})?;
codepoint_to_char(codepoint)
}

/// `%c` codepoint validation for unsigned integer arguments. Errors if the
/// value does not fit in a `u32` or is not a valid Unicode scalar value,
/// instead of silently truncating high bits.
fn unsigned_to_char(value: u64) -> Result<char> {
let codepoint = u32::try_from(value).map_err(|_| {
exec_datafusion_err!("invalid Unicode scalar value for %c: {value:#x}")
})?;
codepoint_to_char(codepoint)
}

impl ConversionSpecifier {
/// Validates that the grouping separator flag is not used with scientific
/// notation conversions, matching Java/Spark behavior which throws
Expand Down Expand Up @@ -904,7 +935,7 @@ impl ConversionSpecifier {
Some(value),
) => self.format_unsigned(string, (*value as u8) as u64),
(ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => {
self.format_char(string, *value as u8 as char)
self.format_char(string, signed_to_char(*value as i64)?)
}
(
ConversionType::StringLower | ConversionType::StringUpper,
Expand All @@ -923,10 +954,7 @@ impl ConversionSpecifier {
self.format_signed(string, *value as i64)
}
(ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => {
self.format_char(
string,
char::from_u32((*value as u16) as u32).unwrap(),
)
self.format_char(string, signed_to_char(*value as i64)?)
}
(
ConversionType::HexIntLower
Expand Down Expand Up @@ -957,7 +985,7 @@ impl ConversionSpecifier {
Some(value),
) => self.format_unsigned(string, (*value as u32) as u64),
(ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => {
self.format_char(string, char::from_u32(*value as u32).unwrap())
self.format_char(string, signed_to_char(*value as i64)?)
}
(
ConversionType::StringLower | ConversionType::StringUpper,
Expand All @@ -982,10 +1010,7 @@ impl ConversionSpecifier {
Some(value),
) => self.format_unsigned(string, *value as u64),
(ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => {
self.format_char(
string,
char::from_u32((*value as u64) as u32).unwrap(),
)
self.format_char(string, signed_to_char(*value)?)
}
(
ConversionType::StringLower | ConversionType::StringUpper,
Expand All @@ -1008,7 +1033,7 @@ impl ConversionSpecifier {
Some(value),
) => self.format_unsigned(string, *value as u64),
(ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => {
self.format_char(string, *value as char)
self.format_char(string, unsigned_to_char(*value as u64)?)
}
(
ConversionType::StringLower | ConversionType::StringUpper,
Expand All @@ -1031,7 +1056,7 @@ impl ConversionSpecifier {
Some(value),
) => self.format_unsigned(string, *value as u64),
(ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => {
self.format_char(string, char::from_u32(*value as u32).unwrap())
self.format_char(string, unsigned_to_char(*value as u64)?)
}
(
ConversionType::StringLower | ConversionType::StringUpper,
Expand All @@ -1054,7 +1079,7 @@ impl ConversionSpecifier {
Some(value),
) => self.format_unsigned(string, *value as u64),
(ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => {
self.format_char(string, char::from_u32(*value).unwrap())
self.format_char(string, unsigned_to_char(*value as u64)?)
}
(
ConversionType::StringLower | ConversionType::StringUpper,
Expand All @@ -1077,7 +1102,7 @@ impl ConversionSpecifier {
Some(value),
) => self.format_unsigned(string, *value),
(ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => {
self.format_char(string, char::from_u32(*value as u32).unwrap())
self.format_char(string, unsigned_to_char(*value)?)
}
(
ConversionType::StringLower | ConversionType::StringUpper,
Expand Down Expand Up @@ -2442,6 +2467,120 @@ mod tests {
Ok(())
}

#[test]
fn test_format_char_invalid_codepoint_errors() {
use arrow::datatypes::Field;
use datafusion_common::config::ConfigOptions;

let func = FormatStringFunc::new();
// Spark/Java reject any negative integer or any value outside
// `0..=0x10FFFF` (and the surrogate range) regardless of integer
// width, so all of these inputs must surface a SQL error rather than
// panicking or silently reinterpreting the bits as unsigned.
let cases: Vec<(&str, ScalarValue)> = vec![
("Int8(-1)", ScalarValue::Int8(Some(-1))),
("Int16(-1)", ScalarValue::Int16(Some(-1))),
("Int16(-10000)", ScalarValue::Int16(Some(-10000))),
("Int32(-1)", ScalarValue::Int32(Some(-1))),
("Int32(0x110000)", ScalarValue::Int32(Some(0x110000))),
("Int64(0x1FFFFFFFF)", ScalarValue::Int64(Some(0x1FFFFFFFF))),
("Int64(-1)", ScalarValue::Int64(Some(-1))),
("UInt16(0xD800)", ScalarValue::UInt16(Some(0xD800))),
("UInt32(0x110000)", ScalarValue::UInt32(Some(0x110000))),
(
"UInt64(0x1_0000_0000)",
ScalarValue::UInt64(Some(0x1_0000_0000)),
),
];

for (label, value) in cases {
let fmt = ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string())));
let arg_data_type = value.data_type();
let arg = ColumnarValue::Scalar(value);
let arg_fields = vec![
Arc::new(Field::new("fmt", Utf8, false)),
Arc::new(Field::new("v", arg_data_type, false)),
];
let res = func.invoke_with_args(ScalarFunctionArgs {
args: vec![fmt, arg],
number_rows: 1,
arg_fields,
return_field: Arc::new(Field::new("o", Utf8, false)),
config_options: Arc::new(ConfigOptions::default()),
});
assert!(
res.is_err(),
"format_string('[%c]', {label}) should error, got Ok"
);
let err = res.unwrap_err().to_string();
assert!(
err.contains("invalid Unicode scalar value for %c"),
"unexpected error for {label}: {err}"
);
}
}

#[test]
fn test_format_char_valid_codepoint_succeeds() {
test_scalar_function!(
FormatStringFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(0x1F680))),
],
Ok(Some("[\u{1F680}]")),
&str,
Utf8,
StringArray
);
test_scalar_function!(
FormatStringFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))),
ColumnarValue::Scalar(ScalarValue::UInt32(Some(0x10FFFF))),
],
Ok(Some("[\u{10FFFF}]")),
&str,
Utf8,
StringArray
);
test_scalar_function!(
FormatStringFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))),
ColumnarValue::Scalar(ScalarValue::Int16(Some(65))),
],
Ok(Some("[A]")),
&str,
Utf8,
StringArray
);
// Int8 / UInt8 can never produce an invalid codepoint for non-negative
// values, but they must still flow through the validating helper.
test_scalar_function!(
FormatStringFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))),
ColumnarValue::Scalar(ScalarValue::Int8(Some(97))),
],
Ok(Some("[a]")),
&str,
Utf8,
StringArray
);
test_scalar_function!(
FormatStringFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))),
ColumnarValue::Scalar(ScalarValue::UInt8(Some(255))),
],
Ok(Some("[\u{00FF}]")),
&str,
Utf8,
StringArray
);
}

#[test]
fn test_insert_thousands_separator() {
assert_eq!(insert_thousands_separator("1234567.89"), "1,234,567.89");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,10 @@ Char: A |
statement error
SELECT format_string('Char: %5c', true);

## Character with invalid negative codepoint
statement error
SELECT format_string('Char: %c', -1);

# ================================
# Time formatting tests
# ================================
Expand Down
Loading