Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions datafusion/expr/src/window_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ impl WindowFrameStateRange {
length: usize,
) -> Result<usize> {
let current_row_values = get_row_at_idx(range_columns, idx)?;
let search_start = if SIDE {
last_range.start
} else {
last_range.end
};
let end_range = if let Some(delta) = delta {
let is_descending: bool = self
.sort_options
Expand All @@ -407,34 +412,40 @@ impl WindowFrameStateRange {
})?
.descending;

current_row_values
.iter()
.map(|value| {
if value.is_null() {
return Ok(value.clone());
// On overflow the boundary exceeds the type's range and is
// effectively unbounded within the partition. Collapse to the
// partition edge rather than feeding `search_in_slice` a
// wrapped-around target: PRECEDING searches reach `search_start`,
// FOLLOWING searches reach `length`.
let unbounded_edge = if SEARCH_SIDE { search_start } else { length };
let mut targets = Vec::with_capacity(current_row_values.len());
for value in &current_row_values {
if value.is_null() {
targets.push(value.clone());
continue;
}
let target = if SEARCH_SIDE == is_descending {
match value.add_checked(delta) {
Ok(v) => v,
Err(_) => return Ok(unbounded_edge),
}
if SEARCH_SIDE == is_descending {
// TODO: Handle positive overflows.
value.add(delta)
} else if value.is_unsigned() && value < delta {
// NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue.
// If we decide to implement a "default" construction mechanism for ScalarValue,
// change the following statement to use that.
value.sub(value)
} else {
// TODO: Handle negative overflows.
value.sub(delta)
} else if value.is_unsigned() && value < delta {
// NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue.
// If we decide to implement a "default" construction mechanism for ScalarValue,
// change the following statement to use that.
value.sub(value)?
} else {
match value.sub_checked(delta) {
Ok(v) => v,
Err(_) => return Ok(unbounded_edge),
}
})
.collect::<Result<Vec<ScalarValue>>>()?
};
targets.push(target);
}
targets
} else {
current_row_values
};
let search_start = if SIDE {
last_range.start
} else {
last_range.end
};
let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
let cmp = compare_rows(current, target, &self.sort_options)?;
Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
Expand Down
50 changes: 35 additions & 15 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,16 @@ impl Accumulator for AvgAccumulator {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(
self.sum.map(|f| f / self.count as f64),
))
// In sliding-window mode `retract_batch` can bring `count` back to 0
// while `sum` remains `Some(..)` (possibly zero or a floating-point
// residual). Guard against that so the frame with no non-NULL values
// yields NULL rather than NaN / ±Inf.
let avg = if self.count == 0 {
None
} else {
self.sum.map(|f| f / self.count as f64)
};
Ok(ScalarValue::Float64(avg))
}

fn size(&self) -> usize {
Expand Down Expand Up @@ -584,17 +591,23 @@ impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumu
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let v = self
.sum
.map(|v| {
DecimalAverager::<T>::try_new(
self.sum_scale,
self.target_precision,
self.target_scale,
)?
.avg(v, T::Native::from_usize(self.count as usize).unwrap())
})
.transpose()?;
// `count == 0` can occur in sliding-window mode after `retract_batch`
// removes every contributing value. Return NULL rather than dividing
// by zero (which would panic for integer decimal types).
let v = if self.count == 0 {
None
} else {
self.sum
.map(|v| {
DecimalAverager::<T>::try_new(
self.sum_scale,
self.target_precision,
self.target_scale,
)?
.avg(v, T::Native::from_usize(self.count as usize).unwrap())
})
.transpose()?
};

ScalarValue::new_primitive::<T>(
v,
Expand Down Expand Up @@ -670,7 +683,14 @@ impl Accumulator for DurationAvgAccumulator {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let avg = self.sum.map(|sum| sum / self.count as i64);
// Guard against `count == 0` which can happen in sliding-window mode
// after every contributing value has been retracted. Without this
// check we would integer-divide by zero.
let avg = if self.count == 0 {
None
} else {
self.sum.map(|sum| sum / self.count as i64)
};

match self.result_unit {
TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
Expand Down
24 changes: 16 additions & 8 deletions datafusion/functions-window/src/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,22 @@ impl WindowUDFImpl for NthValue {
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
let return_type = field_args
.input_fields()
.first()
.map(|f| f.data_type())
.cloned()
.unwrap_or(DataType::Null);

Ok(Field::new(field_args.name(), return_type, true).into())
let input_field =
field_args
.input_fields()
.first()
.cloned()
.unwrap_or_else(|| {
Arc::new(Field::new(field_args.name(), DataType::Null, true))
});

// Clone the input field to preserve metadata, update name and nullability
Ok(input_field
.as_ref()
.clone()
.with_name(field_args.name())
.with_nullable(true)
.into())
}

fn reverse_expr(&self) -> ReversedUDWF {
Expand Down
46 changes: 46 additions & 0 deletions datafusion/sqllogictest/test_files/metadata.slt
Original file line number Diff line number Diff line change
Expand Up @@ -472,5 +472,51 @@ select arrow_metadata(with_metadata(id, 'unit', ''), 'unit') from table_with_met
----
(empty)

# Regression test: window functions should preserve field metadata
# Test FIRST_VALUE window function preserves metadata
query IT
select
first_value(id) over (order by id asc nulls last) as fv,
arrow_metadata(first_value(id) over (order by id asc nulls last), 'metadata_key') as meta
from table_with_metadata limit 1;
----
1 the id field

# Test LAST_VALUE window function preserves metadata
query IT
select
last_value(id) over (order by id asc nulls last rows between unbounded preceding and unbounded following) as lv,
arrow_metadata(last_value(id) over (order by id asc nulls last rows between unbounded preceding and unbounded following), 'metadata_key') as meta
from table_with_metadata limit 1;
----
NULL the id field

# Test NTH_VALUE window function preserves metadata
query IT
select
nth_value(id, 2) over (order by id asc nulls last) as nv,
arrow_metadata(nth_value(id, 2) over (order by id asc nulls last), 'metadata_key') as meta
from table_with_metadata limit 1;
----
NULL the id field

# Test LEAD window function preserves metadata
query IT
select
lead(id) over (order by id asc nulls last) as ld,
arrow_metadata(lead(id) over (order by id asc nulls last), 'metadata_key') as meta
from table_with_metadata limit 1;
----
3 the id field

# Test LAG window function preserves metadata
query IT
select
lag(id) over (order by id asc nulls last) as lg,
arrow_metadata(lag(id) over (order by id asc nulls last), 'metadata_key') as meta
from table_with_metadata limit 1;
----
NULL the id field

statement ok
drop table table_with_metadata;
Loading
Loading