diff --git a/Cargo.toml b/Cargo.toml index 2d182e0..aa13898 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,4 +61,4 @@ debug = true [package.metadata.docs.rs] all-features = true -[lints.rust] +[lints.rust] \ No newline at end of file diff --git a/src/algorithms/dit.rs b/src/algorithms/dit.rs index 29b1a16..f15f6fb 100644 --- a/src/algorithms/dit.rs +++ b/src/algorithms/dit.rs @@ -17,7 +17,7 @@ use fearless_simd::{dispatch, Simd}; use crate::algorithms::bravo::{bit_rev_bravo_f32, bit_rev_bravo_f64}; -use crate::kernels::codelets::{fft_dit_codelet_32_f32, fft_dit_codelet_32_f64}; +use crate::kernels::codelets::{fft_dit_codelet_16_f64, fft_dit_codelet_32_f32}; use crate::kernels::dit::*; use crate::options::Options; use crate::parallel::run_maybe_in_parallel; @@ -42,10 +42,11 @@ fn recursive_dit_fft_f64( let log_size = size.ilog2() as usize; if size <= L1_BLOCK_SIZE { - // Use FFT-32 codelet to fuse stages 0-4 into a single pass per 32-element chunk - let start_stage = if planner.use_codelet_32 { - fft_dit_codelet_32_f64(simd, &mut reals[..size], &mut imags[..size]); - 5 + // Use FFT-16 codelet to fuse stages 0-3 into a single pass per 16-element chunk + let codelet_stages = 4; + let start_stage = if stage_twiddle_idx == 0 && size >= power_of_two(codelet_stages) { + fft_dit_codelet_16_f64(simd, &mut reals[..size], &mut imags[..size]); + codelet_stages } else { 0 }; @@ -109,9 +110,10 @@ fn recursive_dit_fft_f32( if size <= L1_BLOCK_SIZE { // Use FFT-32 codelet to fuse stages 0-4 into a single pass per 32-element chunk - let start_stage = if planner.use_codelet_32 { + let codelet_stages = 5; + let start_stage = if stage_twiddle_idx == 0 && size >= power_of_two(codelet_stages) { fft_dit_codelet_32_f32(simd, &mut reals[..size], &mut imags[..size]); - 5 + codelet_stages } else { 0 }; @@ -392,3 +394,10 @@ fn fft_32_dit_with_planner_and_opts_impl( } } } + +#[inline] +fn power_of_two(power: usize) -> usize { + // 2.pow() requires a lot of ugly type annotations so here's a helper function + debug_assert!(power < usize::BITS as usize); + 1 << power +} diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index 1d622fb..765dfb9 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -3,247 +3,217 @@ //! A codelet is a self-contained FFT kernel that fuses multiple stages into a single function //! call, eliminating per-stage function call overhead and giving LLVM a wider optimization window. //! -use fearless_simd::{f32x16, f32x4, f32x8, f64x4, f64x8, Simd, SimdBase, SimdFloat, SimdFrom}; +use fearless_simd::{ + f32x4, f32x8, f64x4, Simd, SimdBase, SimdCombine, SimdFloat, SimdFrom, SimdSplit, +}; -/// FFT-32 codelet for `f64`: executes stages 0-4 (chunk_size 2 through 32) in a single function. +/// Equivalent to `a.interleave(b)` — returns `(a.zip_low(b), a.zip_high(b))`. +/// Slow polyfill for +#[inline(always)] +fn interleave_f64x4(a: f64x4, b: f64x4) -> (f64x4, f64x4) { + (a.zip_low(b), a.zip_high(b)) +} + +/// Equivalent to `a.interleave(b)` — returns `(a.zip_low(b), a.zip_high(b))`. +/// Slow polyfill for +#[inline(always)] +fn interleave_f32x4(a: f32x4, b: f32x4) -> (f32x4, f32x4) { + (a.zip_low(b), a.zip_high(b)) +} + +/// FFT-16 codelet for `f64`: executes stages 0-3 (chunk_size 2 through 16) in a single function. +/// +/// Register-resident implementation: all 16 complex values are loaded into f64x4 vectors, +/// all 4 butterfly stages execute in registers with no intermediate memory traffic, +/// then results are stored back. +/// +/// The limiting factor is register pressure: we cannot fuse more stages on AVX2 or NEON. +/// Trying to do so incurs haphazard loads/stores from the stack and ends up being slower +/// than running orderly single-stage passes that load/store predictably. #[inline(never)] -pub fn fft_dit_codelet_32_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { +pub fn fft_dit_codelet_16_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { simd.vectorize( #[inline(always)] - || fft_dit_codelet_32_simd_f64(simd, reals, imags), + || fft_dit_codelet_16_simd_f64(simd, reals, imags), ) } #[inline(always)] -fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { - // Fused stages 0+1: radix-2^2 4-point DIT in a single sweep - reals - .chunks_exact_mut(4) - .zip(imags.chunks_exact_mut(4)) - .for_each(|(re, im)| { - let (a_re, a_im) = (re[0], im[0]); - let (b_re, b_im) = (re[1], im[1]); - let (c_re, c_im) = (re[2], im[2]); - let (d_re, d_im) = (re[3], im[3]); - - // Stage 0: dist=1 butterflies on pairs (a,b) and (c,d) - let (t0_re, t0_im) = (a_re + b_re, a_im + b_im); - let (t1_re, t1_im) = (a_re - b_re, a_im - b_im); - let (t2_re, t2_im) = (c_re + d_re, c_im + d_im); - let (t3_re, t3_im) = (c_re - d_re, c_im - d_im); - - // Stage 1: -j multiply on t3, then dist=2 butterflies - let (t3j_re, t3j_im) = (t3_im, -t3_re); - - re[0] = t0_re + t2_re; - im[0] = t0_im + t2_im; - re[1] = t1_re + t3j_re; - im[1] = t1_im + t3j_im; - re[2] = t0_re - t2_re; - im[2] = t0_im - t2_im; - re[3] = t1_re - t3j_re; - im[3] = t1_im - t3j_im; - }); +fn fft_dit_codelet_16_simd_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { + assert_eq!(reals.len(), imags.len()); + + let two = f64x4::splat(simd, 2.0); + + for (re, im) in reals.chunks_exact_mut(16).zip(imags.chunks_exact_mut(16)) { + macro_rules! transpose4x4_f64 { + ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ + let (t0, t1) = interleave_f64x4($g0, $g2); + let (t2, t3) = interleave_f64x4($g1, $g3); + let (r0, r1) = interleave_f64x4(t0, t2); + let (r2, r3) = interleave_f64x4(t1, t3); + (r0, r1, r2, r3) + }}; + } - // Stage 2: dist=4, chunk_size=8, W_8 twiddles via f64x4 - { - let tw_re = f64x4::simd_from( - simd, - [ - 1.0, // W_8^0 - std::f64::consts::FRAC_1_SQRT_2, // W_8^1 - 0.0, // W_8^2 - -std::f64::consts::FRAC_1_SQRT_2, // W_8^3 - ], - ); - let tw_im = f64x4::simd_from( - simd, - [ - 0.0, // W_8^0 - -std::f64::consts::FRAC_1_SQRT_2, // W_8^1 - -1.0, // W_8^2 - -std::f64::consts::FRAC_1_SQRT_2, // W_8^3 - ], - ); - let two = f64x4::splat(simd, 2.0); - - (reals.as_chunks_mut::<8>().0.iter_mut()) - .zip(imags.as_chunks_mut::<8>().0.iter_mut()) - .for_each(|(re8, im8)| { - let (re_lo, re_hi) = re8.split_at_mut(4); - let (im_lo, im_hi) = im8.split_at_mut(4); - - let in0_re = f64x4::from_slice(simd, re_lo); - let in1_re = f64x4::from_slice(simd, re_hi); - let in0_im = f64x4::from_slice(simd, im_lo); - let in1_im = f64x4::from_slice(simd, im_hi); - - let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); - let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(re_lo); - out0_im.store_slice(im_lo); - out1_re.store_slice(re_hi); - out1_im.store_slice(im_hi); - }); - } + macro_rules! radix4_transpose { + ($va_re:expr, $vb_re:expr, $vc_re:expr, $vd_re:expr, + $va_im:expr, $vb_im:expr, $vc_im:expr, $vd_im:expr) => {{ + let (e0_re, e1_re, e2_re, e3_re) = + transpose4x4_f64!($va_re, $vb_re, $vc_re, $vd_re); + let (e0_im, e1_im, e2_im, e3_im) = + transpose4x4_f64!($va_im, $vb_im, $vc_im, $vd_im); + + let s01_re = e0_re + e1_re; + let d01_re = e0_re - e1_re; + let s23_re = e2_re + e3_re; + let d23_re = e2_re - e3_re; + let s01_im = e0_im + e1_im; + let d01_im = e0_im - e1_im; + let s23_im = e2_im + e3_im; + let d23_im = e2_im - e3_im; + + let p0_re = s01_re + s23_re; + let p2_re = s01_re - s23_re; + let p0_im = s01_im + s23_im; + let p2_im = s01_im - s23_im; + + let p1_re = d01_re + d23_im; + let p3_re = d01_re - d23_im; + let p1_im = d01_im - d23_re; + let p3_im = d01_im + d23_re; + + let (r0_re, r1_re, r2_re, r3_re) = transpose4x4_f64!(p0_re, p1_re, p2_re, p3_re); + let (r0_im, r1_im, r2_im, r3_im) = transpose4x4_f64!(p0_im, p1_im, p2_im, p3_im); + + $va_re = r0_re; + $vb_re = r1_re; + $vc_re = r2_re; + $vd_re = r3_re; + $va_im = r0_im; + $vb_im = r1_im; + $vc_im = r2_im; + $vd_im = r3_im; + }}; + } - // Stage 3: dist=8, chunk_size=16, W_16 twiddles via f64x8 - { - let tw_re = f64x8::simd_from( - simd, - [ - 1.0, // W_16^0 - 0.9238795325112867, // W_16^1 - std::f64::consts::FRAC_1_SQRT_2, // W_16^2 - 0.38268343236508984, // W_16^3 - 0.0, // W_16^4 - -0.38268343236508984, // W_16^5 - -std::f64::consts::FRAC_1_SQRT_2, // W_16^6 - -0.9238795325112867, // W_16^7 - ], - ); - let tw_im = f64x8::simd_from( - simd, - [ - 0.0, // W_16^0 - -0.38268343236508984, // W_16^1 - -std::f64::consts::FRAC_1_SQRT_2, // W_16^2 - -0.9238795325112867, // W_16^3 - -1.0, // W_16^4 - -0.9238795325112867, // W_16^5 - -std::f64::consts::FRAC_1_SQRT_2, // W_16^6 - -0.38268343236508984, // W_16^7 - ], - ); - let two = f64x8::splat(simd, 2.0); - - (reals.as_chunks_mut::<16>().0.iter_mut()) - .zip(imags.as_chunks_mut::<16>().0.iter_mut()) - .for_each(|(re16, im16)| { - let (re_lo, re_hi) = re16.split_at_mut(8); - let (im_lo, im_hi) = im16.split_at_mut(8); - - let in0_re = f64x8::from_slice(simd, re_lo); - let in1_re = f64x8::from_slice(simd, re_hi); - let in0_im = f64x8::from_slice(simd, im_lo); - let in1_im = f64x8::from_slice(simd, im_hi); - - let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); - let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(re_lo); - out0_im.store_slice(im_lo); - out1_re.store_slice(re_hi); - out1_im.store_slice(im_hi); - }); - } + // ---- Load and do stages 0+1 ---- + let mut v0_re = f64x4::from_slice(simd, &re[0..4]); + let mut v1_re = f64x4::from_slice(simd, &re[4..8]); + let mut v2_re = f64x4::from_slice(simd, &re[8..12]); + let mut v3_re = f64x4::from_slice(simd, &re[12..16]); + let mut v0_im = f64x4::from_slice(simd, &im[0..4]); + let mut v1_im = f64x4::from_slice(simd, &im[4..8]); + let mut v2_im = f64x4::from_slice(simd, &im[8..12]); + let mut v3_im = f64x4::from_slice(simd, &im[12..16]); + + radix4_transpose!(v0_re, v1_re, v2_re, v3_re, v0_im, v1_im, v2_im, v3_im); + + // Butterfly macro: twiddle-multiply hi, then add/sub with lo. + // out_lo = lo + tw*hi, out_hi = lo - tw*hi (via 2*lo - out_lo). + macro_rules! butterfly { + ($lo_re:expr, $lo_im:expr, $hi_re:expr, $hi_im:expr, $tw_re:expr, $tw_im:expr) => {{ + let out_lo_re = $tw_im.mul_add(-$hi_im, $tw_re.mul_add($hi_re, $lo_re)); + let out_lo_im = $tw_im.mul_add($hi_re, $tw_re.mul_add($hi_im, $lo_im)); + let out_hi_re = two.mul_sub($lo_re, out_lo_re); + let out_hi_im = two.mul_sub($lo_im, out_lo_im); + $lo_re = out_lo_re; + $lo_im = out_lo_im; + $hi_re = out_hi_re; + $hi_im = out_hi_im; + }}; + } + + // ---- Stage 2: dist=4, W_8 twiddles ---- + // Butterfly pairs: (v0,v1), (v2,v3) + // All pairs use the same twiddle: W_8^{0,1,2,3} + { + let tw_re = f64x4::simd_from( + simd, + [ + 1.0, // W_8^0 + std::f64::consts::FRAC_1_SQRT_2, // W_8^1 + 0.0, // W_8^2 + -std::f64::consts::FRAC_1_SQRT_2, // W_8^3 + ], + ); + let tw_im = f64x4::simd_from( + simd, + [ + 0.0, // W_8^0 + -std::f64::consts::FRAC_1_SQRT_2, // W_8^1 + -1.0, // W_8^2 + -std::f64::consts::FRAC_1_SQRT_2, // W_8^3 + ], + ); + + butterfly!(v0_re, v0_im, v1_re, v1_im, tw_re, tw_im); + butterfly!(v2_re, v2_im, v3_re, v3_im, tw_re, tw_im); + } - // Stage 4: dist=16, chunk_size=32, W_32 twiddles via 2x f64x8 - { - let tw_re_0_7 = f64x8::simd_from( - simd, - [ - 1.0, // W_32^0 - 0.9807852804032304, // W_32^1 - 0.9238795325112867, // W_32^2 - 0.8314696123025452, // W_32^3 - std::f64::consts::FRAC_1_SQRT_2, // W_32^4 - 0.5555702330196022, // W_32^5 - 0.3826834323650898, // W_32^6 - 0.19509032201612825, // W_32^7 - ], - ); - let tw_im_0_7 = f64x8::simd_from( - simd, - [ - 0.0, // W_32^0 - -0.19509032201612825, // W_32^1 - -0.3826834323650898, // W_32^2 - -0.5555702330196022, // W_32^3 - -std::f64::consts::FRAC_1_SQRT_2, // W_32^4 - -0.8314696123025452, // W_32^5 - -0.9238795325112867, // W_32^6 - -0.9807852804032304, // W_32^7 - ], - ); - let tw_re_8_15 = f64x8::simd_from( - simd, - [ - 0.0, // W_32^8 - -0.19509032201612825, // W_32^9 - -0.3826834323650898, // W_32^10 - -0.5555702330196022, // W_32^11 - -std::f64::consts::FRAC_1_SQRT_2, // W_32^12 - -0.8314696123025452, // W_32^13 - -0.9238795325112867, // W_32^14 - -0.9807852804032304, // W_32^15 - ], - ); - let tw_im_8_15 = f64x8::simd_from( - simd, - [ - -1.0, // W_32^8 - -0.9807852804032304, // W_32^9 - -0.9238795325112867, // W_32^10 - -0.8314696123025452, // W_32^11 - -std::f64::consts::FRAC_1_SQRT_2, // W_32^12 - -0.5555702330196022, // W_32^13 - -0.3826834323650898, // W_32^14 - -0.19509032201612825, // W_32^15 - ], - ); - let two = f64x8::splat(simd, 2.0); - - for (re32, im32) in reals - .as_chunks_mut::<32>() - .0 - .iter_mut() - .zip(imags.as_chunks_mut::<32>().0.iter_mut()) + // ---- Stage 3: dist=8, W_16 twiddles ---- + // Butterfly pairs: (v0,v2), (v1,v3) + // (v0,v2) uses W_16^{0,1,2,3} + // (v1,v3) uses W_16^{4,5,6,7} { - let (re_lo, re_hi) = re32.split_at_mut(16); - let (im_lo, im_hi) = im32.split_at_mut(16); - - // Batch 0: elements [0..8] and [16..24] with W_32^{0..7} - let in0_re = f64x8::from_slice(simd, &re_lo[0..8]); - let in1_re = f64x8::from_slice(simd, &re_hi[0..8]); - let in0_im = f64x8::from_slice(simd, &im_lo[0..8]); - let in1_im = f64x8::from_slice(simd, &im_hi[0..8]); - - let out0_re = tw_im_0_7.mul_add(-in1_im, tw_re_0_7.mul_add(in1_re, in0_re)); - let out0_im = tw_im_0_7.mul_add(in1_re, tw_re_0_7.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(&mut re_lo[0..8]); - out0_im.store_slice(&mut im_lo[0..8]); - out1_re.store_slice(&mut re_hi[0..8]); - out1_im.store_slice(&mut im_hi[0..8]); - - // Batch 1: elements [8..16] and [24..32] with W_32^{8..15} - let in0_re = f64x8::from_slice(simd, &re_lo[8..16]); - let in1_re = f64x8::from_slice(simd, &re_hi[8..16]); - let in0_im = f64x8::from_slice(simd, &im_lo[8..16]); - let in1_im = f64x8::from_slice(simd, &im_hi[8..16]); - - let out0_re = tw_im_8_15.mul_add(-in1_im, tw_re_8_15.mul_add(in1_re, in0_re)); - let out0_im = tw_im_8_15.mul_add(in1_re, tw_re_8_15.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(&mut re_lo[8..16]); - out0_im.store_slice(&mut im_lo[8..16]); - out1_re.store_slice(&mut re_hi[8..16]); - out1_im.store_slice(&mut im_hi[8..16]); + let tw_lo_re = f64x4::simd_from( + simd, + [ + 1.0, // W_16^0 + 0.9238795325112867, // W_16^1 + std::f64::consts::FRAC_1_SQRT_2, // W_16^2 + 0.38268343236508984, // W_16^3 + ], + ); + let tw_lo_im = f64x4::simd_from( + simd, + [ + 0.0, // W_16^0 + -0.38268343236508984, // W_16^1 + -std::f64::consts::FRAC_1_SQRT_2, // W_16^2 + -0.9238795325112867, // W_16^3 + ], + ); + let tw_hi_re = f64x4::simd_from( + simd, + [ + 0.0, // W_16^4 + -0.38268343236508984, // W_16^5 + -std::f64::consts::FRAC_1_SQRT_2, // W_16^6 + -0.9238795325112867, // W_16^7 + ], + ); + let tw_hi_im = f64x4::simd_from( + simd, + [ + -1.0, // W_16^4 + -0.9238795325112867, // W_16^5 + -std::f64::consts::FRAC_1_SQRT_2, // W_16^6 + -0.38268343236508984, // W_16^7 + ], + ); + + butterfly!(v0_re, v0_im, v2_re, v2_im, tw_lo_re, tw_lo_im); + butterfly!(v1_re, v1_im, v3_re, v3_im, tw_hi_re, tw_hi_im); } + + // ---- Store all vectors back ---- + v0_re.store_slice(&mut re[0..4]); + v1_re.store_slice(&mut re[4..8]); + v2_re.store_slice(&mut re[8..12]); + v3_re.store_slice(&mut re[12..16]); + + v0_im.store_slice(&mut im[0..4]); + v1_im.store_slice(&mut im[4..8]); + v2_im.store_slice(&mut im[8..12]); + v3_im.store_slice(&mut im[12..16]); } } -/// FFT-32 codelet for f32: executes stages 0-4 in a single function. +/// FFT-32 codelet for `f32`: executes stages 0-4 (chunk_size 2 through 32) in a single function. +/// +/// Register-resident implementation using `f32x8`: all 32 complex values are loaded into +/// 4 `f32x8` re + 4 `f32x8` im vectors, all 5 butterfly stages execute in registers with +/// no intermediate memory traffic, then results are stored back. #[inline(never)] pub fn fft_dit_codelet_32_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { simd.vectorize( @@ -254,200 +224,276 @@ pub fn fft_dit_codelet_32_f32(simd: S, reals: &mut [f32], imags: &mut [ #[inline(always)] fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { - // Fused stages 0+1: radix-2^2 4-point DIT in a single sweep - reals - .chunks_exact_mut(4) - .zip(imags.chunks_exact_mut(4)) - .for_each(|(re, im)| { - let (a_re, a_im) = (re[0], im[0]); - let (b_re, b_im) = (re[1], im[1]); - let (c_re, c_im) = (re[2], im[2]); - let (d_re, d_im) = (re[3], im[3]); - - // Stage 0: dist=1 butterflies on pairs (a,b) and (c,d) - let (t0_re, t0_im) = (a_re + b_re, a_im + b_im); - let (t1_re, t1_im) = (a_re - b_re, a_im - b_im); - let (t2_re, t2_im) = (c_re + d_re, c_im + d_im); - let (t3_re, t3_im) = (c_re - d_re, c_im - d_im); - - // Stage 1: -j multiply on t3, then dist=2 butterflies - let (t3j_re, t3j_im) = (t3_im, -t3_re); - - re[0] = t0_re + t2_re; - im[0] = t0_im + t2_im; - re[1] = t1_re + t3j_re; - im[1] = t1_im + t3j_im; - re[2] = t0_re - t2_re; - im[2] = t0_im - t2_im; - re[3] = t1_re - t3j_re; - im[3] = t1_im - t3j_im; - }); + assert_eq!(reals.len(), imags.len()); + + let two = f32x8::splat(simd, 2.0); + + for (re, im) in reals.chunks_exact_mut(32).zip(imags.chunks_exact_mut(32)) { + // ---- Load into 4 f32x8 register pairs (re + im) ---- + // v_k holds elements [8k .. 8k+7] + let mut v0_re = f32x8::from_slice(simd, &re[0..8]); + let mut v1_re = f32x8::from_slice(simd, &re[8..16]); + let mut v2_re = f32x8::from_slice(simd, &re[16..24]); + let mut v3_re = f32x8::from_slice(simd, &re[24..32]); + + let mut v0_im = f32x8::from_slice(simd, &im[0..8]); + let mut v1_im = f32x8::from_slice(simd, &im[8..16]); + let mut v2_im = f32x8::from_slice(simd, &im[16..24]); + let mut v3_im = f32x8::from_slice(simd, &im[24..32]); + + // ---- Stages 0+1+2 fused: 8-point DIT on all 4 vectors via transpose ---- + // To reduce register pressure, process lo halves (elements 0-3) through + // stages 0+1 first, then hi halves (elements 4-7) through stages 0+1. + // Data stays in transposed (per-element) layout between stages 0+1 and + // stage 2, avoiding redundant transpose pairs. Only one inverse transpose + // at the end. Peak active register usage during stages 0+1: ~8 f32x4. + { + macro_rules! transpose4x4 { + ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ + let (t0, t1) = interleave_f32x4($g0, $g2); + let (t2, t3) = interleave_f32x4($g1, $g3); + let (r0, r1) = interleave_f32x4(t0, t2); + let (r2, r3) = interleave_f32x4(t1, t3); + (r0, r1, r2, r3) + }}; + } - // Stage 2: dist=4, chunk_size=8, W_8 twiddles via f32x4 - { - let tw_re = f32x4::simd_from( - simd, - [ - 1.0_f32, // W_8^0 - std::f32::consts::FRAC_1_SQRT_2, // W_8^1 - 0.0_f32, // W_8^2 - -std::f32::consts::FRAC_1_SQRT_2, // W_8^3 - ], - ); - let tw_im = f32x4::simd_from( - simd, - [ - 0.0_f32, // W_8^0 - -std::f32::consts::FRAC_1_SQRT_2, // W_8^1 - -1.0_f32, // W_8^2 - -std::f32::consts::FRAC_1_SQRT_2, // W_8^3 - ], - ); - let two = f32x4::splat(simd, 2.0); - - (reals.as_chunks_mut::<8>().0.iter_mut()) - .zip(imags.as_chunks_mut::<8>().0.iter_mut()) - .for_each(|(re8, im8)| { - let (re_lo, re_hi) = re8.split_at_mut(4); - let (im_lo, im_hi) = im8.split_at_mut(4); - - let in0_re = f32x4::from_slice(simd, re_lo); - let in1_re = f32x4::from_slice(simd, re_hi); - let in0_im = f32x4::from_slice(simd, im_lo); - let in1_im = f32x4::from_slice(simd, im_hi); - - let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); - let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(re_lo); - out0_im.store_slice(im_lo); - out1_re.store_slice(re_hi); - out1_im.store_slice(im_hi); - }); - } + // Stages 0+1 (radix-4 DIT) on transposed data. + // Input: 4 f32x4 in per-group layout. Output: 4 f32x4 in per-element + // (transposed) layout — p0, p1, p2, p3 where each lane is one group. + macro_rules! radix4_transpose_fwd { + ($ga_re:expr, $gb_re:expr, $gc_re:expr, $gd_re:expr, + $ga_im:expr, $gb_im:expr, $gc_im:expr, $gd_im:expr) => {{ + let (e0_re, e1_re, e2_re, e3_re) = + transpose4x4!($ga_re, $gb_re, $gc_re, $gd_re); + let (e0_im, e1_im, e2_im, e3_im) = + transpose4x4!($ga_im, $gb_im, $gc_im, $gd_im); + + // Stage 0 (dist=1) + let s01_re = e0_re + e1_re; + let d01_re = e0_re - e1_re; + let s23_re = e2_re + e3_re; + let d23_re = e2_re - e3_re; + let s01_im = e0_im + e1_im; + let d01_im = e0_im - e1_im; + let s23_im = e2_im + e3_im; + let d23_im = e2_im - e3_im; + + // Stage 1 (dist=2): W4^0=1, W4^1=-j + let p0_re = s01_re + s23_re; + let p2_re = s01_re - s23_re; + let p0_im = s01_im + s23_im; + let p2_im = s01_im - s23_im; + + let p1_re = d01_re + d23_im; + let p3_re = d01_re - d23_im; + let p1_im = d01_im - d23_re; + let p3_im = d01_im + d23_re; + + // Return in per-element (transposed) layout + (p0_re, p1_re, p2_re, p3_re, p0_im, p1_im, p2_im, p3_im) + }}; + } - // Stage 3: dist=8, chunk_size=16, W_16 twiddles via f32x8 - { - let tw_re = f32x8::simd_from( - simd, - [ - 1.0_f32, // W_16^0 - 0.923_879_5_f32, // W_16^1 - std::f32::consts::FRAC_1_SQRT_2, // W_16^2 - 0.382_683_43_f32, // W_16^3 - 0.0_f32, // W_16^4 - -0.382_683_43_f32, // W_16^5 - -std::f32::consts::FRAC_1_SQRT_2, // W_16^6 - -0.923_879_5_f32, // W_16^7 - ], - ); - let tw_im = f32x8::simd_from( - simd, - [ - 0.0_f32, // W_16^0 - -0.382_683_43_f32, // W_16^1 - -std::f32::consts::FRAC_1_SQRT_2, // W_16^2 - -0.923_879_5_f32, // W_16^3 - -1.0_f32, // W_16^4 - -0.923_879_5_f32, // W_16^5 - -std::f32::consts::FRAC_1_SQRT_2, // W_16^6 - -0.382_683_43_f32, // W_16^7 - ], - ); - let two = f32x8::splat(simd, 2.0); - - (reals.as_chunks_mut::<16>().0.iter_mut()) - .zip(imags.as_chunks_mut::<16>().0.iter_mut()) - .for_each(|(re16, im16)| { - let (re_lo, re_hi) = re16.split_at_mut(8); - let (im_lo, im_hi) = im16.split_at_mut(8); - - let in0_re = f32x8::from_slice(simd, re_lo); - let in1_re = f32x8::from_slice(simd, re_hi); - let in0_im = f32x8::from_slice(simd, im_lo); - let in1_im = f32x8::from_slice(simd, im_hi); - - let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); - let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(re_lo); - out0_im.store_slice(im_lo); - out1_re.store_slice(re_hi); - out1_im.store_slice(im_hi); - }); - } + // Process lo halves (elements 0-3) — result in transposed layout + let (g0_lo_re, g0_hi_re) = v0_re.split(); + let (g1_lo_re, g1_hi_re) = v1_re.split(); + let (g2_lo_re, g2_hi_re) = v2_re.split(); + let (g3_lo_re, g3_hi_re) = v3_re.split(); + let (g0_lo_im, g0_hi_im) = v0_im.split(); + let (g1_lo_im, g1_hi_im) = v1_im.split(); + let (g2_lo_im, g2_hi_im) = v2_im.split(); + let (g3_lo_im, g3_hi_im) = v3_im.split(); + + let (p0_re, p1_re, p2_re, p3_re, p0_im, p1_im, p2_im, p3_im) = radix4_transpose_fwd!( + g0_lo_re, g1_lo_re, g2_lo_re, g3_lo_re, g0_lo_im, g1_lo_im, g2_lo_im, g3_lo_im + ); + + // Process hi halves (elements 4-7) — result in transposed layout + let (p4_re, p5_re, p6_re, p7_re, p4_im, p5_im, p6_im, p7_im) = radix4_transpose_fwd!( + g0_hi_re, g1_hi_re, g2_hi_re, g3_hi_re, g0_hi_im, g1_hi_im, g2_hi_im, g3_hi_im + ); + + // Stage 2 (dist=4) — W8^k twiddles, already in per-element layout + + // W8^0 = 1+0j: just add/sub + let r0_re = p0_re + p4_re; + let r4_re = p0_re - p4_re; + let r0_im = p0_im + p4_im; + let r4_im = p0_im - p4_im; + + // W8^1 = (1-j)/√2: twiddle * p5, expressed to avoid double-negation. + // tw = (s, -s) where s = 1/√2. + // tw*p5 = (s*p5_re + s*p5_im, s*p5_im - s*p5_re) + const FRAC_1_SQRT_2: f32 = std::f32::consts::FRAC_1_SQRT_2; + let s = f32x4::splat(simd, FRAC_1_SQRT_2); + let tw_p5_re = s.mul_add(p5_im, s * p5_re); + let tw_p5_im = s.mul_sub(p5_im, s * p5_re); + let r1_re = p1_re + tw_p5_re; + let r5_re = p1_re - tw_p5_re; + let r1_im = p1_im + tw_p5_im; + let r5_im = p1_im - tw_p5_im; + + // W8^2 = 0-j: -j*(re+j*im) = (im, -re) + let r2_re = p2_re + p6_im; + let r6_re = p2_re - p6_im; + let r2_im = p2_im - p6_re; + let r6_im = p2_im + p6_re; + + // W8^3 = (-1-j)/√2: twiddle * p7, expressed to avoid double-negation. + // tw = (-s, -s) where s = 1/√2. + // tw*p7_re = s*p7_im - s*p7_re (fmsub, clean) + // tw*p7_im = -(s*p7_im + s*p7_re), but we compute the positive form + // neg_tw_p7_im = s*p7_im + s*p7_re (fmadd, clean) + // and swap the +/- in the butterfly for the im component. + let tw_p7_re = s.mul_sub(p7_im, s * p7_re); + let neg_tw_p7_im = s.mul_add(p7_im, s * p7_re); + let r3_re = p3_re + tw_p7_re; + let r7_re = p3_re - tw_p7_re; + let r3_im = p3_im - neg_tw_p7_im; + let r7_im = p3_im + neg_tw_p7_im; + + // Single inverse transpose and recombine f32x4 → f32x8 + let (g0_lo_re, g1_lo_re, g2_lo_re, g3_lo_re) = + transpose4x4!(r0_re, r1_re, r2_re, r3_re); + let (g0_hi_re, g1_hi_re, g2_hi_re, g3_hi_re) = + transpose4x4!(r4_re, r5_re, r6_re, r7_re); + let (g0_lo_im, g1_lo_im, g2_lo_im, g3_lo_im) = + transpose4x4!(r0_im, r1_im, r2_im, r3_im); + let (g0_hi_im, g1_hi_im, g2_hi_im, g3_hi_im) = + transpose4x4!(r4_im, r5_im, r6_im, r7_im); + + v0_re = g0_lo_re.combine(g0_hi_re); + v1_re = g1_lo_re.combine(g1_hi_re); + v2_re = g2_lo_re.combine(g2_hi_re); + v3_re = g3_lo_re.combine(g3_hi_re); + v0_im = g0_lo_im.combine(g0_hi_im); + v1_im = g1_lo_im.combine(g1_hi_im); + v2_im = g2_lo_im.combine(g2_hi_im); + v3_im = g3_lo_im.combine(g3_hi_im); + } + + // Butterfly macro: twiddle-multiply hi, then add/sub with lo. + // out_lo = lo + tw*hi, out_hi = lo - tw*hi (via 2*lo - out_lo). + macro_rules! butterfly { + ($lo_re:expr, $lo_im:expr, $hi_re:expr, $hi_im:expr, $tw_re:expr, $tw_im:expr) => {{ + let out_lo_re = $tw_im.mul_add(-$hi_im, $tw_re.mul_add($hi_re, $lo_re)); + let out_lo_im = $tw_im.mul_add($hi_re, $tw_re.mul_add($hi_im, $lo_im)); + let out_hi_re = two.mul_sub($lo_re, out_lo_re); + let out_hi_im = two.mul_sub($lo_im, out_lo_im); + $lo_re = out_lo_re; + $lo_im = out_lo_im; + $hi_re = out_hi_re; + $hi_im = out_hi_im; + }}; + } + + // ---- Stage 3: dist=8, W_16 twiddles ---- + // Butterfly pairs: (v0,v1), (v2,v3) + // Both pairs use the same twiddle: W_16^{0..7} + { + let tw_re = f32x8::simd_from( + simd, + [ + 1.0_f32, // W_16^0 + 0.923_879_5_f32, // W_16^1 + std::f32::consts::FRAC_1_SQRT_2, // W_16^2 + 0.382_683_43_f32, // W_16^3 + 0.0_f32, // W_16^4 + -0.382_683_43_f32, // W_16^5 + -std::f32::consts::FRAC_1_SQRT_2, // W_16^6 + -0.923_879_5_f32, // W_16^7 + ], + ); + let tw_im = f32x8::simd_from( + simd, + [ + 0.0_f32, // W_16^0 + -0.382_683_43_f32, // W_16^1 + -std::f32::consts::FRAC_1_SQRT_2, // W_16^2 + -0.923_879_5_f32, // W_16^3 + -1.0_f32, // W_16^4 + -0.923_879_5_f32, // W_16^5 + -std::f32::consts::FRAC_1_SQRT_2, // W_16^6 + -0.382_683_43_f32, // W_16^7 + ], + ); - // Stage 4: dist=16, chunk_size=32, W_32 twiddles via f32x16 - { - let tw_re = f32x16::simd_from( - simd, - [ - 1.0_f32, // W_32^0 - 0.980_785_25_f32, // W_32^1 - 0.923_879_5_f32, // W_32^2 - 0.831_469_6_f32, // W_32^3 - std::f32::consts::FRAC_1_SQRT_2, // W_32^4 - 0.555_570_24_f32, // W_32^5 - 0.382_683_43_f32, // W_32^6 - 0.195_090_32_f32, // W_32^7 - 0.0_f32, // W_32^8 - -0.195_090_32_f32, // W_32^9 - -0.382_683_43_f32, // W_32^10 - -0.555_570_24_f32, // W_32^11 - -std::f32::consts::FRAC_1_SQRT_2, // W_32^12 - -0.831_469_6_f32, // W_32^13 - -0.923_879_5_f32, // W_32^14 - -0.980_785_25_f32, // W_32^15 - ], - ); - let tw_im = f32x16::simd_from( - simd, - [ - 0.0_f32, // W_32^0 - -0.195_090_32_f32, // W_32^1 - -0.382_683_43_f32, // W_32^2 - -0.555_570_24_f32, // W_32^3 - -std::f32::consts::FRAC_1_SQRT_2, // W_32^4 - -0.831_469_6_f32, // W_32^5 - -0.923_879_5_f32, // W_32^6 - -0.980_785_25_f32, // W_32^7 - -1.0_f32, // W_32^8 - -0.980_785_25_f32, // W_32^9 - -0.923_879_5_f32, // W_32^10 - -0.831_469_6_f32, // W_32^11 - -std::f32::consts::FRAC_1_SQRT_2, // W_32^12 - -0.555_570_24_f32, // W_32^13 - -0.382_683_43_f32, // W_32^14 - -0.195_090_32_f32, // W_32^15 - ], - ); - let two = f32x16::splat(simd, 2.0); - - (reals.as_chunks_mut::<32>().0.iter_mut()) - .zip(imags.as_chunks_mut::<32>().0.iter_mut()) - .for_each(|(re32, im32)| { - let (re_lo, re_hi) = re32.split_at_mut(16); - let (im_lo, im_hi) = im32.split_at_mut(16); - - let in0_re = f32x16::from_slice(simd, re_lo); - let in1_re = f32x16::from_slice(simd, re_hi); - let in0_im = f32x16::from_slice(simd, im_lo); - let in1_im = f32x16::from_slice(simd, im_hi); - - let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); - let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(re_lo); - out0_im.store_slice(im_lo); - out1_re.store_slice(re_hi); - out1_im.store_slice(im_hi); - }); + butterfly!(v0_re, v0_im, v1_re, v1_im, tw_re, tw_im); + butterfly!(v2_re, v2_im, v3_re, v3_im, tw_re, tw_im); + } + + // ---- Stage 4: dist=16, W_32 twiddles ---- + // Butterfly pairs: (v0,v2), (v1,v3) + // (v0,v2) uses W_32^{0..7}, (v1,v3) uses W_32^{8..15} + { + let tw_lo_re = f32x8::simd_from( + simd, + [ + 1.0_f32, // W_32^0 + 0.980_785_25_f32, // W_32^1 + 0.923_879_5_f32, // W_32^2 + 0.831_469_6_f32, // W_32^3 + std::f32::consts::FRAC_1_SQRT_2, // W_32^4 + 0.555_570_24_f32, // W_32^5 + 0.382_683_43_f32, // W_32^6 + 0.195_090_32_f32, // W_32^7 + ], + ); + let tw_lo_im = f32x8::simd_from( + simd, + [ + 0.0_f32, // W_32^0 + -0.195_090_32_f32, // W_32^1 + -0.382_683_43_f32, // W_32^2 + -0.555_570_24_f32, // W_32^3 + -std::f32::consts::FRAC_1_SQRT_2, // W_32^4 + -0.831_469_6_f32, // W_32^5 + -0.923_879_5_f32, // W_32^6 + -0.980_785_25_f32, // W_32^7 + ], + ); + let tw_hi_re = f32x8::simd_from( + simd, + [ + 0.0_f32, // W_32^8 + -0.195_090_32_f32, // W_32^9 + -0.382_683_43_f32, // W_32^10 + -0.555_570_24_f32, // W_32^11 + -std::f32::consts::FRAC_1_SQRT_2, // W_32^12 + -0.831_469_6_f32, // W_32^13 + -0.923_879_5_f32, // W_32^14 + -0.980_785_25_f32, // W_32^15 + ], + ); + let tw_hi_im = f32x8::simd_from( + simd, + [ + -1.0_f32, // W_32^8 + -0.980_785_25_f32, // W_32^9 + -0.923_879_5_f32, // W_32^10 + -0.831_469_6_f32, // W_32^11 + -std::f32::consts::FRAC_1_SQRT_2, // W_32^12 + -0.555_570_24_f32, // W_32^13 + -0.382_683_43_f32, // W_32^14 + -0.195_090_32_f32, // W_32^15 + ], + ); + + butterfly!(v0_re, v0_im, v2_re, v2_im, tw_lo_re, tw_lo_im); + butterfly!(v1_re, v1_im, v3_re, v3_im, tw_hi_re, tw_hi_im); + } + + // ---- Store all vectors back ---- + v0_re.store_slice(&mut re[0..8]); + v1_re.store_slice(&mut re[8..16]); + v2_re.store_slice(&mut re[16..24]); + v3_re.store_slice(&mut re[24..32]); + + v0_im.store_slice(&mut im[0..8]); + v1_im.store_slice(&mut im[8..16]); + v2_im.store_slice(&mut im[16..24]); + v3_im.store_slice(&mut im[24..32]); } } @@ -456,13 +502,12 @@ mod tests { use super::*; use crate::kernels::dit::*; - fn run_stages_0_to_4_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { - assert_eq!(reals.len(), 32); + fn run_stages_0_to_3_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { + assert_eq!(reals.len(), 16); fft_dit_chunk_2(simd, reals, imags); fft_dit_chunk_4_f64(simd, reals, imags); fft_dit_chunk_8_f64(simd, reals, imags); fft_dit_chunk_16_f64(simd, reals, imags); - fft_dit_chunk_32_f64(simd, reals, imags); } fn run_stages_0_to_4_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { @@ -475,14 +520,14 @@ mod tests { } #[test] - fn codelet_32_f64_matches_staged() { + fn codelet_16_f64_matches_staged() { use fearless_simd::dispatch; let simd_level = fearless_simd::Level::new(); // Test with a simple impulse signal - let mut re_staged = vec![0.0f64; 32]; - let mut im_staged = vec![0.0f64; 32]; + let mut re_staged = vec![0.0f64; 16]; + let mut im_staged = vec![0.0f64; 16]; re_staged[0] = 1.0; let mut re_codelet = re_staged.clone(); @@ -491,12 +536,12 @@ mod tests { dispatch!(simd_level, simd => { simd.vectorize( #[inline(always)] - || run_stages_0_to_4_f64(simd, &mut re_staged, &mut im_staged), + || run_stages_0_to_3_f64(simd, &mut re_staged, &mut im_staged), ); - fft_dit_codelet_32_f64(simd, &mut re_codelet, &mut im_codelet); + fft_dit_codelet_16_f64(simd, &mut re_codelet, &mut im_codelet); }); - for i in 0..32 { + for i in 0..16 { assert!( (re_staged[i] - re_codelet[i]).abs() < 1e-14, "re[{i}]: staged={}, codelet={}", @@ -511,8 +556,8 @@ mod tests { ); } - let mut re_staged: Vec = (1..=32).map(|i| i as f64).collect(); - let mut im_staged: Vec = (1..=32).map(|i| -(i as f64) * 0.5).collect(); + let mut re_staged: Vec = (1..=16).map(|i| i as f64).collect(); + let mut im_staged: Vec = (1..=16).map(|i| -(i as f64) * 0.5).collect(); let mut re_codelet = re_staged.clone(); let mut im_codelet = im_staged.clone(); @@ -520,12 +565,12 @@ mod tests { dispatch!(simd_level, simd => { simd.vectorize( #[inline(always)] - || run_stages_0_to_4_f64(simd, &mut re_staged, &mut im_staged), + || run_stages_0_to_3_f64(simd, &mut re_staged, &mut im_staged), ); - fft_dit_codelet_32_f64(simd, &mut re_codelet, &mut im_codelet); + fft_dit_codelet_16_f64(simd, &mut re_codelet, &mut im_codelet); }); - for i in 0..32 { + for i in 0..16 { assert!( (re_staged[i] - re_codelet[i]).abs() < 1e-10, "re[{i}]: staged={}, codelet={}", @@ -608,12 +653,12 @@ mod tests { } #[test] - fn codelet_32_f64_multi_chunk() { + fn codelet_16_f64_multi_chunk() { use fearless_simd::dispatch; let simd_level = fearless_simd::Level::new(); - // Test that the codelet correctly processes multiple 32-element chunks + // Test that the codelet correctly processes multiple 16-element chunks let n = 128; let mut re_staged: Vec = (0..n).map(|i| (i as f64) * 0.1).collect(); let mut im_staged: Vec = (0..n).map(|i| -(i as f64) * 0.05).collect(); @@ -623,17 +668,17 @@ mod tests { dispatch!(simd_level, simd => { // Run individual stage kernels on all chunks - for chunk_start in (0..n).step_by(32) { - let re = &mut re_staged[chunk_start..chunk_start + 32]; - let im = &mut im_staged[chunk_start..chunk_start + 32]; + for chunk_start in (0..n).step_by(16) { + let re = &mut re_staged[chunk_start..chunk_start + 16]; + let im = &mut im_staged[chunk_start..chunk_start + 16]; simd.vectorize( #[inline(always)] - || run_stages_0_to_4_f64(simd, re, im), + || run_stages_0_to_3_f64(simd, re, im), ); } // Run codelet on the full array - fft_dit_codelet_32_f64(simd, &mut re_codelet, &mut im_codelet); + fft_dit_codelet_16_f64(simd, &mut re_codelet, &mut im_codelet); }); for i in 0..n { diff --git a/src/lib.rs b/src/lib.rs index 5c8e03a..ecb3d2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -394,27 +394,6 @@ mod tests { } } - #[test] - fn heuristic_disables_codelet_for_small_sizes() { - let planner = PlannerDit64::new(16, Direction::Forward); // log_n = 4 - assert!(!planner.use_codelet_32); - } - - #[test] - fn heuristic_enables_codelet_at_threshold() { - let planner = PlannerDit64::new(32, Direction::Forward); // log_n = 5 - assert!(planner.use_codelet_32); - - let planner = PlannerDit64::new(8192, Direction::Forward); // log_n = 13 - assert!(planner.use_codelet_32); - } - - #[test] - fn heuristic_disables_codelet_above_threshold() { - let planner = PlannerDit64::new(16384, Direction::Forward); // log_n = 14 - assert!(!planner.use_codelet_32); - } - #[test] fn tune_mode_does_not_panic() { use crate::planner::PlannerMode; @@ -451,66 +430,4 @@ mod tests { } } } - - #[test] - fn codelet_forced_on_above_heuristic_threshold_f64() { - for n in 14..=15 { - let size = 1 << n; - let mut reals_original = vec![0.0f64; size]; - let mut imags_original = vec![0.0f64; size]; - gen_random_signal_f64(&mut reals_original, &mut imags_original); - - let mut reals = reals_original.clone(); - let mut imags = imags_original.clone(); - - let mut fwd = PlannerDit64::new(size, Direction::Forward); - assert!( - !fwd.use_codelet_32, - "heuristic should disable codelet at n={n}" - ); - fwd.use_codelet_32 = true; - - let mut inv = PlannerDit64::new(size, Direction::Reverse); - inv.use_codelet_32 = true; - - fft_64_dit_with_planner(&mut reals, &mut imags, &fwd); - fft_64_dit_with_planner(&mut reals, &mut imags, &inv); - - for i in 0..size { - assert_float_closeness(reals[i], reals_original[i], 1e-10); - assert_float_closeness(imags[i], imags_original[i], 1e-10); - } - } - } - - #[test] - fn codelet_forced_on_above_heuristic_threshold_f32() { - for n in 14..=15 { - let size = 1 << n; - let mut reals_original = vec![0.0f32; size]; - let mut imags_original = vec![0.0f32; size]; - gen_random_signal_f32(&mut reals_original, &mut imags_original); - - let mut reals = reals_original.clone(); - let mut imags = imags_original.clone(); - - let mut fwd = PlannerDit32::new(size, Direction::Forward); - assert!( - !fwd.use_codelet_32, - "heuristic should disable codelet at n={n}" - ); - fwd.use_codelet_32 = true; - - let mut inv = PlannerDit32::new(size, Direction::Reverse); - inv.use_codelet_32 = true; - - fft_32_dit_with_planner(&mut reals, &mut imags, &fwd); - fft_32_dit_with_planner(&mut reals, &mut imags, &inv); - - for i in 0..size { - assert_float_closeness(reals[i], reals_original[i], 1e-4); - assert_float_closeness(imags[i], imags_original[i], 1e-4); - } - } - } } diff --git a/src/planner.rs b/src/planner.rs index 0783dea..b718d39 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -42,8 +42,6 @@ macro_rules! impl_planner_dit_for { pub(crate) log_n: usize, /// The level of SIMD instruction support, detected at runtime on x86 and hardcoded elsewhere pub(crate) simd_level: fearless_simd::Level, - /// Whether to use the fused 32-point codelet for stages 0-4 in L1 blocks - pub(crate) use_codelet_32: bool, } impl $struct_name { @@ -61,7 +59,7 @@ macro_rules! impl_planner_dit_for { /// leave performance on the table on platforms with large L1i caches. /// - [`PlannerMode::Tune`]: Benchmarks both paths at plan time. Use this /// when you can afford extra planning time (e.g., planner is reused). - pub fn with_mode(num_points: usize, direction: Direction, mode: PlannerMode) -> Self { + pub fn with_mode(num_points: usize, direction: Direction, _mode: PlannerMode) -> Self { assert!(num_points > 0 && num_points.is_power_of_two()); let simd_level = fearless_simd::Level::new(); @@ -91,109 +89,15 @@ macro_rules! impl_planner_dit_for { } } - let use_codelet_32 = Self::estimate_use_codelet_32(log_n); - - let mut planner = Self { + let planner = Self { stage_twiddles, direction, log_n, simd_level, - use_codelet_32, }; - if matches!(mode, PlannerMode::Tune) { - planner.tune_codelet_32(num_points); - } - planner } - - /// Conservative, arch-independent heuristic for whether the 32-point - /// codelet is beneficial. - /// - /// At small sizes (N ≤ 8192) the codelet dominates runtime and - /// cross-block kernel eviction from the µop cache doesn't matter. - /// At large sizes the codelet's code footprint can evict the - /// cross-block kernel on platforms with small L1i caches (e.g., 32KB - /// on x86). Use [`PlannerMode::Tune`] to discover the real threshold - /// on your hardware. - fn estimate_use_codelet_32(log_n: usize) -> bool { - // Codelet needs at least 32 elements (5 stages) - if log_n < 5 { - return false; - } - - // Conservative threshold: enable only for N ≤ 8192 where the - // codelet dominates runtime. On platforms with large L1i (e.g., - // Apple Silicon at 192KB), Tune mode will discover that the - // codelet wins at larger sizes too. - log_n <= 13 - } - - /// Benchmark both paths and set `use_codelet_32` to whichever is faster. - fn tune_codelet_32(&mut self, num_points: usize) { - if self.log_n < 5 { - self.use_codelet_32 = false; - return; - } - - let opts = crate::options::Options { - multithreaded_bit_reversal: false, - smallest_parallel_chunk_size: usize::MAX, - }; - - // Generate random complex signal via xorshift64 (no rand dependency) - let mut rng_state: u64 = 0x517C_C1B7_2722_0A95; - let mut next_f = || -> $precision { - rng_state ^= rng_state << 13; - rng_state ^= rng_state >> 7; - rng_state ^= rng_state << 17; - (rng_state as $precision) / (u64::MAX as $precision) * 2.0 - 1.0 - }; - let reals_orig: Vec<$precision> = (0..num_points).map(|_| next_f()).collect(); - let imags_orig: Vec<$precision> = (0..num_points).map(|_| next_f()).collect(); - let mut reals = reals_orig.clone(); - let mut imags = imags_orig.clone(); - - const WARMUP: usize = 3; - const ITERS: usize = 5; - - // Time WITHOUT codelet - self.use_codelet_32 = false; - for _ in 0..WARMUP { - reals.copy_from_slice(&reals_orig); - imags.copy_from_slice(&imags_orig); - $fft_func(&mut reals, &mut imags, &*self, &opts); - } - - let mut best_without = std::time::Duration::MAX; - for _ in 0..ITERS { - reals.copy_from_slice(&reals_orig); - imags.copy_from_slice(&imags_orig); - let start = std::time::Instant::now(); - $fft_func(&mut reals, &mut imags, &*self, &opts); - best_without = best_without.min(start.elapsed()); - } - - // Time WITH codelet - self.use_codelet_32 = true; - for _ in 0..WARMUP { - reals.copy_from_slice(&reals_orig); - imags.copy_from_slice(&imags_orig); - $fft_func(&mut reals, &mut imags, &*self, &opts); - } - - let mut best_with = std::time::Duration::MAX; - for _ in 0..ITERS { - reals.copy_from_slice(&reals_orig); - imags.copy_from_slice(&imags_orig); - let start = std::time::Instant::now(); - $fft_func(&mut reals, &mut imags, &*self, &opts); - best_with = best_with.min(start.elapsed()); - } - - self.use_codelet_32 = best_with < best_without; - } } }; }