From 9d1e37a2221985ea421ff3b04106c1020c8d0bac Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 11 Apr 2026 19:26:38 +0100 Subject: [PATCH 01/17] Make the f64 codelet operate entirely in registers --- src/kernels/codelets.rs | 386 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 382 insertions(+), 4 deletions(-) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index 1d622fb..9c3291f 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -5,17 +5,17 @@ //! use fearless_simd::{f32x16, f32x4, f32x8, f64x4, f64x8, Simd, SimdBase, SimdFloat, SimdFrom}; -/// FFT-32 codelet for `f64`: executes stages 0-4 (chunk_size 2 through 32) in a single function. +/// Legacy FFT-32 codelet for `f64`: stage-by-stage in-place with intermediate stores. #[inline(never)] -pub fn fft_dit_codelet_32_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { +pub fn fft_dit_codelet_32_staged_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { simd.vectorize( #[inline(always)] - || fft_dit_codelet_32_simd_f64(simd, reals, imags), + || fft_dit_codelet_32_staged_simd_f64(simd, reals, imags), ) } #[inline(always)] -fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { +fn fft_dit_codelet_32_staged_simd_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { // Fused stages 0+1: radix-2^2 4-point DIT in a single sweep reals .chunks_exact_mut(4) @@ -243,6 +243,288 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut } } +/// FFT-32 codelet for `f64`: executes stages 0-4 (chunk_size 2 through 32) in a single function. +/// +/// Register-resident implementation: all 32 complex values are loaded into f64x4 vectors, +/// all 5 butterfly stages execute in registers with no intermediate memory traffic, +/// then results are stored back. +#[inline(never)] +pub fn fft_dit_codelet_32_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { + simd.vectorize( + #[inline(always)] + || fft_dit_codelet_32_simd_f64(simd, reals, imags), + ) +} + +#[inline(always)] +fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { + assert_eq!(reals.len(), imags.len()); + + let two = f64x4::splat(simd, 2.0); + + for (re, im) in reals.chunks_exact_mut(32).zip(imags.chunks_exact_mut(32)) { + // ---- Load into 8 f64x4 register pairs (re + im) ---- + // v_k holds elements [4k .. 4k+3] + let mut v0_re = f64x4::from_slice(simd, &re[0..4]); + let mut v1_re = f64x4::from_slice(simd, &re[4..8]); + let mut v2_re = f64x4::from_slice(simd, &re[8..12]); + let mut v3_re = f64x4::from_slice(simd, &re[12..16]); + let mut v4_re = f64x4::from_slice(simd, &re[16..20]); + let mut v5_re = f64x4::from_slice(simd, &re[20..24]); + let mut v6_re = f64x4::from_slice(simd, &re[24..28]); + let mut v7_re = f64x4::from_slice(simd, &re[28..32]); + + let mut v0_im = f64x4::from_slice(simd, &im[0..4]); + let mut v1_im = f64x4::from_slice(simd, &im[4..8]); + let mut v2_im = f64x4::from_slice(simd, &im[8..12]); + let mut v3_im = f64x4::from_slice(simd, &im[12..16]); + let mut v4_im = f64x4::from_slice(simd, &im[16..20]); + let mut v5_im = f64x4::from_slice(simd, &im[20..24]); + let mut v6_im = f64x4::from_slice(simd, &im[24..28]); + let mut v7_im = f64x4::from_slice(simd, &im[28..32]); + + // ---- Stages 0+1 fused: radix-4 DIT on each group of 4 (within each vector) ---- + // Extract scalars, compute radix-4 butterfly, repack into f64x4. + macro_rules! radix4_inplace { + ($v_re:expr, $v_im:expr) => {{ + let mut re_arr = [0.0f64; 4]; + let mut im_arr = [0.0f64; 4]; + $v_re.store_slice(&mut re_arr); + $v_im.store_slice(&mut im_arr); + + let (a_re, a_im) = (re_arr[0], im_arr[0]); + let (b_re, b_im) = (re_arr[1], im_arr[1]); + let (c_re, c_im) = (re_arr[2], im_arr[2]); + let (d_re, d_im) = (re_arr[3], im_arr[3]); + + // Stage 0: dist=1 butterflies + let (t0_re, t0_im) = (a_re + b_re, a_im + b_im); + let (t1_re, t1_im) = (a_re - b_re, a_im - b_im); + let (t2_re, t2_im) = (c_re + d_re, c_im + d_im); + let (t3_re, t3_im) = (c_re - d_re, c_im - d_im); + + // Stage 1: -j multiply on t3, then dist=2 butterflies + let (t3j_re, t3j_im) = (t3_im, -t3_re); + + $v_re = f64x4::simd_from( + simd, + [t0_re + t2_re, t1_re + t3j_re, t0_re - t2_re, t1_re - t3j_re], + ); + $v_im = f64x4::simd_from( + simd, + [t0_im + t2_im, t1_im + t3j_im, t0_im - t2_im, t1_im - t3j_im], + ); + }}; + } + + radix4_inplace!(v0_re, v0_im); + radix4_inplace!(v1_re, v1_im); + radix4_inplace!(v2_re, v2_im); + radix4_inplace!(v3_re, v3_im); + radix4_inplace!(v4_re, v4_im); + radix4_inplace!(v5_re, v5_im); + radix4_inplace!(v6_re, v6_im); + radix4_inplace!(v7_re, v7_im); + + // Butterfly macro: twiddle-multiply hi, then add/sub with lo. + // out_lo = lo + tw*hi, out_hi = lo - tw*hi (via 2*lo - out_lo). + macro_rules! butterfly { + ($lo_re:expr, $lo_im:expr, $hi_re:expr, $hi_im:expr, $tw_re:expr, $tw_im:expr) => {{ + let out_lo_re = $tw_im.mul_add(-$hi_im, $tw_re.mul_add($hi_re, $lo_re)); + let out_lo_im = $tw_im.mul_add($hi_re, $tw_re.mul_add($hi_im, $lo_im)); + let out_hi_re = two.mul_sub($lo_re, out_lo_re); + let out_hi_im = two.mul_sub($lo_im, out_lo_im); + $lo_re = out_lo_re; + $lo_im = out_lo_im; + $hi_re = out_hi_re; + $hi_im = out_hi_im; + }}; + } + + // ---- Stage 2: dist=4, W_8 twiddles ---- + // Butterfly pairs: (v0,v1), (v2,v3), (v4,v5), (v6,v7) + // All pairs use the same twiddle: W_8^{0,1,2,3} + { + let tw_re = f64x4::simd_from( + simd, + [ + 1.0, // W_8^0 + std::f64::consts::FRAC_1_SQRT_2, // W_8^1 + 0.0, // W_8^2 + -std::f64::consts::FRAC_1_SQRT_2, // W_8^3 + ], + ); + let tw_im = f64x4::simd_from( + simd, + [ + 0.0, // W_8^0 + -std::f64::consts::FRAC_1_SQRT_2, // W_8^1 + -1.0, // W_8^2 + -std::f64::consts::FRAC_1_SQRT_2, // W_8^3 + ], + ); + + butterfly!(v0_re, v0_im, v1_re, v1_im, tw_re, tw_im); + butterfly!(v2_re, v2_im, v3_re, v3_im, tw_re, tw_im); + butterfly!(v4_re, v4_im, v5_re, v5_im, tw_re, tw_im); + butterfly!(v6_re, v6_im, v7_re, v7_im, tw_re, tw_im); + } + + // ---- Stage 3: dist=8, W_16 twiddles ---- + // Butterfly pairs: (v0,v2), (v1,v3), (v4,v6), (v5,v7) + // (v0,v2) and (v4,v6) use W_16^{0,1,2,3} + // (v1,v3) and (v5,v7) use W_16^{4,5,6,7} + { + let tw_lo_re = f64x4::simd_from( + simd, + [ + 1.0, // W_16^0 + 0.9238795325112867, // W_16^1 + std::f64::consts::FRAC_1_SQRT_2, // W_16^2 + 0.38268343236508984, // W_16^3 + ], + ); + let tw_lo_im = f64x4::simd_from( + simd, + [ + 0.0, // W_16^0 + -0.38268343236508984, // W_16^1 + -std::f64::consts::FRAC_1_SQRT_2, // W_16^2 + -0.9238795325112867, // W_16^3 + ], + ); + let tw_hi_re = f64x4::simd_from( + simd, + [ + 0.0, // W_16^4 + -0.38268343236508984, // W_16^5 + -std::f64::consts::FRAC_1_SQRT_2, // W_16^6 + -0.9238795325112867, // W_16^7 + ], + ); + let tw_hi_im = f64x4::simd_from( + simd, + [ + -1.0, // W_16^4 + -0.9238795325112867, // W_16^5 + -std::f64::consts::FRAC_1_SQRT_2, // W_16^6 + -0.38268343236508984, // W_16^7 + ], + ); + + butterfly!(v0_re, v0_im, v2_re, v2_im, tw_lo_re, tw_lo_im); + butterfly!(v1_re, v1_im, v3_re, v3_im, tw_hi_re, tw_hi_im); + butterfly!(v4_re, v4_im, v6_re, v6_im, tw_lo_re, tw_lo_im); + butterfly!(v5_re, v5_im, v7_re, v7_im, tw_hi_re, tw_hi_im); + } + + // ---- Stage 4: dist=16, W_32 twiddles ---- + // Butterfly pairs: (v0,v4), (v1,v5), (v2,v6), (v3,v7) + // Each pair uses its own twiddle: W_32^{0..3}, W_32^{4..7}, W_32^{8..11}, W_32^{12..15} + { + let tw0_re = f64x4::simd_from( + simd, + [ + 1.0, // W_32^0 + 0.9807852804032304, // W_32^1 + 0.9238795325112867, // W_32^2 + 0.8314696123025452, // W_32^3 + ], + ); + let tw0_im = f64x4::simd_from( + simd, + [ + 0.0, // W_32^0 + -0.19509032201612825, // W_32^1 + -0.3826834323650898, // W_32^2 + -0.5555702330196022, // W_32^3 + ], + ); + + let tw1_re = f64x4::simd_from( + simd, + [ + std::f64::consts::FRAC_1_SQRT_2, // W_32^4 + 0.5555702330196022, // W_32^5 + 0.3826834323650898, // W_32^6 + 0.19509032201612825, // W_32^7 + ], + ); + let tw1_im = f64x4::simd_from( + simd, + [ + -std::f64::consts::FRAC_1_SQRT_2, // W_32^4 + -0.8314696123025452, // W_32^5 + -0.9238795325112867, // W_32^6 + -0.9807852804032304, // W_32^7 + ], + ); + + let tw2_re = f64x4::simd_from( + simd, + [ + 0.0, // W_32^8 + -0.19509032201612825, // W_32^9 + -0.3826834323650898, // W_32^10 + -0.5555702330196022, // W_32^11 + ], + ); + let tw2_im = f64x4::simd_from( + simd, + [ + -1.0, // W_32^8 + -0.9807852804032304, // W_32^9 + -0.9238795325112867, // W_32^10 + -0.8314696123025452, // W_32^11 + ], + ); + + let tw3_re = f64x4::simd_from( + simd, + [ + -std::f64::consts::FRAC_1_SQRT_2, // W_32^12 + -0.8314696123025452, // W_32^13 + -0.9238795325112867, // W_32^14 + -0.9807852804032304, // W_32^15 + ], + ); + let tw3_im = f64x4::simd_from( + simd, + [ + -std::f64::consts::FRAC_1_SQRT_2, // W_32^12 + -0.5555702330196022, // W_32^13 + -0.3826834323650898, // W_32^14 + -0.19509032201612825, // W_32^15 + ], + ); + + butterfly!(v0_re, v0_im, v4_re, v4_im, tw0_re, tw0_im); + butterfly!(v1_re, v1_im, v5_re, v5_im, tw1_re, tw1_im); + butterfly!(v2_re, v2_im, v6_re, v6_im, tw2_re, tw2_im); + butterfly!(v3_re, v3_im, v7_re, v7_im, tw3_re, tw3_im); + } + + // ---- Store all vectors back ---- + v0_re.store_slice(&mut re[0..4]); + v1_re.store_slice(&mut re[4..8]); + v2_re.store_slice(&mut re[8..12]); + v3_re.store_slice(&mut re[12..16]); + v4_re.store_slice(&mut re[16..20]); + v5_re.store_slice(&mut re[20..24]); + v6_re.store_slice(&mut re[24..28]); + v7_re.store_slice(&mut re[28..32]); + + v0_im.store_slice(&mut im[0..4]); + v1_im.store_slice(&mut im[4..8]); + v2_im.store_slice(&mut im[8..12]); + v3_im.store_slice(&mut im[12..16]); + v4_im.store_slice(&mut im[16..20]); + v5_im.store_slice(&mut im[20..24]); + v6_im.store_slice(&mut im[24..28]); + v7_im.store_slice(&mut im[28..32]); + } +} + /// FFT-32 codelet for f32: executes stages 0-4 in a single function. #[inline(never)] pub fn fft_dit_codelet_32_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { @@ -607,6 +889,102 @@ mod tests { } } + #[test] + fn codelet_32_f64_matches_legacy() { + use fearless_simd::dispatch; + + let simd_level = fearless_simd::Level::new(); + + // Test with impulse signal + let mut re_new = vec![0.0f64; 32]; + let mut im_new = vec![0.0f64; 32]; + re_new[0] = 1.0; + + let mut re_legacy = re_new.clone(); + let mut im_legacy = im_new.clone(); + + dispatch!(simd_level, simd => { + fft_dit_codelet_32_f64(simd, &mut re_new, &mut im_new); + fft_dit_codelet_32_staged_f64(simd, &mut re_legacy, &mut im_legacy); + }); + + for i in 0..32 { + assert!( + (re_new[i] - re_legacy[i]).abs() < 1e-14, + "re[{i}]: new={}, legacy={}", + re_new[i], + re_legacy[i] + ); + assert!( + (im_new[i] - im_legacy[i]).abs() < 1e-14, + "im[{i}]: new={}, legacy={}", + im_new[i], + im_legacy[i] + ); + } + + // Test with non-trivial signal + let mut re_new: Vec = (1..=32).map(|i| i as f64).collect(); + let mut im_new: Vec = (1..=32).map(|i| -(i as f64) * 0.5).collect(); + + let mut re_legacy = re_new.clone(); + let mut im_legacy = im_new.clone(); + + dispatch!(simd_level, simd => { + fft_dit_codelet_32_f64(simd, &mut re_new, &mut im_new); + fft_dit_codelet_32_staged_f64(simd, &mut re_legacy, &mut im_legacy); + }); + + for i in 0..32 { + assert!( + (re_new[i] - re_legacy[i]).abs() < 1e-10, + "re[{i}]: new={}, legacy={}", + re_new[i], + re_legacy[i] + ); + assert!( + (im_new[i] - im_legacy[i]).abs() < 1e-10, + "im[{i}]: new={}, legacy={}", + im_new[i], + im_legacy[i] + ); + } + } + + #[test] + fn codelet_32_f64_matches_legacy_multi_chunk() { + use fearless_simd::dispatch; + + let simd_level = fearless_simd::Level::new(); + + let n = 128; + let mut re_new: Vec = (0..n).map(|i| (i as f64) * 0.1).collect(); + let mut im_new: Vec = (0..n).map(|i| -(i as f64) * 0.05).collect(); + + let mut re_legacy = re_new.clone(); + let mut im_legacy = im_new.clone(); + + dispatch!(simd_level, simd => { + fft_dit_codelet_32_f64(simd, &mut re_new, &mut im_new); + fft_dit_codelet_32_staged_f64(simd, &mut re_legacy, &mut im_legacy); + }); + + for i in 0..n { + assert!( + (re_new[i] - re_legacy[i]).abs() < 1e-10, + "re[{i}]: new={}, legacy={}", + re_new[i], + re_legacy[i] + ); + assert!( + (im_new[i] - im_legacy[i]).abs() < 1e-10, + "im[{i}]: new={}, legacy={}", + im_new[i], + im_legacy[i] + ); + } + } + #[test] fn codelet_32_f64_multi_chunk() { use fearless_simd::dispatch; From eb26222269023a0ccbc72a2edf4d997a62af55d3 Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 11 Apr 2026 19:51:23 +0100 Subject: [PATCH 02/17] Make f32 codelet operate entirely in registers --- src/kernels/codelets.rs | 328 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 324 insertions(+), 4 deletions(-) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index 9c3291f..c35798b 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -525,17 +525,17 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut } } -/// FFT-32 codelet for f32: executes stages 0-4 in a single function. +/// Legacy FFT-32 codelet for `f32`: stage-by-stage in-place with intermediate stores. #[inline(never)] -pub fn fft_dit_codelet_32_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { +pub fn fft_dit_codelet_32_staged_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { simd.vectorize( #[inline(always)] - || fft_dit_codelet_32_simd_f32(simd, reals, imags), + || fft_dit_codelet_32_staged_simd_f32(simd, reals, imags), ) } #[inline(always)] -fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { +fn fft_dit_codelet_32_staged_simd_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { // Fused stages 0+1: radix-2^2 4-point DIT in a single sweep reals .chunks_exact_mut(4) @@ -733,6 +733,230 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut } } +/// FFT-32 codelet for `f32`: executes stages 0-4 (chunk_size 2 through 32) in a single function. +/// +/// Register-resident implementation using `f32x8`: all 32 complex values are loaded into +/// 4 `f32x8` re + 4 `f32x8` im vectors, all 5 butterfly stages execute in registers with +/// no intermediate memory traffic, then results are stored back. +#[inline(never)] +pub fn fft_dit_codelet_32_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { + simd.vectorize( + #[inline(always)] + || fft_dit_codelet_32_simd_f32(simd, reals, imags), + ) +} + +#[inline(always)] +fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { + assert_eq!(reals.len(), imags.len()); + + let two = f32x8::splat(simd, 2.0); + + for (re, im) in reals.chunks_exact_mut(32).zip(imags.chunks_exact_mut(32)) { + // ---- Load into 4 f32x8 register pairs (re + im) ---- + // v_k holds elements [8k .. 8k+7] + let mut v0_re = f32x8::from_slice(simd, &re[0..8]); + let mut v1_re = f32x8::from_slice(simd, &re[8..16]); + let mut v2_re = f32x8::from_slice(simd, &re[16..24]); + let mut v3_re = f32x8::from_slice(simd, &re[24..32]); + + let mut v0_im = f32x8::from_slice(simd, &im[0..8]); + let mut v1_im = f32x8::from_slice(simd, &im[8..16]); + let mut v2_im = f32x8::from_slice(simd, &im[16..24]); + let mut v3_im = f32x8::from_slice(simd, &im[24..32]); + + // ---- Stages 0+1+2 fused: 8-point DIT on each vector (scalar) ---- + // Each f32x8 holds 8 consecutive elements. Stages 0 (dist=1), 1 (dist=2), + // and 2 (dist=4) all have butterfly pairs within the same 8-element group. + // We extract to scalars, do the full 8-point DIT, and repack. + macro_rules! dit8_inplace { + ($v_re:expr, $v_im:expr) => {{ + let mut r = [0.0f32; 8]; + let mut i = [0.0f32; 8]; + $v_re.store_slice(&mut r); + $v_im.store_slice(&mut i); + + // Stage 0+1 fused: radix-4 on elements [0,1,2,3] + let (a_re, a_im) = (r[0] + r[1], i[0] + i[1]); + let (b_re, b_im) = (r[0] - r[1], i[0] - i[1]); + let (c_re, c_im) = (r[2] + r[3], i[2] + i[3]); + let (d_re, d_im) = (r[2] - r[3], i[2] - i[3]); + let (dj_re, dj_im) = (d_im, -d_re); // -j * d + r[0] = a_re + c_re; + i[0] = a_im + c_im; + r[1] = b_re + dj_re; + i[1] = b_im + dj_im; + r[2] = a_re - c_re; + i[2] = a_im - c_im; + r[3] = b_re - dj_re; + i[3] = b_im - dj_im; + + // Stage 0+1 fused: radix-4 on elements [4,5,6,7] + let (a_re, a_im) = (r[4] + r[5], i[4] + i[5]); + let (b_re, b_im) = (r[4] - r[5], i[4] - i[5]); + let (c_re, c_im) = (r[6] + r[7], i[6] + i[7]); + let (d_re, d_im) = (r[6] - r[7], i[6] - i[7]); + let (dj_re, dj_im) = (d_im, -d_re); + r[4] = a_re + c_re; + i[4] = a_im + c_im; + r[5] = b_re + dj_re; + i[5] = b_im + dj_im; + r[6] = a_re - c_re; + i[6] = a_im - c_im; + r[7] = b_re - dj_re; + i[7] = b_im - dj_im; + + // Stage 2: dist=4 butterflies between (k, k+4) with W_8 twiddles + // W_8^0 = 1, W_8^1 = (1-j)/√2, W_8^2 = -j, W_8^3 = (-1-j)/√2 + const FRAC_1_SQRT_2: f32 = std::f32::consts::FRAC_1_SQRT_2; + let tw_re: [f32; 4] = [1.0, FRAC_1_SQRT_2, 0.0, -FRAC_1_SQRT_2]; + let tw_im: [f32; 4] = [0.0, -FRAC_1_SQRT_2, -1.0, -FRAC_1_SQRT_2]; + + for k in 0..4 { + let hi_re = tw_re[k] * r[k + 4] - tw_im[k] * i[k + 4]; + let hi_im = tw_re[k] * i[k + 4] + tw_im[k] * r[k + 4]; + let lo_re = r[k]; + let lo_im = i[k]; + r[k] = lo_re + hi_re; + i[k] = lo_im + hi_im; + r[k + 4] = lo_re - hi_re; + i[k + 4] = lo_im - hi_im; + } + + $v_re = f32x8::simd_from(simd, r); + $v_im = f32x8::simd_from(simd, i); + }}; + } + + dit8_inplace!(v0_re, v0_im); + dit8_inplace!(v1_re, v1_im); + dit8_inplace!(v2_re, v2_im); + dit8_inplace!(v3_re, v3_im); + + // Butterfly macro: twiddle-multiply hi, then add/sub with lo. + // out_lo = lo + tw*hi, out_hi = lo - tw*hi (via 2*lo - out_lo). + macro_rules! butterfly { + ($lo_re:expr, $lo_im:expr, $hi_re:expr, $hi_im:expr, $tw_re:expr, $tw_im:expr) => {{ + let out_lo_re = $tw_im.mul_add(-$hi_im, $tw_re.mul_add($hi_re, $lo_re)); + let out_lo_im = $tw_im.mul_add($hi_re, $tw_re.mul_add($hi_im, $lo_im)); + let out_hi_re = two.mul_sub($lo_re, out_lo_re); + let out_hi_im = two.mul_sub($lo_im, out_lo_im); + $lo_re = out_lo_re; + $lo_im = out_lo_im; + $hi_re = out_hi_re; + $hi_im = out_hi_im; + }}; + } + + // ---- Stage 3: dist=8, W_16 twiddles ---- + // Butterfly pairs: (v0,v1), (v2,v3) + // Both pairs use the same twiddle: W_16^{0..7} + { + let tw_re = f32x8::simd_from( + simd, + [ + 1.0_f32, // W_16^0 + 0.923_879_5_f32, // W_16^1 + std::f32::consts::FRAC_1_SQRT_2, // W_16^2 + 0.382_683_43_f32, // W_16^3 + 0.0_f32, // W_16^4 + -0.382_683_43_f32, // W_16^5 + -std::f32::consts::FRAC_1_SQRT_2, // W_16^6 + -0.923_879_5_f32, // W_16^7 + ], + ); + let tw_im = f32x8::simd_from( + simd, + [ + 0.0_f32, // W_16^0 + -0.382_683_43_f32, // W_16^1 + -std::f32::consts::FRAC_1_SQRT_2, // W_16^2 + -0.923_879_5_f32, // W_16^3 + -1.0_f32, // W_16^4 + -0.923_879_5_f32, // W_16^5 + -std::f32::consts::FRAC_1_SQRT_2, // W_16^6 + -0.382_683_43_f32, // W_16^7 + ], + ); + + butterfly!(v0_re, v0_im, v1_re, v1_im, tw_re, tw_im); + butterfly!(v2_re, v2_im, v3_re, v3_im, tw_re, tw_im); + } + + // ---- Stage 4: dist=16, W_32 twiddles ---- + // Butterfly pairs: (v0,v2), (v1,v3) + // (v0,v2) uses W_32^{0..7}, (v1,v3) uses W_32^{8..15} + { + let tw_lo_re = f32x8::simd_from( + simd, + [ + 1.0_f32, // W_32^0 + 0.980_785_25_f32, // W_32^1 + 0.923_879_5_f32, // W_32^2 + 0.831_469_6_f32, // W_32^3 + std::f32::consts::FRAC_1_SQRT_2, // W_32^4 + 0.555_570_24_f32, // W_32^5 + 0.382_683_43_f32, // W_32^6 + 0.195_090_32_f32, // W_32^7 + ], + ); + let tw_lo_im = f32x8::simd_from( + simd, + [ + 0.0_f32, // W_32^0 + -0.195_090_32_f32, // W_32^1 + -0.382_683_43_f32, // W_32^2 + -0.555_570_24_f32, // W_32^3 + -std::f32::consts::FRAC_1_SQRT_2, // W_32^4 + -0.831_469_6_f32, // W_32^5 + -0.923_879_5_f32, // W_32^6 + -0.980_785_25_f32, // W_32^7 + ], + ); + let tw_hi_re = f32x8::simd_from( + simd, + [ + 0.0_f32, // W_32^8 + -0.195_090_32_f32, // W_32^9 + -0.382_683_43_f32, // W_32^10 + -0.555_570_24_f32, // W_32^11 + -std::f32::consts::FRAC_1_SQRT_2, // W_32^12 + -0.831_469_6_f32, // W_32^13 + -0.923_879_5_f32, // W_32^14 + -0.980_785_25_f32, // W_32^15 + ], + ); + let tw_hi_im = f32x8::simd_from( + simd, + [ + -1.0_f32, // W_32^8 + -0.980_785_25_f32, // W_32^9 + -0.923_879_5_f32, // W_32^10 + -0.831_469_6_f32, // W_32^11 + -std::f32::consts::FRAC_1_SQRT_2, // W_32^12 + -0.555_570_24_f32, // W_32^13 + -0.382_683_43_f32, // W_32^14 + -0.195_090_32_f32, // W_32^15 + ], + ); + + butterfly!(v0_re, v0_im, v2_re, v2_im, tw_lo_re, tw_lo_im); + butterfly!(v1_re, v1_im, v3_re, v3_im, tw_hi_re, tw_hi_im); + } + + // ---- Store all vectors back ---- + v0_re.store_slice(&mut re[0..8]); + v1_re.store_slice(&mut re[8..16]); + v2_re.store_slice(&mut re[16..24]); + v3_re.store_slice(&mut re[24..32]); + + v0_im.store_slice(&mut im[0..8]); + v1_im.store_slice(&mut im[8..16]); + v2_im.store_slice(&mut im[16..24]); + v3_im.store_slice(&mut im[24..32]); + } +} + #[cfg(test)] mod tests { use super::*; @@ -1029,4 +1253,100 @@ mod tests { ); } } + + #[test] + fn codelet_32_f32_matches_legacy() { + use fearless_simd::dispatch; + + let simd_level = fearless_simd::Level::new(); + + // Test with impulse signal + let mut re_new = vec![0.0f32; 32]; + let mut im_new = vec![0.0f32; 32]; + re_new[0] = 1.0; + + let mut re_legacy = re_new.clone(); + let mut im_legacy = im_new.clone(); + + dispatch!(simd_level, simd => { + fft_dit_codelet_32_f32(simd, &mut re_new, &mut im_new); + fft_dit_codelet_32_staged_f32(simd, &mut re_legacy, &mut im_legacy); + }); + + for i in 0..32 { + assert!( + (re_new[i] - re_legacy[i]).abs() < 1e-5, + "re[{i}]: new={}, legacy={}", + re_new[i], + re_legacy[i] + ); + assert!( + (im_new[i] - im_legacy[i]).abs() < 1e-5, + "im[{i}]: new={}, legacy={}", + im_new[i], + im_legacy[i] + ); + } + + // Test with non-trivial signal + let mut re_new: Vec = (1..=32).map(|i| i as f32).collect(); + let mut im_new: Vec = (1..=32).map(|i| -(i as f32) * 0.5).collect(); + + let mut re_legacy = re_new.clone(); + let mut im_legacy = im_new.clone(); + + dispatch!(simd_level, simd => { + fft_dit_codelet_32_f32(simd, &mut re_new, &mut im_new); + fft_dit_codelet_32_staged_f32(simd, &mut re_legacy, &mut im_legacy); + }); + + for i in 0..32 { + assert!( + (re_new[i] - re_legacy[i]).abs() < 1e-4, + "re[{i}]: new={}, legacy={}", + re_new[i], + re_legacy[i] + ); + assert!( + (im_new[i] - im_legacy[i]).abs() < 1e-4, + "im[{i}]: new={}, legacy={}", + im_new[i], + im_legacy[i] + ); + } + } + + #[test] + fn codelet_32_f32_matches_legacy_multi_chunk() { + use fearless_simd::dispatch; + + let simd_level = fearless_simd::Level::new(); + + let n = 128; + let mut re_new: Vec = (0..n).map(|i| (i as f32) * 0.1).collect(); + let mut im_new: Vec = (0..n).map(|i| -(i as f32) * 0.05).collect(); + + let mut re_legacy = re_new.clone(); + let mut im_legacy = im_new.clone(); + + dispatch!(simd_level, simd => { + fft_dit_codelet_32_f32(simd, &mut re_new, &mut im_new); + fft_dit_codelet_32_staged_f32(simd, &mut re_legacy, &mut im_legacy); + }); + + for i in 0..n { + assert!( + (re_new[i] - re_legacy[i]).abs() < 1e-4, + "re[{i}]: new={}, legacy={}", + re_new[i], + re_legacy[i] + ); + assert!( + (im_new[i] - im_legacy[i]).abs() < 1e-4, + "im[{i}]: new={}, legacy={}", + im_new[i], + im_legacy[i] + ); + } + } } From 3aedc257c07451e776bcf8c55b9fe7a8882989f8 Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 11 Apr 2026 19:55:17 +0100 Subject: [PATCH 03/17] Delete old, superseded codelets --- src/kernels/codelets.rs | 640 +--------------------------------------- 1 file changed, 1 insertion(+), 639 deletions(-) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index c35798b..86b9dfc 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -3,245 +3,7 @@ //! A codelet is a self-contained FFT kernel that fuses multiple stages into a single function //! call, eliminating per-stage function call overhead and giving LLVM a wider optimization window. //! -use fearless_simd::{f32x16, f32x4, f32x8, f64x4, f64x8, Simd, SimdBase, SimdFloat, SimdFrom}; - -/// Legacy FFT-32 codelet for `f64`: stage-by-stage in-place with intermediate stores. -#[inline(never)] -pub fn fft_dit_codelet_32_staged_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { - simd.vectorize( - #[inline(always)] - || fft_dit_codelet_32_staged_simd_f64(simd, reals, imags), - ) -} - -#[inline(always)] -fn fft_dit_codelet_32_staged_simd_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { - // Fused stages 0+1: radix-2^2 4-point DIT in a single sweep - reals - .chunks_exact_mut(4) - .zip(imags.chunks_exact_mut(4)) - .for_each(|(re, im)| { - let (a_re, a_im) = (re[0], im[0]); - let (b_re, b_im) = (re[1], im[1]); - let (c_re, c_im) = (re[2], im[2]); - let (d_re, d_im) = (re[3], im[3]); - - // Stage 0: dist=1 butterflies on pairs (a,b) and (c,d) - let (t0_re, t0_im) = (a_re + b_re, a_im + b_im); - let (t1_re, t1_im) = (a_re - b_re, a_im - b_im); - let (t2_re, t2_im) = (c_re + d_re, c_im + d_im); - let (t3_re, t3_im) = (c_re - d_re, c_im - d_im); - - // Stage 1: -j multiply on t3, then dist=2 butterflies - let (t3j_re, t3j_im) = (t3_im, -t3_re); - - re[0] = t0_re + t2_re; - im[0] = t0_im + t2_im; - re[1] = t1_re + t3j_re; - im[1] = t1_im + t3j_im; - re[2] = t0_re - t2_re; - im[2] = t0_im - t2_im; - re[3] = t1_re - t3j_re; - im[3] = t1_im - t3j_im; - }); - - // Stage 2: dist=4, chunk_size=8, W_8 twiddles via f64x4 - { - let tw_re = f64x4::simd_from( - simd, - [ - 1.0, // W_8^0 - std::f64::consts::FRAC_1_SQRT_2, // W_8^1 - 0.0, // W_8^2 - -std::f64::consts::FRAC_1_SQRT_2, // W_8^3 - ], - ); - let tw_im = f64x4::simd_from( - simd, - [ - 0.0, // W_8^0 - -std::f64::consts::FRAC_1_SQRT_2, // W_8^1 - -1.0, // W_8^2 - -std::f64::consts::FRAC_1_SQRT_2, // W_8^3 - ], - ); - let two = f64x4::splat(simd, 2.0); - - (reals.as_chunks_mut::<8>().0.iter_mut()) - .zip(imags.as_chunks_mut::<8>().0.iter_mut()) - .for_each(|(re8, im8)| { - let (re_lo, re_hi) = re8.split_at_mut(4); - let (im_lo, im_hi) = im8.split_at_mut(4); - - let in0_re = f64x4::from_slice(simd, re_lo); - let in1_re = f64x4::from_slice(simd, re_hi); - let in0_im = f64x4::from_slice(simd, im_lo); - let in1_im = f64x4::from_slice(simd, im_hi); - - let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); - let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(re_lo); - out0_im.store_slice(im_lo); - out1_re.store_slice(re_hi); - out1_im.store_slice(im_hi); - }); - } - - // Stage 3: dist=8, chunk_size=16, W_16 twiddles via f64x8 - { - let tw_re = f64x8::simd_from( - simd, - [ - 1.0, // W_16^0 - 0.9238795325112867, // W_16^1 - std::f64::consts::FRAC_1_SQRT_2, // W_16^2 - 0.38268343236508984, // W_16^3 - 0.0, // W_16^4 - -0.38268343236508984, // W_16^5 - -std::f64::consts::FRAC_1_SQRT_2, // W_16^6 - -0.9238795325112867, // W_16^7 - ], - ); - let tw_im = f64x8::simd_from( - simd, - [ - 0.0, // W_16^0 - -0.38268343236508984, // W_16^1 - -std::f64::consts::FRAC_1_SQRT_2, // W_16^2 - -0.9238795325112867, // W_16^3 - -1.0, // W_16^4 - -0.9238795325112867, // W_16^5 - -std::f64::consts::FRAC_1_SQRT_2, // W_16^6 - -0.38268343236508984, // W_16^7 - ], - ); - let two = f64x8::splat(simd, 2.0); - - (reals.as_chunks_mut::<16>().0.iter_mut()) - .zip(imags.as_chunks_mut::<16>().0.iter_mut()) - .for_each(|(re16, im16)| { - let (re_lo, re_hi) = re16.split_at_mut(8); - let (im_lo, im_hi) = im16.split_at_mut(8); - - let in0_re = f64x8::from_slice(simd, re_lo); - let in1_re = f64x8::from_slice(simd, re_hi); - let in0_im = f64x8::from_slice(simd, im_lo); - let in1_im = f64x8::from_slice(simd, im_hi); - - let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); - let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(re_lo); - out0_im.store_slice(im_lo); - out1_re.store_slice(re_hi); - out1_im.store_slice(im_hi); - }); - } - - // Stage 4: dist=16, chunk_size=32, W_32 twiddles via 2x f64x8 - { - let tw_re_0_7 = f64x8::simd_from( - simd, - [ - 1.0, // W_32^0 - 0.9807852804032304, // W_32^1 - 0.9238795325112867, // W_32^2 - 0.8314696123025452, // W_32^3 - std::f64::consts::FRAC_1_SQRT_2, // W_32^4 - 0.5555702330196022, // W_32^5 - 0.3826834323650898, // W_32^6 - 0.19509032201612825, // W_32^7 - ], - ); - let tw_im_0_7 = f64x8::simd_from( - simd, - [ - 0.0, // W_32^0 - -0.19509032201612825, // W_32^1 - -0.3826834323650898, // W_32^2 - -0.5555702330196022, // W_32^3 - -std::f64::consts::FRAC_1_SQRT_2, // W_32^4 - -0.8314696123025452, // W_32^5 - -0.9238795325112867, // W_32^6 - -0.9807852804032304, // W_32^7 - ], - ); - let tw_re_8_15 = f64x8::simd_from( - simd, - [ - 0.0, // W_32^8 - -0.19509032201612825, // W_32^9 - -0.3826834323650898, // W_32^10 - -0.5555702330196022, // W_32^11 - -std::f64::consts::FRAC_1_SQRT_2, // W_32^12 - -0.8314696123025452, // W_32^13 - -0.9238795325112867, // W_32^14 - -0.9807852804032304, // W_32^15 - ], - ); - let tw_im_8_15 = f64x8::simd_from( - simd, - [ - -1.0, // W_32^8 - -0.9807852804032304, // W_32^9 - -0.9238795325112867, // W_32^10 - -0.8314696123025452, // W_32^11 - -std::f64::consts::FRAC_1_SQRT_2, // W_32^12 - -0.5555702330196022, // W_32^13 - -0.3826834323650898, // W_32^14 - -0.19509032201612825, // W_32^15 - ], - ); - let two = f64x8::splat(simd, 2.0); - - for (re32, im32) in reals - .as_chunks_mut::<32>() - .0 - .iter_mut() - .zip(imags.as_chunks_mut::<32>().0.iter_mut()) - { - let (re_lo, re_hi) = re32.split_at_mut(16); - let (im_lo, im_hi) = im32.split_at_mut(16); - - // Batch 0: elements [0..8] and [16..24] with W_32^{0..7} - let in0_re = f64x8::from_slice(simd, &re_lo[0..8]); - let in1_re = f64x8::from_slice(simd, &re_hi[0..8]); - let in0_im = f64x8::from_slice(simd, &im_lo[0..8]); - let in1_im = f64x8::from_slice(simd, &im_hi[0..8]); - - let out0_re = tw_im_0_7.mul_add(-in1_im, tw_re_0_7.mul_add(in1_re, in0_re)); - let out0_im = tw_im_0_7.mul_add(in1_re, tw_re_0_7.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(&mut re_lo[0..8]); - out0_im.store_slice(&mut im_lo[0..8]); - out1_re.store_slice(&mut re_hi[0..8]); - out1_im.store_slice(&mut im_hi[0..8]); - - // Batch 1: elements [8..16] and [24..32] with W_32^{8..15} - let in0_re = f64x8::from_slice(simd, &re_lo[8..16]); - let in1_re = f64x8::from_slice(simd, &re_hi[8..16]); - let in0_im = f64x8::from_slice(simd, &im_lo[8..16]); - let in1_im = f64x8::from_slice(simd, &im_hi[8..16]); - - let out0_re = tw_im_8_15.mul_add(-in1_im, tw_re_8_15.mul_add(in1_re, in0_re)); - let out0_im = tw_im_8_15.mul_add(in1_re, tw_re_8_15.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(&mut re_lo[8..16]); - out0_im.store_slice(&mut im_lo[8..16]); - out1_re.store_slice(&mut re_hi[8..16]); - out1_im.store_slice(&mut im_hi[8..16]); - } - } -} +use fearless_simd::{f32x16, f32x4, f32x8, f64x4, Simd, SimdBase, SimdFloat, SimdFrom}; /// FFT-32 codelet for `f64`: executes stages 0-4 (chunk_size 2 through 32) in a single function. /// @@ -525,214 +287,6 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut } } -/// Legacy FFT-32 codelet for `f32`: stage-by-stage in-place with intermediate stores. -#[inline(never)] -pub fn fft_dit_codelet_32_staged_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { - simd.vectorize( - #[inline(always)] - || fft_dit_codelet_32_staged_simd_f32(simd, reals, imags), - ) -} - -#[inline(always)] -fn fft_dit_codelet_32_staged_simd_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { - // Fused stages 0+1: radix-2^2 4-point DIT in a single sweep - reals - .chunks_exact_mut(4) - .zip(imags.chunks_exact_mut(4)) - .for_each(|(re, im)| { - let (a_re, a_im) = (re[0], im[0]); - let (b_re, b_im) = (re[1], im[1]); - let (c_re, c_im) = (re[2], im[2]); - let (d_re, d_im) = (re[3], im[3]); - - // Stage 0: dist=1 butterflies on pairs (a,b) and (c,d) - let (t0_re, t0_im) = (a_re + b_re, a_im + b_im); - let (t1_re, t1_im) = (a_re - b_re, a_im - b_im); - let (t2_re, t2_im) = (c_re + d_re, c_im + d_im); - let (t3_re, t3_im) = (c_re - d_re, c_im - d_im); - - // Stage 1: -j multiply on t3, then dist=2 butterflies - let (t3j_re, t3j_im) = (t3_im, -t3_re); - - re[0] = t0_re + t2_re; - im[0] = t0_im + t2_im; - re[1] = t1_re + t3j_re; - im[1] = t1_im + t3j_im; - re[2] = t0_re - t2_re; - im[2] = t0_im - t2_im; - re[3] = t1_re - t3j_re; - im[3] = t1_im - t3j_im; - }); - - // Stage 2: dist=4, chunk_size=8, W_8 twiddles via f32x4 - { - let tw_re = f32x4::simd_from( - simd, - [ - 1.0_f32, // W_8^0 - std::f32::consts::FRAC_1_SQRT_2, // W_8^1 - 0.0_f32, // W_8^2 - -std::f32::consts::FRAC_1_SQRT_2, // W_8^3 - ], - ); - let tw_im = f32x4::simd_from( - simd, - [ - 0.0_f32, // W_8^0 - -std::f32::consts::FRAC_1_SQRT_2, // W_8^1 - -1.0_f32, // W_8^2 - -std::f32::consts::FRAC_1_SQRT_2, // W_8^3 - ], - ); - let two = f32x4::splat(simd, 2.0); - - (reals.as_chunks_mut::<8>().0.iter_mut()) - .zip(imags.as_chunks_mut::<8>().0.iter_mut()) - .for_each(|(re8, im8)| { - let (re_lo, re_hi) = re8.split_at_mut(4); - let (im_lo, im_hi) = im8.split_at_mut(4); - - let in0_re = f32x4::from_slice(simd, re_lo); - let in1_re = f32x4::from_slice(simd, re_hi); - let in0_im = f32x4::from_slice(simd, im_lo); - let in1_im = f32x4::from_slice(simd, im_hi); - - let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); - let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(re_lo); - out0_im.store_slice(im_lo); - out1_re.store_slice(re_hi); - out1_im.store_slice(im_hi); - }); - } - - // Stage 3: dist=8, chunk_size=16, W_16 twiddles via f32x8 - { - let tw_re = f32x8::simd_from( - simd, - [ - 1.0_f32, // W_16^0 - 0.923_879_5_f32, // W_16^1 - std::f32::consts::FRAC_1_SQRT_2, // W_16^2 - 0.382_683_43_f32, // W_16^3 - 0.0_f32, // W_16^4 - -0.382_683_43_f32, // W_16^5 - -std::f32::consts::FRAC_1_SQRT_2, // W_16^6 - -0.923_879_5_f32, // W_16^7 - ], - ); - let tw_im = f32x8::simd_from( - simd, - [ - 0.0_f32, // W_16^0 - -0.382_683_43_f32, // W_16^1 - -std::f32::consts::FRAC_1_SQRT_2, // W_16^2 - -0.923_879_5_f32, // W_16^3 - -1.0_f32, // W_16^4 - -0.923_879_5_f32, // W_16^5 - -std::f32::consts::FRAC_1_SQRT_2, // W_16^6 - -0.382_683_43_f32, // W_16^7 - ], - ); - let two = f32x8::splat(simd, 2.0); - - (reals.as_chunks_mut::<16>().0.iter_mut()) - .zip(imags.as_chunks_mut::<16>().0.iter_mut()) - .for_each(|(re16, im16)| { - let (re_lo, re_hi) = re16.split_at_mut(8); - let (im_lo, im_hi) = im16.split_at_mut(8); - - let in0_re = f32x8::from_slice(simd, re_lo); - let in1_re = f32x8::from_slice(simd, re_hi); - let in0_im = f32x8::from_slice(simd, im_lo); - let in1_im = f32x8::from_slice(simd, im_hi); - - let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); - let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(re_lo); - out0_im.store_slice(im_lo); - out1_re.store_slice(re_hi); - out1_im.store_slice(im_hi); - }); - } - - // Stage 4: dist=16, chunk_size=32, W_32 twiddles via f32x16 - { - let tw_re = f32x16::simd_from( - simd, - [ - 1.0_f32, // W_32^0 - 0.980_785_25_f32, // W_32^1 - 0.923_879_5_f32, // W_32^2 - 0.831_469_6_f32, // W_32^3 - std::f32::consts::FRAC_1_SQRT_2, // W_32^4 - 0.555_570_24_f32, // W_32^5 - 0.382_683_43_f32, // W_32^6 - 0.195_090_32_f32, // W_32^7 - 0.0_f32, // W_32^8 - -0.195_090_32_f32, // W_32^9 - -0.382_683_43_f32, // W_32^10 - -0.555_570_24_f32, // W_32^11 - -std::f32::consts::FRAC_1_SQRT_2, // W_32^12 - -0.831_469_6_f32, // W_32^13 - -0.923_879_5_f32, // W_32^14 - -0.980_785_25_f32, // W_32^15 - ], - ); - let tw_im = f32x16::simd_from( - simd, - [ - 0.0_f32, // W_32^0 - -0.195_090_32_f32, // W_32^1 - -0.382_683_43_f32, // W_32^2 - -0.555_570_24_f32, // W_32^3 - -std::f32::consts::FRAC_1_SQRT_2, // W_32^4 - -0.831_469_6_f32, // W_32^5 - -0.923_879_5_f32, // W_32^6 - -0.980_785_25_f32, // W_32^7 - -1.0_f32, // W_32^8 - -0.980_785_25_f32, // W_32^9 - -0.923_879_5_f32, // W_32^10 - -0.831_469_6_f32, // W_32^11 - -std::f32::consts::FRAC_1_SQRT_2, // W_32^12 - -0.555_570_24_f32, // W_32^13 - -0.382_683_43_f32, // W_32^14 - -0.195_090_32_f32, // W_32^15 - ], - ); - let two = f32x16::splat(simd, 2.0); - - (reals.as_chunks_mut::<32>().0.iter_mut()) - .zip(imags.as_chunks_mut::<32>().0.iter_mut()) - .for_each(|(re32, im32)| { - let (re_lo, re_hi) = re32.split_at_mut(16); - let (im_lo, im_hi) = im32.split_at_mut(16); - - let in0_re = f32x16::from_slice(simd, re_lo); - let in1_re = f32x16::from_slice(simd, re_hi); - let in0_im = f32x16::from_slice(simd, im_lo); - let in1_im = f32x16::from_slice(simd, im_hi); - - let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); - let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); - let out1_re = two.mul_sub(in0_re, out0_re); - let out1_im = two.mul_sub(in0_im, out0_im); - - out0_re.store_slice(re_lo); - out0_im.store_slice(im_lo); - out1_re.store_slice(re_hi); - out1_im.store_slice(im_hi); - }); - } -} - /// FFT-32 codelet for `f32`: executes stages 0-4 (chunk_size 2 through 32) in a single function. /// /// Register-resident implementation using `f32x8`: all 32 complex values are loaded into @@ -1113,102 +667,6 @@ mod tests { } } - #[test] - fn codelet_32_f64_matches_legacy() { - use fearless_simd::dispatch; - - let simd_level = fearless_simd::Level::new(); - - // Test with impulse signal - let mut re_new = vec![0.0f64; 32]; - let mut im_new = vec![0.0f64; 32]; - re_new[0] = 1.0; - - let mut re_legacy = re_new.clone(); - let mut im_legacy = im_new.clone(); - - dispatch!(simd_level, simd => { - fft_dit_codelet_32_f64(simd, &mut re_new, &mut im_new); - fft_dit_codelet_32_staged_f64(simd, &mut re_legacy, &mut im_legacy); - }); - - for i in 0..32 { - assert!( - (re_new[i] - re_legacy[i]).abs() < 1e-14, - "re[{i}]: new={}, legacy={}", - re_new[i], - re_legacy[i] - ); - assert!( - (im_new[i] - im_legacy[i]).abs() < 1e-14, - "im[{i}]: new={}, legacy={}", - im_new[i], - im_legacy[i] - ); - } - - // Test with non-trivial signal - let mut re_new: Vec = (1..=32).map(|i| i as f64).collect(); - let mut im_new: Vec = (1..=32).map(|i| -(i as f64) * 0.5).collect(); - - let mut re_legacy = re_new.clone(); - let mut im_legacy = im_new.clone(); - - dispatch!(simd_level, simd => { - fft_dit_codelet_32_f64(simd, &mut re_new, &mut im_new); - fft_dit_codelet_32_staged_f64(simd, &mut re_legacy, &mut im_legacy); - }); - - for i in 0..32 { - assert!( - (re_new[i] - re_legacy[i]).abs() < 1e-10, - "re[{i}]: new={}, legacy={}", - re_new[i], - re_legacy[i] - ); - assert!( - (im_new[i] - im_legacy[i]).abs() < 1e-10, - "im[{i}]: new={}, legacy={}", - im_new[i], - im_legacy[i] - ); - } - } - - #[test] - fn codelet_32_f64_matches_legacy_multi_chunk() { - use fearless_simd::dispatch; - - let simd_level = fearless_simd::Level::new(); - - let n = 128; - let mut re_new: Vec = (0..n).map(|i| (i as f64) * 0.1).collect(); - let mut im_new: Vec = (0..n).map(|i| -(i as f64) * 0.05).collect(); - - let mut re_legacy = re_new.clone(); - let mut im_legacy = im_new.clone(); - - dispatch!(simd_level, simd => { - fft_dit_codelet_32_f64(simd, &mut re_new, &mut im_new); - fft_dit_codelet_32_staged_f64(simd, &mut re_legacy, &mut im_legacy); - }); - - for i in 0..n { - assert!( - (re_new[i] - re_legacy[i]).abs() < 1e-10, - "re[{i}]: new={}, legacy={}", - re_new[i], - re_legacy[i] - ); - assert!( - (im_new[i] - im_legacy[i]).abs() < 1e-10, - "im[{i}]: new={}, legacy={}", - im_new[i], - im_legacy[i] - ); - } - } - #[test] fn codelet_32_f64_multi_chunk() { use fearless_simd::dispatch; @@ -1253,100 +711,4 @@ mod tests { ); } } - - #[test] - fn codelet_32_f32_matches_legacy() { - use fearless_simd::dispatch; - - let simd_level = fearless_simd::Level::new(); - - // Test with impulse signal - let mut re_new = vec![0.0f32; 32]; - let mut im_new = vec![0.0f32; 32]; - re_new[0] = 1.0; - - let mut re_legacy = re_new.clone(); - let mut im_legacy = im_new.clone(); - - dispatch!(simd_level, simd => { - fft_dit_codelet_32_f32(simd, &mut re_new, &mut im_new); - fft_dit_codelet_32_staged_f32(simd, &mut re_legacy, &mut im_legacy); - }); - - for i in 0..32 { - assert!( - (re_new[i] - re_legacy[i]).abs() < 1e-5, - "re[{i}]: new={}, legacy={}", - re_new[i], - re_legacy[i] - ); - assert!( - (im_new[i] - im_legacy[i]).abs() < 1e-5, - "im[{i}]: new={}, legacy={}", - im_new[i], - im_legacy[i] - ); - } - - // Test with non-trivial signal - let mut re_new: Vec = (1..=32).map(|i| i as f32).collect(); - let mut im_new: Vec = (1..=32).map(|i| -(i as f32) * 0.5).collect(); - - let mut re_legacy = re_new.clone(); - let mut im_legacy = im_new.clone(); - - dispatch!(simd_level, simd => { - fft_dit_codelet_32_f32(simd, &mut re_new, &mut im_new); - fft_dit_codelet_32_staged_f32(simd, &mut re_legacy, &mut im_legacy); - }); - - for i in 0..32 { - assert!( - (re_new[i] - re_legacy[i]).abs() < 1e-4, - "re[{i}]: new={}, legacy={}", - re_new[i], - re_legacy[i] - ); - assert!( - (im_new[i] - im_legacy[i]).abs() < 1e-4, - "im[{i}]: new={}, legacy={}", - im_new[i], - im_legacy[i] - ); - } - } - - #[test] - fn codelet_32_f32_matches_legacy_multi_chunk() { - use fearless_simd::dispatch; - - let simd_level = fearless_simd::Level::new(); - - let n = 128; - let mut re_new: Vec = (0..n).map(|i| (i as f32) * 0.1).collect(); - let mut im_new: Vec = (0..n).map(|i| -(i as f32) * 0.05).collect(); - - let mut re_legacy = re_new.clone(); - let mut im_legacy = im_new.clone(); - - dispatch!(simd_level, simd => { - fft_dit_codelet_32_f32(simd, &mut re_new, &mut im_new); - fft_dit_codelet_32_staged_f32(simd, &mut re_legacy, &mut im_legacy); - }); - - for i in 0..n { - assert!( - (re_new[i] - re_legacy[i]).abs() < 1e-4, - "re[{i}]: new={}, legacy={}", - re_new[i], - re_legacy[i] - ); - assert!( - (im_new[i] - im_legacy[i]).abs() < 1e-4, - "im[{i}]: new={}, legacy={}", - im_new[i], - im_legacy[i] - ); - } - } } From 2345b58b9f2f66a7b92314a9f097def087d14979 Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 11 Apr 2026 20:23:43 +0100 Subject: [PATCH 04/17] Properly vectorize f32 codelet --- src/kernels/codelets.rs | 206 +++++++++++++++++++++++++++------------- src/planner.rs | 3 +- 2 files changed, 141 insertions(+), 68 deletions(-) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index 86b9dfc..8144de4 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -3,7 +3,9 @@ //! A codelet is a self-contained FFT kernel that fuses multiple stages into a single function //! call, eliminating per-stage function call overhead and giving LLVM a wider optimization window. //! -use fearless_simd::{f32x16, f32x4, f32x8, f64x4, Simd, SimdBase, SimdFloat, SimdFrom}; +use fearless_simd::{ + f32x16, f32x4, f32x8, f64x4, Simd, SimdBase, SimdCombine, SimdFloat, SimdFrom, SimdSplit, +}; /// FFT-32 codelet for `f64`: executes stages 0-4 (chunk_size 2 through 32) in a single function. /// @@ -319,73 +321,143 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut let mut v2_im = f32x8::from_slice(simd, &im[16..24]); let mut v3_im = f32x8::from_slice(simd, &im[24..32]); - // ---- Stages 0+1+2 fused: 8-point DIT on each vector (scalar) ---- - // Each f32x8 holds 8 consecutive elements. Stages 0 (dist=1), 1 (dist=2), - // and 2 (dist=4) all have butterfly pairs within the same 8-element group. - // We extract to scalars, do the full 8-point DIT, and repack. - macro_rules! dit8_inplace { - ($v_re:expr, $v_im:expr) => {{ - let mut r = [0.0f32; 8]; - let mut i = [0.0f32; 8]; - $v_re.store_slice(&mut r); - $v_im.store_slice(&mut i); - - // Stage 0+1 fused: radix-4 on elements [0,1,2,3] - let (a_re, a_im) = (r[0] + r[1], i[0] + i[1]); - let (b_re, b_im) = (r[0] - r[1], i[0] - i[1]); - let (c_re, c_im) = (r[2] + r[3], i[2] + i[3]); - let (d_re, d_im) = (r[2] - r[3], i[2] - i[3]); - let (dj_re, dj_im) = (d_im, -d_re); // -j * d - r[0] = a_re + c_re; - i[0] = a_im + c_im; - r[1] = b_re + dj_re; - i[1] = b_im + dj_im; - r[2] = a_re - c_re; - i[2] = a_im - c_im; - r[3] = b_re - dj_re; - i[3] = b_im - dj_im; - - // Stage 0+1 fused: radix-4 on elements [4,5,6,7] - let (a_re, a_im) = (r[4] + r[5], i[4] + i[5]); - let (b_re, b_im) = (r[4] - r[5], i[4] - i[5]); - let (c_re, c_im) = (r[6] + r[7], i[6] + i[7]); - let (d_re, d_im) = (r[6] - r[7], i[6] - i[7]); - let (dj_re, dj_im) = (d_im, -d_re); - r[4] = a_re + c_re; - i[4] = a_im + c_im; - r[5] = b_re + dj_re; - i[5] = b_im + dj_im; - r[6] = a_re - c_re; - i[6] = a_im - c_im; - r[7] = b_re - dj_re; - i[7] = b_im - dj_im; - - // Stage 2: dist=4 butterflies between (k, k+4) with W_8 twiddles - // W_8^0 = 1, W_8^1 = (1-j)/√2, W_8^2 = -j, W_8^3 = (-1-j)/√2 - const FRAC_1_SQRT_2: f32 = std::f32::consts::FRAC_1_SQRT_2; - let tw_re: [f32; 4] = [1.0, FRAC_1_SQRT_2, 0.0, -FRAC_1_SQRT_2]; - let tw_im: [f32; 4] = [0.0, -FRAC_1_SQRT_2, -1.0, -FRAC_1_SQRT_2]; - - for k in 0..4 { - let hi_re = tw_re[k] * r[k + 4] - tw_im[k] * i[k + 4]; - let hi_im = tw_re[k] * i[k + 4] + tw_im[k] * r[k + 4]; - let lo_re = r[k]; - let lo_im = i[k]; - r[k] = lo_re + hi_re; - i[k] = lo_im + hi_im; - r[k + 4] = lo_re - hi_re; - i[k + 4] = lo_im - hi_im; - } - - $v_re = f32x8::simd_from(simd, r); - $v_im = f32x8::simd_from(simd, i); - }}; - } + // ---- Stages 0+1+2 fused: 8-point DIT on all 4 vectors via transpose ---- + // Split each f32x8 into two f32x4 (lo=elems 0-3, hi=elems 4-7), transpose + // so each f32x4 holds one element position from all 4 groups, then do all + // butterfly stages as vertical f32x4 adds/subs/FMA. + { + // Step 1: Split f32x8 → f32x4 pairs + let (g0_lo_re, g0_hi_re) = v0_re.split(); + let (g1_lo_re, g1_hi_re) = v1_re.split(); + let (g2_lo_re, g2_hi_re) = v2_re.split(); + let (g3_lo_re, g3_hi_re) = v3_re.split(); + let (g0_lo_im, g0_hi_im) = v0_im.split(); + let (g1_lo_im, g1_hi_im) = v1_im.split(); + let (g2_lo_im, g2_hi_im) = v2_im.split(); + let (g3_lo_im, g3_hi_im) = v3_im.split(); + + // Step 2: 4×4 transpose on lo halves (re) + // After transpose, e_k_re[lane] = group lane's element k + macro_rules! transpose4x4 { + ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ + let t0 = $g0.zip_low($g2); // [g0[0], g2[0], g0[1], g2[1]] + let t1 = $g0.zip_high($g2); // [g0[2], g2[2], g0[3], g2[3]] + let t2 = $g1.zip_low($g3); // [g1[0], g3[0], g1[1], g3[1]] + let t3 = $g1.zip_high($g3); // [g1[2], g3[2], g1[3], g3[3]] + ( + t0.zip_low(t2), // [g0[0], g1[0], g2[0], g3[0]] + t0.zip_high(t2), // [g0[1], g1[1], g2[1], g3[1]] + t1.zip_low(t3), // [g0[2], g1[2], g2[2], g3[2]] + t1.zip_high(t3), // [g0[3], g1[3], g2[3], g3[3]] + ) + }}; + } - dit8_inplace!(v0_re, v0_im); - dit8_inplace!(v1_re, v1_im); - dit8_inplace!(v2_re, v2_im); - dit8_inplace!(v3_re, v3_im); + let (e0_re, e1_re, e2_re, e3_re) = + transpose4x4!(g0_lo_re, g1_lo_re, g2_lo_re, g3_lo_re); + let (e4_re, e5_re, e6_re, e7_re) = + transpose4x4!(g0_hi_re, g1_hi_re, g2_hi_re, g3_hi_re); + let (e0_im, e1_im, e2_im, e3_im) = + transpose4x4!(g0_lo_im, g1_lo_im, g2_lo_im, g3_lo_im); + let (e4_im, e5_im, e6_im, e7_im) = + transpose4x4!(g0_hi_im, g1_hi_im, g2_hi_im, g3_hi_im); + + // Step 3: Stage 0 (dist=1) — butterfly between adjacent elements + let s01_re = e0_re + e1_re; + let d01_re = e0_re - e1_re; + let s23_re = e2_re + e3_re; + let d23_re = e2_re - e3_re; + let s45_re = e4_re + e5_re; + let d45_re = e4_re - e5_re; + let s67_re = e6_re + e7_re; + let d67_re = e6_re - e7_re; + + let s01_im = e0_im + e1_im; + let d01_im = e0_im - e1_im; + let s23_im = e2_im + e3_im; + let d23_im = e2_im - e3_im; + let s45_im = e4_im + e5_im; + let d45_im = e4_im - e5_im; + let s67_im = e6_im + e7_im; + let d67_im = e6_im - e7_im; + + // Step 4: Stage 1 (dist=2) — W4^0=1, W4^1=-j twiddles + // Twiddle=1: butterfly(s01, s23), butterfly(s45, s67) + let p0_re = s01_re + s23_re; + let p2_re = s01_re - s23_re; + let p4_re = s45_re + s67_re; + let p6_re = s45_re - s67_re; + let p0_im = s01_im + s23_im; + let p2_im = s01_im - s23_im; + let p4_im = s45_im + s67_im; + let p6_im = s45_im - s67_im; + + // Twiddle=-j: -j*(re+j*im) = (im, -re) + // butterfly(d01, d23*(-j)), butterfly(d45, d67*(-j)) + let p1_re = d01_re + d23_im; + let p3_re = d01_re - d23_im; + let p1_im = d01_im - d23_re; + let p3_im = d01_im + d23_re; + let p5_re = d45_re + d67_im; + let p7_re = d45_re - d67_im; + let p5_im = d45_im - d67_re; + let p7_im = d45_im + d67_re; + + // Step 5: Stage 2 (dist=4) — W8^k twiddles + // W8^0 = 1+0j: just add/sub + let r0_re = p0_re + p4_re; + let r4_re = p0_re - p4_re; + let r0_im = p0_im + p4_im; + let r4_im = p0_im - p4_im; + + // W8^1 = (1-j)/√2: FMA twiddle butterfly + const FRAC_1_SQRT_2: f32 = std::f32::consts::FRAC_1_SQRT_2; + let tw1_re = f32x4::splat(simd, FRAC_1_SQRT_2); + let tw1_im = f32x4::splat(simd, -FRAC_1_SQRT_2); + // twiddled = tw * p5 = (tw_re*p5_re - tw_im*p5_im, tw_re*p5_im + tw_im*p5_re) + let tw_p5_re = tw1_im.mul_add(-p5_im, tw1_re * p5_re); + let tw_p5_im = tw1_im.mul_add(p5_re, tw1_re * p5_im); + let r1_re = p1_re + tw_p5_re; + let r5_re = p1_re - tw_p5_re; + let r1_im = p1_im + tw_p5_im; + let r5_im = p1_im - tw_p5_im; + + // W8^2 = 0-j: -j*(re+j*im) = (im, -re) + let r2_re = p2_re + p6_im; + let r6_re = p2_re - p6_im; + let r2_im = p2_im - p6_re; + let r6_im = p2_im + p6_re; + + // W8^3 = (-1-j)/√2: FMA twiddle butterfly + let tw3_re = f32x4::splat(simd, -FRAC_1_SQRT_2); + let tw3_im = f32x4::splat(simd, -FRAC_1_SQRT_2); + let tw_p7_re = tw3_im.mul_add(-p7_im, tw3_re * p7_re); + let tw_p7_im = tw3_im.mul_add(p7_re, tw3_re * p7_im); + let r3_re = p3_re + tw_p7_re; + let r7_re = p3_re - tw_p7_re; + let r3_im = p3_im + tw_p7_im; + let r7_im = p3_im - tw_p7_im; + + // Step 6: 4×4 transpose back to per-group layout + let (g0_lo_re, g1_lo_re, g2_lo_re, g3_lo_re) = + transpose4x4!(r0_re, r1_re, r2_re, r3_re); + let (g0_hi_re, g1_hi_re, g2_hi_re, g3_hi_re) = + transpose4x4!(r4_re, r5_re, r6_re, r7_re); + let (g0_lo_im, g1_lo_im, g2_lo_im, g3_lo_im) = + transpose4x4!(r0_im, r1_im, r2_im, r3_im); + let (g0_hi_im, g1_hi_im, g2_hi_im, g3_hi_im) = + transpose4x4!(r4_im, r5_im, r6_im, r7_im); + + // Step 7: Recombine f32x4 → f32x8 + v0_re = g0_lo_re.combine(g0_hi_re); + v1_re = g1_lo_re.combine(g1_hi_re); + v2_re = g2_lo_re.combine(g2_hi_re); + v3_re = g3_lo_re.combine(g3_hi_re); + v0_im = g0_lo_im.combine(g0_hi_im); + v1_im = g1_lo_im.combine(g1_hi_im); + v2_im = g2_lo_im.combine(g2_hi_im); + v3_im = g3_lo_im.combine(g3_hi_im); + } // Butterfly macro: twiddle-multiply hi, then add/sub with lo. // out_lo = lo + tw*hi, out_hi = lo - tw*hi (via 2*lo - out_lo). diff --git a/src/planner.rs b/src/planner.rs index 0783dea..f9869cf 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -127,7 +127,8 @@ macro_rules! impl_planner_dit_for { // codelet dominates runtime. On platforms with large L1i (e.g., // Apple Silicon at 192KB), Tune mode will discover that the // codelet wins at larger sizes too. - log_n <= 13 + // log_n <= 13 + true } /// Benchmark both paths and set `use_codelet_32` to whichever is faster. From 636e5011ad25a56c5c8666e96a76c6e617ac2a71 Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 11 Apr 2026 20:42:35 +0100 Subject: [PATCH 05/17] Reduce live set in f64 codelet to reduce register pressure --- src/kernels/codelets.rs | 120 ++++++++++++++++++++++++---------------- 1 file changed, 71 insertions(+), 49 deletions(-) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index 8144de4..3e8acb2 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -27,68 +27,90 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut let two = f64x4::splat(simd, 2.0); for (re, im) in reals.chunks_exact_mut(32).zip(imags.chunks_exact_mut(32)) { - // ---- Load into 8 f64x4 register pairs (re + im) ---- - // v_k holds elements [4k .. 4k+3] + macro_rules! transpose4x4_f64 { + ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ + let t0 = $g0.zip_low($g2); + let t1 = $g0.zip_high($g2); + let t2 = $g1.zip_low($g3); + let t3 = $g1.zip_high($g3); + ( + t0.zip_low(t2), + t0.zip_high(t2), + t1.zip_low(t3), + t1.zip_high(t3), + ) + }}; + } + + macro_rules! radix4_transpose { + ($va_re:expr, $vb_re:expr, $vc_re:expr, $vd_re:expr, + $va_im:expr, $vb_im:expr, $vc_im:expr, $vd_im:expr) => {{ + let (e0_re, e1_re, e2_re, e3_re) = + transpose4x4_f64!($va_re, $vb_re, $vc_re, $vd_re); + let (e0_im, e1_im, e2_im, e3_im) = + transpose4x4_f64!($va_im, $vb_im, $vc_im, $vd_im); + + let s01_re = e0_re + e1_re; + let d01_re = e0_re - e1_re; + let s23_re = e2_re + e3_re; + let d23_re = e2_re - e3_re; + let s01_im = e0_im + e1_im; + let d01_im = e0_im - e1_im; + let s23_im = e2_im + e3_im; + let d23_im = e2_im - e3_im; + + let p0_re = s01_re + s23_re; + let p2_re = s01_re - s23_re; + let p0_im = s01_im + s23_im; + let p2_im = s01_im - s23_im; + + let p1_re = d01_re + d23_im; + let p3_re = d01_re - d23_im; + let p1_im = d01_im - d23_re; + let p3_im = d01_im + d23_re; + + let (r0_re, r1_re, r2_re, r3_re) = + transpose4x4_f64!(p0_re, p1_re, p2_re, p3_re); + let (r0_im, r1_im, r2_im, r3_im) = + transpose4x4_f64!(p0_im, p1_im, p2_im, p3_im); + + $va_re = r0_re; + $vb_re = r1_re; + $vc_re = r2_re; + $vd_re = r3_re; + $va_im = r0_im; + $vb_im = r1_im; + $vc_im = r2_im; + $vd_im = r3_im; + }}; + } + + // ---- Load first group and do stages 0+1 ---- + // Deferring v4-v7 loads reduces peak register pressure during transpose. let mut v0_re = f64x4::from_slice(simd, &re[0..4]); let mut v1_re = f64x4::from_slice(simd, &re[4..8]); let mut v2_re = f64x4::from_slice(simd, &re[8..12]); let mut v3_re = f64x4::from_slice(simd, &re[12..16]); - let mut v4_re = f64x4::from_slice(simd, &re[16..20]); - let mut v5_re = f64x4::from_slice(simd, &re[20..24]); - let mut v6_re = f64x4::from_slice(simd, &re[24..28]); - let mut v7_re = f64x4::from_slice(simd, &re[28..32]); - let mut v0_im = f64x4::from_slice(simd, &im[0..4]); let mut v1_im = f64x4::from_slice(simd, &im[4..8]); let mut v2_im = f64x4::from_slice(simd, &im[8..12]); let mut v3_im = f64x4::from_slice(simd, &im[12..16]); + + radix4_transpose!(v0_re, v1_re, v2_re, v3_re, + v0_im, v1_im, v2_im, v3_im); + + // ---- Load second group and do stages 0+1 ---- + let mut v4_re = f64x4::from_slice(simd, &re[16..20]); + let mut v5_re = f64x4::from_slice(simd, &re[20..24]); + let mut v6_re = f64x4::from_slice(simd, &re[24..28]); + let mut v7_re = f64x4::from_slice(simd, &re[28..32]); let mut v4_im = f64x4::from_slice(simd, &im[16..20]); let mut v5_im = f64x4::from_slice(simd, &im[20..24]); let mut v6_im = f64x4::from_slice(simd, &im[24..28]); let mut v7_im = f64x4::from_slice(simd, &im[28..32]); - // ---- Stages 0+1 fused: radix-4 DIT on each group of 4 (within each vector) ---- - // Extract scalars, compute radix-4 butterfly, repack into f64x4. - macro_rules! radix4_inplace { - ($v_re:expr, $v_im:expr) => {{ - let mut re_arr = [0.0f64; 4]; - let mut im_arr = [0.0f64; 4]; - $v_re.store_slice(&mut re_arr); - $v_im.store_slice(&mut im_arr); - - let (a_re, a_im) = (re_arr[0], im_arr[0]); - let (b_re, b_im) = (re_arr[1], im_arr[1]); - let (c_re, c_im) = (re_arr[2], im_arr[2]); - let (d_re, d_im) = (re_arr[3], im_arr[3]); - - // Stage 0: dist=1 butterflies - let (t0_re, t0_im) = (a_re + b_re, a_im + b_im); - let (t1_re, t1_im) = (a_re - b_re, a_im - b_im); - let (t2_re, t2_im) = (c_re + d_re, c_im + d_im); - let (t3_re, t3_im) = (c_re - d_re, c_im - d_im); - - // Stage 1: -j multiply on t3, then dist=2 butterflies - let (t3j_re, t3j_im) = (t3_im, -t3_re); - - $v_re = f64x4::simd_from( - simd, - [t0_re + t2_re, t1_re + t3j_re, t0_re - t2_re, t1_re - t3j_re], - ); - $v_im = f64x4::simd_from( - simd, - [t0_im + t2_im, t1_im + t3j_im, t0_im - t2_im, t1_im - t3j_im], - ); - }}; - } - - radix4_inplace!(v0_re, v0_im); - radix4_inplace!(v1_re, v1_im); - radix4_inplace!(v2_re, v2_im); - radix4_inplace!(v3_re, v3_im); - radix4_inplace!(v4_re, v4_im); - radix4_inplace!(v5_re, v5_im); - radix4_inplace!(v6_re, v6_im); - radix4_inplace!(v7_re, v7_im); + radix4_transpose!(v4_re, v5_re, v6_re, v7_re, + v4_im, v5_im, v6_im, v7_im); // Butterfly macro: twiddle-multiply hi, then add/sub with lo. // out_lo = lo + tw*hi, out_hi = lo - tw*hi (via 2*lo - out_lo). From a2f7c2ac2821262cd4ef46e6131142f6ec752ebd Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 11 Apr 2026 20:49:26 +0100 Subject: [PATCH 06/17] Another attempt at massaging f64 assembly; didn't work --- src/kernels/codelets.rs | 192 ++++++++++++++++++++++++++++------------ 1 file changed, 134 insertions(+), 58 deletions(-) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index 3e8acb2..bef357f 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -27,66 +27,142 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut let two = f64x4::splat(simd, 2.0); for (re, im) in reals.chunks_exact_mut(32).zip(imags.chunks_exact_mut(32)) { - macro_rules! transpose4x4_f64 { - ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ - let t0 = $g0.zip_low($g2); - let t1 = $g0.zip_high($g2); - let t2 = $g1.zip_low($g3); - let t3 = $g1.zip_high($g3); - ( - t0.zip_low(t2), - t0.zip_high(t2), - t1.zip_low(t3), - t1.zip_high(t3), - ) - }}; - } - - macro_rules! radix4_transpose { + // ---- Stages 0+1 fused via f64x2 split + transposed butterflies ---- + // Split f64x4 into f64x2 halves (free: vextractf128). + // Use f64x2 zip (single SSE instr) to transpose pairs of vectors. + // Keep data transposed through both stages so all lanes get the same + // operation, then transpose back and combine to f64x4 (free: vinsertf128). + // + // Processing 2 vectors at a time (pairs a,b and c,d): + // After split, each vector's lo half has [elem0, elem1] and hi has [elem2, elem3]. + // Transpose a_lo with b_lo: sums_lo = [a_sum, b_sum], diffs_lo = [a_diff, b_diff] + // where sum = elem0+elem1, diff = elem0-elem1 (stage 0). + // Then stage 1 butterflies sums_lo ↔ sums_hi (twiddle=1) and + // diffs_lo ↔ diffs_hi (twiddle=-j), all lanes uniform. + + macro_rules! stages01 { ($va_re:expr, $vb_re:expr, $vc_re:expr, $vd_re:expr, $va_im:expr, $vb_im:expr, $vc_im:expr, $vd_im:expr) => {{ - let (e0_re, e1_re, e2_re, e3_re) = - transpose4x4_f64!($va_re, $vb_re, $vc_re, $vd_re); - let (e0_im, e1_im, e2_im, e3_im) = - transpose4x4_f64!($va_im, $vb_im, $vc_im, $vd_im); - - let s01_re = e0_re + e1_re; - let d01_re = e0_re - e1_re; - let s23_re = e2_re + e3_re; - let d23_re = e2_re - e3_re; - let s01_im = e0_im + e1_im; - let d01_im = e0_im - e1_im; - let s23_im = e2_im + e3_im; - let d23_im = e2_im - e3_im; - - let p0_re = s01_re + s23_re; - let p2_re = s01_re - s23_re; - let p0_im = s01_im + s23_im; - let p2_im = s01_im - s23_im; - - let p1_re = d01_re + d23_im; - let p3_re = d01_re - d23_im; - let p1_im = d01_im - d23_re; - let p3_im = d01_im + d23_re; - - let (r0_re, r1_re, r2_re, r3_re) = - transpose4x4_f64!(p0_re, p1_re, p2_re, p3_re); - let (r0_im, r1_im, r2_im, r3_im) = - transpose4x4_f64!(p0_im, p1_im, p2_im, p3_im); - - $va_re = r0_re; - $vb_re = r1_re; - $vc_re = r2_re; - $vd_re = r3_re; - $va_im = r0_im; - $vb_im = r1_im; - $vc_im = r2_im; - $vd_im = r3_im; + // Split f64x4 → f64x2 halves + let (a_lo_re, a_hi_re) = $va_re.split(); + let (b_lo_re, b_hi_re) = $vb_re.split(); + let (c_lo_re, c_hi_re) = $vc_re.split(); + let (d_lo_re, d_hi_re) = $vd_re.split(); + let (a_lo_im, a_hi_im) = $va_im.split(); + let (b_lo_im, b_hi_im) = $vb_im.split(); + let (c_lo_im, c_hi_im) = $vc_im.split(); + let (d_lo_im, d_hi_im) = $vd_im.split(); + + // ---- Stage 0 via 2×2 transpose ---- + // For pair (a,b) lo halves: + // evens = zip_low(a_lo, b_lo) = [a[0], b[0]] + // odds = zip_high(a_lo, b_lo) = [a[1], b[1]] + // sums = evens + odds (stage-0 sum for both vectors) + // diffs = evens - odds (stage-0 diff for both vectors) + // Data stays transposed: sums_ab_lo[0] = a's sum, sums_ab_lo[1] = b's sum + + // Pair (a,b) lo halves + let ab_lo_e0_re = a_lo_re.zip_low(b_lo_re); + let ab_lo_e1_re = a_lo_re.zip_high(b_lo_re); + let ab_lo_e0_im = a_lo_im.zip_low(b_lo_im); + let ab_lo_e1_im = a_lo_im.zip_high(b_lo_im); + let sums_ab_lo_re = ab_lo_e0_re + ab_lo_e1_re; + let diffs_ab_lo_re = ab_lo_e0_re - ab_lo_e1_re; + let sums_ab_lo_im = ab_lo_e0_im + ab_lo_e1_im; + let diffs_ab_lo_im = ab_lo_e0_im - ab_lo_e1_im; + + // Pair (a,b) hi halves + let ab_hi_e0_re = a_hi_re.zip_low(b_hi_re); + let ab_hi_e1_re = a_hi_re.zip_high(b_hi_re); + let ab_hi_e0_im = a_hi_im.zip_low(b_hi_im); + let ab_hi_e1_im = a_hi_im.zip_high(b_hi_im); + let sums_ab_hi_re = ab_hi_e0_re + ab_hi_e1_re; + let diffs_ab_hi_re = ab_hi_e0_re - ab_hi_e1_re; + let sums_ab_hi_im = ab_hi_e0_im + ab_hi_e1_im; + let diffs_ab_hi_im = ab_hi_e0_im - ab_hi_e1_im; + + // Pair (c,d) lo halves + let cd_lo_e0_re = c_lo_re.zip_low(d_lo_re); + let cd_lo_e1_re = c_lo_re.zip_high(d_lo_re); + let cd_lo_e0_im = c_lo_im.zip_low(d_lo_im); + let cd_lo_e1_im = c_lo_im.zip_high(d_lo_im); + let sums_cd_lo_re = cd_lo_e0_re + cd_lo_e1_re; + let diffs_cd_lo_re = cd_lo_e0_re - cd_lo_e1_re; + let sums_cd_lo_im = cd_lo_e0_im + cd_lo_e1_im; + let diffs_cd_lo_im = cd_lo_e0_im - cd_lo_e1_im; + + // Pair (c,d) hi halves + let cd_hi_e0_re = c_hi_re.zip_low(d_hi_re); + let cd_hi_e1_re = c_hi_re.zip_high(d_hi_re); + let cd_hi_e0_im = c_hi_im.zip_low(d_hi_im); + let cd_hi_e1_im = c_hi_im.zip_high(d_hi_im); + let sums_cd_hi_re = cd_hi_e0_re + cd_hi_e1_re; + let diffs_cd_hi_re = cd_hi_e0_re - cd_hi_e1_re; + let sums_cd_hi_im = cd_hi_e0_im + cd_hi_e1_im; + let diffs_cd_hi_im = cd_hi_e0_im - cd_hi_e1_im; + + // ---- Stage 1: butterfly lo↔hi (dist=2) ---- + // sums butterfly with twiddle=1: just add/sub + let p0_ab_re = sums_ab_lo_re + sums_ab_hi_re; + let p2_ab_re = sums_ab_lo_re - sums_ab_hi_re; + let p0_ab_im = sums_ab_lo_im + sums_ab_hi_im; + let p2_ab_im = sums_ab_lo_im - sums_ab_hi_im; + + let p0_cd_re = sums_cd_lo_re + sums_cd_hi_re; + let p2_cd_re = sums_cd_lo_re - sums_cd_hi_re; + let p0_cd_im = sums_cd_lo_im + sums_cd_hi_im; + let p2_cd_im = sums_cd_lo_im - sums_cd_hi_im; + + // diffs butterfly with twiddle=-j: -j*(re+j*im) = (im, -re) + let p1_ab_re = diffs_ab_lo_re + diffs_ab_hi_im; + let p3_ab_re = diffs_ab_lo_re - diffs_ab_hi_im; + let p1_ab_im = diffs_ab_lo_im - diffs_ab_hi_re; + let p3_ab_im = diffs_ab_lo_im + diffs_ab_hi_re; + + let p1_cd_re = diffs_cd_lo_re + diffs_cd_hi_im; + let p3_cd_re = diffs_cd_lo_re - diffs_cd_hi_im; + let p1_cd_im = diffs_cd_lo_im - diffs_cd_hi_re; + let p3_cd_im = diffs_cd_lo_im + diffs_cd_hi_re; + + // ---- Transpose back (2×2) and combine ---- + // Transpose back: zip_low/zip_high to go from [a_val, b_val] + // back to per-vector layout. + // For vector a: lo = [p0, p1], hi = [p2, p3] + // p0_ab = [a_p0, b_p0], p1_ab = [a_p1, b_p1] + // zip_low(p0_ab, p1_ab) = [a_p0, a_p1] = a's lo half + // zip_high(p0_ab, p1_ab) = [b_p0, b_p1] = b's lo half + + let a_lo_re = p0_ab_re.zip_low(p1_ab_re); + let b_lo_re = p0_ab_re.zip_high(p1_ab_re); + let a_hi_re = p2_ab_re.zip_low(p3_ab_re); + let b_hi_re = p2_ab_re.zip_high(p3_ab_re); + let a_lo_im = p0_ab_im.zip_low(p1_ab_im); + let b_lo_im = p0_ab_im.zip_high(p1_ab_im); + let a_hi_im = p2_ab_im.zip_low(p3_ab_im); + let b_hi_im = p2_ab_im.zip_high(p3_ab_im); + + let c_lo_re = p0_cd_re.zip_low(p1_cd_re); + let d_lo_re = p0_cd_re.zip_high(p1_cd_re); + let c_hi_re = p2_cd_re.zip_low(p3_cd_re); + let d_hi_re = p2_cd_re.zip_high(p3_cd_re); + let c_lo_im = p0_cd_im.zip_low(p1_cd_im); + let d_lo_im = p0_cd_im.zip_high(p1_cd_im); + let c_hi_im = p2_cd_im.zip_low(p3_cd_im); + let d_hi_im = p2_cd_im.zip_high(p3_cd_im); + + // Combine f64x2 → f64x4 + $va_re = a_lo_re.combine(a_hi_re); + $vb_re = b_lo_re.combine(b_hi_re); + $vc_re = c_lo_re.combine(c_hi_re); + $vd_re = d_lo_re.combine(d_hi_re); + $va_im = a_lo_im.combine(a_hi_im); + $vb_im = b_lo_im.combine(b_hi_im); + $vc_im = c_lo_im.combine(c_hi_im); + $vd_im = d_lo_im.combine(d_hi_im); }}; } // ---- Load first group and do stages 0+1 ---- - // Deferring v4-v7 loads reduces peak register pressure during transpose. let mut v0_re = f64x4::from_slice(simd, &re[0..4]); let mut v1_re = f64x4::from_slice(simd, &re[4..8]); let mut v2_re = f64x4::from_slice(simd, &re[8..12]); @@ -96,8 +172,8 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut let mut v2_im = f64x4::from_slice(simd, &im[8..12]); let mut v3_im = f64x4::from_slice(simd, &im[12..16]); - radix4_transpose!(v0_re, v1_re, v2_re, v3_re, - v0_im, v1_im, v2_im, v3_im); + stages01!(v0_re, v1_re, v2_re, v3_re, + v0_im, v1_im, v2_im, v3_im); // ---- Load second group and do stages 0+1 ---- let mut v4_re = f64x4::from_slice(simd, &re[16..20]); @@ -109,8 +185,8 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut let mut v6_im = f64x4::from_slice(simd, &im[24..28]); let mut v7_im = f64x4::from_slice(simd, &im[28..32]); - radix4_transpose!(v4_re, v5_re, v6_re, v7_re, - v4_im, v5_im, v6_im, v7_im); + stages01!(v4_re, v5_re, v6_re, v7_re, + v4_im, v5_im, v6_im, v7_im); // Butterfly macro: twiddle-multiply hi, then add/sub with lo. // out_lo = lo + tw*hi, out_hi = lo - tw*hi (via 2*lo - out_lo). From 1356ff229bf1b0cd37c9cefb11df1ca1d4428a2f Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 11 Apr 2026 20:49:35 +0100 Subject: [PATCH 07/17] Revert "Another attempt at massaging f64 assembly; didn't work" This reverts commit a2f7c2ac2821262cd4ef46e6131142f6ec752ebd. --- src/kernels/codelets.rs | 192 ++++++++++++---------------------------- 1 file changed, 58 insertions(+), 134 deletions(-) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index bef357f..3e8acb2 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -27,142 +27,66 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut let two = f64x4::splat(simd, 2.0); for (re, im) in reals.chunks_exact_mut(32).zip(imags.chunks_exact_mut(32)) { - // ---- Stages 0+1 fused via f64x2 split + transposed butterflies ---- - // Split f64x4 into f64x2 halves (free: vextractf128). - // Use f64x2 zip (single SSE instr) to transpose pairs of vectors. - // Keep data transposed through both stages so all lanes get the same - // operation, then transpose back and combine to f64x4 (free: vinsertf128). - // - // Processing 2 vectors at a time (pairs a,b and c,d): - // After split, each vector's lo half has [elem0, elem1] and hi has [elem2, elem3]. - // Transpose a_lo with b_lo: sums_lo = [a_sum, b_sum], diffs_lo = [a_diff, b_diff] - // where sum = elem0+elem1, diff = elem0-elem1 (stage 0). - // Then stage 1 butterflies sums_lo ↔ sums_hi (twiddle=1) and - // diffs_lo ↔ diffs_hi (twiddle=-j), all lanes uniform. - - macro_rules! stages01 { + macro_rules! transpose4x4_f64 { + ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ + let t0 = $g0.zip_low($g2); + let t1 = $g0.zip_high($g2); + let t2 = $g1.zip_low($g3); + let t3 = $g1.zip_high($g3); + ( + t0.zip_low(t2), + t0.zip_high(t2), + t1.zip_low(t3), + t1.zip_high(t3), + ) + }}; + } + + macro_rules! radix4_transpose { ($va_re:expr, $vb_re:expr, $vc_re:expr, $vd_re:expr, $va_im:expr, $vb_im:expr, $vc_im:expr, $vd_im:expr) => {{ - // Split f64x4 → f64x2 halves - let (a_lo_re, a_hi_re) = $va_re.split(); - let (b_lo_re, b_hi_re) = $vb_re.split(); - let (c_lo_re, c_hi_re) = $vc_re.split(); - let (d_lo_re, d_hi_re) = $vd_re.split(); - let (a_lo_im, a_hi_im) = $va_im.split(); - let (b_lo_im, b_hi_im) = $vb_im.split(); - let (c_lo_im, c_hi_im) = $vc_im.split(); - let (d_lo_im, d_hi_im) = $vd_im.split(); - - // ---- Stage 0 via 2×2 transpose ---- - // For pair (a,b) lo halves: - // evens = zip_low(a_lo, b_lo) = [a[0], b[0]] - // odds = zip_high(a_lo, b_lo) = [a[1], b[1]] - // sums = evens + odds (stage-0 sum for both vectors) - // diffs = evens - odds (stage-0 diff for both vectors) - // Data stays transposed: sums_ab_lo[0] = a's sum, sums_ab_lo[1] = b's sum - - // Pair (a,b) lo halves - let ab_lo_e0_re = a_lo_re.zip_low(b_lo_re); - let ab_lo_e1_re = a_lo_re.zip_high(b_lo_re); - let ab_lo_e0_im = a_lo_im.zip_low(b_lo_im); - let ab_lo_e1_im = a_lo_im.zip_high(b_lo_im); - let sums_ab_lo_re = ab_lo_e0_re + ab_lo_e1_re; - let diffs_ab_lo_re = ab_lo_e0_re - ab_lo_e1_re; - let sums_ab_lo_im = ab_lo_e0_im + ab_lo_e1_im; - let diffs_ab_lo_im = ab_lo_e0_im - ab_lo_e1_im; - - // Pair (a,b) hi halves - let ab_hi_e0_re = a_hi_re.zip_low(b_hi_re); - let ab_hi_e1_re = a_hi_re.zip_high(b_hi_re); - let ab_hi_e0_im = a_hi_im.zip_low(b_hi_im); - let ab_hi_e1_im = a_hi_im.zip_high(b_hi_im); - let sums_ab_hi_re = ab_hi_e0_re + ab_hi_e1_re; - let diffs_ab_hi_re = ab_hi_e0_re - ab_hi_e1_re; - let sums_ab_hi_im = ab_hi_e0_im + ab_hi_e1_im; - let diffs_ab_hi_im = ab_hi_e0_im - ab_hi_e1_im; - - // Pair (c,d) lo halves - let cd_lo_e0_re = c_lo_re.zip_low(d_lo_re); - let cd_lo_e1_re = c_lo_re.zip_high(d_lo_re); - let cd_lo_e0_im = c_lo_im.zip_low(d_lo_im); - let cd_lo_e1_im = c_lo_im.zip_high(d_lo_im); - let sums_cd_lo_re = cd_lo_e0_re + cd_lo_e1_re; - let diffs_cd_lo_re = cd_lo_e0_re - cd_lo_e1_re; - let sums_cd_lo_im = cd_lo_e0_im + cd_lo_e1_im; - let diffs_cd_lo_im = cd_lo_e0_im - cd_lo_e1_im; - - // Pair (c,d) hi halves - let cd_hi_e0_re = c_hi_re.zip_low(d_hi_re); - let cd_hi_e1_re = c_hi_re.zip_high(d_hi_re); - let cd_hi_e0_im = c_hi_im.zip_low(d_hi_im); - let cd_hi_e1_im = c_hi_im.zip_high(d_hi_im); - let sums_cd_hi_re = cd_hi_e0_re + cd_hi_e1_re; - let diffs_cd_hi_re = cd_hi_e0_re - cd_hi_e1_re; - let sums_cd_hi_im = cd_hi_e0_im + cd_hi_e1_im; - let diffs_cd_hi_im = cd_hi_e0_im - cd_hi_e1_im; - - // ---- Stage 1: butterfly lo↔hi (dist=2) ---- - // sums butterfly with twiddle=1: just add/sub - let p0_ab_re = sums_ab_lo_re + sums_ab_hi_re; - let p2_ab_re = sums_ab_lo_re - sums_ab_hi_re; - let p0_ab_im = sums_ab_lo_im + sums_ab_hi_im; - let p2_ab_im = sums_ab_lo_im - sums_ab_hi_im; - - let p0_cd_re = sums_cd_lo_re + sums_cd_hi_re; - let p2_cd_re = sums_cd_lo_re - sums_cd_hi_re; - let p0_cd_im = sums_cd_lo_im + sums_cd_hi_im; - let p2_cd_im = sums_cd_lo_im - sums_cd_hi_im; - - // diffs butterfly with twiddle=-j: -j*(re+j*im) = (im, -re) - let p1_ab_re = diffs_ab_lo_re + diffs_ab_hi_im; - let p3_ab_re = diffs_ab_lo_re - diffs_ab_hi_im; - let p1_ab_im = diffs_ab_lo_im - diffs_ab_hi_re; - let p3_ab_im = diffs_ab_lo_im + diffs_ab_hi_re; - - let p1_cd_re = diffs_cd_lo_re + diffs_cd_hi_im; - let p3_cd_re = diffs_cd_lo_re - diffs_cd_hi_im; - let p1_cd_im = diffs_cd_lo_im - diffs_cd_hi_re; - let p3_cd_im = diffs_cd_lo_im + diffs_cd_hi_re; - - // ---- Transpose back (2×2) and combine ---- - // Transpose back: zip_low/zip_high to go from [a_val, b_val] - // back to per-vector layout. - // For vector a: lo = [p0, p1], hi = [p2, p3] - // p0_ab = [a_p0, b_p0], p1_ab = [a_p1, b_p1] - // zip_low(p0_ab, p1_ab) = [a_p0, a_p1] = a's lo half - // zip_high(p0_ab, p1_ab) = [b_p0, b_p1] = b's lo half - - let a_lo_re = p0_ab_re.zip_low(p1_ab_re); - let b_lo_re = p0_ab_re.zip_high(p1_ab_re); - let a_hi_re = p2_ab_re.zip_low(p3_ab_re); - let b_hi_re = p2_ab_re.zip_high(p3_ab_re); - let a_lo_im = p0_ab_im.zip_low(p1_ab_im); - let b_lo_im = p0_ab_im.zip_high(p1_ab_im); - let a_hi_im = p2_ab_im.zip_low(p3_ab_im); - let b_hi_im = p2_ab_im.zip_high(p3_ab_im); - - let c_lo_re = p0_cd_re.zip_low(p1_cd_re); - let d_lo_re = p0_cd_re.zip_high(p1_cd_re); - let c_hi_re = p2_cd_re.zip_low(p3_cd_re); - let d_hi_re = p2_cd_re.zip_high(p3_cd_re); - let c_lo_im = p0_cd_im.zip_low(p1_cd_im); - let d_lo_im = p0_cd_im.zip_high(p1_cd_im); - let c_hi_im = p2_cd_im.zip_low(p3_cd_im); - let d_hi_im = p2_cd_im.zip_high(p3_cd_im); - - // Combine f64x2 → f64x4 - $va_re = a_lo_re.combine(a_hi_re); - $vb_re = b_lo_re.combine(b_hi_re); - $vc_re = c_lo_re.combine(c_hi_re); - $vd_re = d_lo_re.combine(d_hi_re); - $va_im = a_lo_im.combine(a_hi_im); - $vb_im = b_lo_im.combine(b_hi_im); - $vc_im = c_lo_im.combine(c_hi_im); - $vd_im = d_lo_im.combine(d_hi_im); + let (e0_re, e1_re, e2_re, e3_re) = + transpose4x4_f64!($va_re, $vb_re, $vc_re, $vd_re); + let (e0_im, e1_im, e2_im, e3_im) = + transpose4x4_f64!($va_im, $vb_im, $vc_im, $vd_im); + + let s01_re = e0_re + e1_re; + let d01_re = e0_re - e1_re; + let s23_re = e2_re + e3_re; + let d23_re = e2_re - e3_re; + let s01_im = e0_im + e1_im; + let d01_im = e0_im - e1_im; + let s23_im = e2_im + e3_im; + let d23_im = e2_im - e3_im; + + let p0_re = s01_re + s23_re; + let p2_re = s01_re - s23_re; + let p0_im = s01_im + s23_im; + let p2_im = s01_im - s23_im; + + let p1_re = d01_re + d23_im; + let p3_re = d01_re - d23_im; + let p1_im = d01_im - d23_re; + let p3_im = d01_im + d23_re; + + let (r0_re, r1_re, r2_re, r3_re) = + transpose4x4_f64!(p0_re, p1_re, p2_re, p3_re); + let (r0_im, r1_im, r2_im, r3_im) = + transpose4x4_f64!(p0_im, p1_im, p2_im, p3_im); + + $va_re = r0_re; + $vb_re = r1_re; + $vc_re = r2_re; + $vd_re = r3_re; + $va_im = r0_im; + $vb_im = r1_im; + $vc_im = r2_im; + $vd_im = r3_im; }}; } // ---- Load first group and do stages 0+1 ---- + // Deferring v4-v7 loads reduces peak register pressure during transpose. let mut v0_re = f64x4::from_slice(simd, &re[0..4]); let mut v1_re = f64x4::from_slice(simd, &re[4..8]); let mut v2_re = f64x4::from_slice(simd, &re[8..12]); @@ -172,8 +96,8 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut let mut v2_im = f64x4::from_slice(simd, &im[8..12]); let mut v3_im = f64x4::from_slice(simd, &im[12..16]); - stages01!(v0_re, v1_re, v2_re, v3_re, - v0_im, v1_im, v2_im, v3_im); + radix4_transpose!(v0_re, v1_re, v2_re, v3_re, + v0_im, v1_im, v2_im, v3_im); // ---- Load second group and do stages 0+1 ---- let mut v4_re = f64x4::from_slice(simd, &re[16..20]); @@ -185,8 +109,8 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut let mut v6_im = f64x4::from_slice(simd, &im[24..28]); let mut v7_im = f64x4::from_slice(simd, &im[28..32]); - stages01!(v4_re, v5_re, v6_re, v7_re, - v4_im, v5_im, v6_im, v7_im); + radix4_transpose!(v4_re, v5_re, v6_re, v7_re, + v4_im, v5_im, v6_im, v7_im); // Butterfly macro: twiddle-multiply hi, then add/sub with lo. // out_lo = lo + tw*hi, out_hi = lo - tw*hi (via 2*lo - out_lo). From bb293bc0710953d8f63cc1fedbcf9e823c1e302c Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 11 Apr 2026 23:26:55 +0100 Subject: [PATCH 08/17] Address the register spills in f32 codelet passes 0 and 1 --- src/kernels/codelets.rs | 159 +++++++++++++++++++--------------------- 1 file changed, 75 insertions(+), 84 deletions(-) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index 3e8acb2..495ab81 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -69,10 +69,8 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut let p1_im = d01_im - d23_re; let p3_im = d01_im + d23_re; - let (r0_re, r1_re, r2_re, r3_re) = - transpose4x4_f64!(p0_re, p1_re, p2_re, p3_re); - let (r0_im, r1_im, r2_im, r3_im) = - transpose4x4_f64!(p0_im, p1_im, p2_im, p3_im); + let (r0_re, r1_re, r2_re, r3_re) = transpose4x4_f64!(p0_re, p1_re, p2_re, p3_re); + let (r0_im, r1_im, r2_im, r3_im) = transpose4x4_f64!(p0_im, p1_im, p2_im, p3_im); $va_re = r0_re; $vb_re = r1_re; @@ -96,8 +94,7 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut let mut v2_im = f64x4::from_slice(simd, &im[8..12]); let mut v3_im = f64x4::from_slice(simd, &im[12..16]); - radix4_transpose!(v0_re, v1_re, v2_re, v3_re, - v0_im, v1_im, v2_im, v3_im); + radix4_transpose!(v0_re, v1_re, v2_re, v3_re, v0_im, v1_im, v2_im, v3_im); // ---- Load second group and do stages 0+1 ---- let mut v4_re = f64x4::from_slice(simd, &re[16..20]); @@ -109,8 +106,7 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut let mut v6_im = f64x4::from_slice(simd, &im[24..28]); let mut v7_im = f64x4::from_slice(simd, &im[28..32]); - radix4_transpose!(v4_re, v5_re, v6_re, v7_re, - v4_im, v5_im, v6_im, v7_im); + radix4_transpose!(v4_re, v5_re, v6_re, v7_re, v4_im, v5_im, v6_im, v7_im); // Butterfly macro: twiddle-multiply hi, then add/sub with lo. // out_lo = lo + tw*hi, out_hi = lo - tw*hi (via 2*lo - out_lo). @@ -344,11 +340,65 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut let mut v3_im = f32x8::from_slice(simd, &im[24..32]); // ---- Stages 0+1+2 fused: 8-point DIT on all 4 vectors via transpose ---- - // Split each f32x8 into two f32x4 (lo=elems 0-3, hi=elems 4-7), transpose - // so each f32x4 holds one element position from all 4 groups, then do all - // butterfly stages as vertical f32x4 adds/subs/FMA. + // To reduce register pressure, process lo halves (elements 0-3) through + // stages 0+1 first, then hi halves (elements 4-7) through stages 0+1. + // Data stays in transposed (per-element) layout between stages 0+1 and + // stage 2, avoiding redundant transpose pairs. Only one inverse transpose + // at the end. Peak active register usage during stages 0+1: ~8 f32x4. { - // Step 1: Split f32x8 → f32x4 pairs + macro_rules! transpose4x4 { + ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ + let t0 = $g0.zip_low($g2); + let t1 = $g0.zip_high($g2); + let t2 = $g1.zip_low($g3); + let t3 = $g1.zip_high($g3); + ( + t0.zip_low(t2), + t0.zip_high(t2), + t1.zip_low(t3), + t1.zip_high(t3), + ) + }}; + } + + // Stages 0+1 (radix-4 DIT) on transposed data. + // Input: 4 f32x4 in per-group layout. Output: 4 f32x4 in per-element + // (transposed) layout — p0, p1, p2, p3 where each lane is one group. + macro_rules! radix4_transpose_fwd { + ($ga_re:expr, $gb_re:expr, $gc_re:expr, $gd_re:expr, + $ga_im:expr, $gb_im:expr, $gc_im:expr, $gd_im:expr) => {{ + let (e0_re, e1_re, e2_re, e3_re) = + transpose4x4!($ga_re, $gb_re, $gc_re, $gd_re); + let (e0_im, e1_im, e2_im, e3_im) = + transpose4x4!($ga_im, $gb_im, $gc_im, $gd_im); + + // Stage 0 (dist=1) + let s01_re = e0_re + e1_re; + let d01_re = e0_re - e1_re; + let s23_re = e2_re + e3_re; + let d23_re = e2_re - e3_re; + let s01_im = e0_im + e1_im; + let d01_im = e0_im - e1_im; + let s23_im = e2_im + e3_im; + let d23_im = e2_im - e3_im; + + // Stage 1 (dist=2): W4^0=1, W4^1=-j + let p0_re = s01_re + s23_re; + let p2_re = s01_re - s23_re; + let p0_im = s01_im + s23_im; + let p2_im = s01_im - s23_im; + + let p1_re = d01_re + d23_im; + let p3_re = d01_re - d23_im; + let p1_im = d01_im - d23_re; + let p3_im = d01_im + d23_re; + + // Return in per-element (transposed) layout + (p0_re, p1_re, p2_re, p3_re, p0_im, p1_im, p2_im, p3_im) + }}; + } + + // Process lo halves (elements 0-3) — result in transposed layout let (g0_lo_re, g0_hi_re) = v0_re.split(); let (g1_lo_re, g1_hi_re) = v1_re.split(); let (g2_lo_re, g2_hi_re) = v2_re.split(); @@ -358,85 +408,27 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut let (g2_lo_im, g2_hi_im) = v2_im.split(); let (g3_lo_im, g3_hi_im) = v3_im.split(); - // Step 2: 4×4 transpose on lo halves (re) - // After transpose, e_k_re[lane] = group lane's element k - macro_rules! transpose4x4 { - ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ - let t0 = $g0.zip_low($g2); // [g0[0], g2[0], g0[1], g2[1]] - let t1 = $g0.zip_high($g2); // [g0[2], g2[2], g0[3], g2[3]] - let t2 = $g1.zip_low($g3); // [g1[0], g3[0], g1[1], g3[1]] - let t3 = $g1.zip_high($g3); // [g1[2], g3[2], g1[3], g3[3]] - ( - t0.zip_low(t2), // [g0[0], g1[0], g2[0], g3[0]] - t0.zip_high(t2), // [g0[1], g1[1], g2[1], g3[1]] - t1.zip_low(t3), // [g0[2], g1[2], g2[2], g3[2]] - t1.zip_high(t3), // [g0[3], g1[3], g2[3], g3[3]] - ) - }}; - } + let (p0_re, p1_re, p2_re, p3_re, p0_im, p1_im, p2_im, p3_im) = radix4_transpose_fwd!( + g0_lo_re, g1_lo_re, g2_lo_re, g3_lo_re, g0_lo_im, g1_lo_im, g2_lo_im, g3_lo_im + ); + + // Process hi halves (elements 4-7) — result in transposed layout + let (p4_re, p5_re, p6_re, p7_re, p4_im, p5_im, p6_im, p7_im) = radix4_transpose_fwd!( + g0_hi_re, g1_hi_re, g2_hi_re, g3_hi_re, g0_hi_im, g1_hi_im, g2_hi_im, g3_hi_im + ); + + // Stage 2 (dist=4) — W8^k twiddles, already in per-element layout - let (e0_re, e1_re, e2_re, e3_re) = - transpose4x4!(g0_lo_re, g1_lo_re, g2_lo_re, g3_lo_re); - let (e4_re, e5_re, e6_re, e7_re) = - transpose4x4!(g0_hi_re, g1_hi_re, g2_hi_re, g3_hi_re); - let (e0_im, e1_im, e2_im, e3_im) = - transpose4x4!(g0_lo_im, g1_lo_im, g2_lo_im, g3_lo_im); - let (e4_im, e5_im, e6_im, e7_im) = - transpose4x4!(g0_hi_im, g1_hi_im, g2_hi_im, g3_hi_im); - - // Step 3: Stage 0 (dist=1) — butterfly between adjacent elements - let s01_re = e0_re + e1_re; - let d01_re = e0_re - e1_re; - let s23_re = e2_re + e3_re; - let d23_re = e2_re - e3_re; - let s45_re = e4_re + e5_re; - let d45_re = e4_re - e5_re; - let s67_re = e6_re + e7_re; - let d67_re = e6_re - e7_re; - - let s01_im = e0_im + e1_im; - let d01_im = e0_im - e1_im; - let s23_im = e2_im + e3_im; - let d23_im = e2_im - e3_im; - let s45_im = e4_im + e5_im; - let d45_im = e4_im - e5_im; - let s67_im = e6_im + e7_im; - let d67_im = e6_im - e7_im; - - // Step 4: Stage 1 (dist=2) — W4^0=1, W4^1=-j twiddles - // Twiddle=1: butterfly(s01, s23), butterfly(s45, s67) - let p0_re = s01_re + s23_re; - let p2_re = s01_re - s23_re; - let p4_re = s45_re + s67_re; - let p6_re = s45_re - s67_re; - let p0_im = s01_im + s23_im; - let p2_im = s01_im - s23_im; - let p4_im = s45_im + s67_im; - let p6_im = s45_im - s67_im; - - // Twiddle=-j: -j*(re+j*im) = (im, -re) - // butterfly(d01, d23*(-j)), butterfly(d45, d67*(-j)) - let p1_re = d01_re + d23_im; - let p3_re = d01_re - d23_im; - let p1_im = d01_im - d23_re; - let p3_im = d01_im + d23_re; - let p5_re = d45_re + d67_im; - let p7_re = d45_re - d67_im; - let p5_im = d45_im - d67_re; - let p7_im = d45_im + d67_re; - - // Step 5: Stage 2 (dist=4) — W8^k twiddles // W8^0 = 1+0j: just add/sub let r0_re = p0_re + p4_re; let r4_re = p0_re - p4_re; let r0_im = p0_im + p4_im; let r4_im = p0_im - p4_im; - // W8^1 = (1-j)/√2: FMA twiddle butterfly + // W8^1 = (1-j)/√2 const FRAC_1_SQRT_2: f32 = std::f32::consts::FRAC_1_SQRT_2; let tw1_re = f32x4::splat(simd, FRAC_1_SQRT_2); let tw1_im = f32x4::splat(simd, -FRAC_1_SQRT_2); - // twiddled = tw * p5 = (tw_re*p5_re - tw_im*p5_im, tw_re*p5_im + tw_im*p5_re) let tw_p5_re = tw1_im.mul_add(-p5_im, tw1_re * p5_re); let tw_p5_im = tw1_im.mul_add(p5_re, tw1_re * p5_im); let r1_re = p1_re + tw_p5_re; @@ -450,7 +442,7 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut let r2_im = p2_im - p6_re; let r6_im = p2_im + p6_re; - // W8^3 = (-1-j)/√2: FMA twiddle butterfly + // W8^3 = (-1-j)/√2 let tw3_re = f32x4::splat(simd, -FRAC_1_SQRT_2); let tw3_im = f32x4::splat(simd, -FRAC_1_SQRT_2); let tw_p7_re = tw3_im.mul_add(-p7_im, tw3_re * p7_re); @@ -460,7 +452,7 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut let r3_im = p3_im + tw_p7_im; let r7_im = p3_im - tw_p7_im; - // Step 6: 4×4 transpose back to per-group layout + // Single inverse transpose and recombine f32x4 → f32x8 let (g0_lo_re, g1_lo_re, g2_lo_re, g3_lo_re) = transpose4x4!(r0_re, r1_re, r2_re, r3_re); let (g0_hi_re, g1_hi_re, g2_hi_re, g3_hi_re) = @@ -470,7 +462,6 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut let (g0_hi_im, g1_hi_im, g2_hi_im, g3_hi_im) = transpose4x4!(r4_im, r5_im, r6_im, r7_im); - // Step 7: Recombine f32x4 → f32x8 v0_re = g0_lo_re.combine(g0_hi_re); v1_re = g1_lo_re.combine(g1_hi_re); v2_re = g2_lo_re.combine(g2_hi_re); From ee3958c6b405d802d2882cd79fc8e51599065c53 Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 11 Apr 2026 23:32:33 +0100 Subject: [PATCH 09/17] Tighten up stage 2 assembly in f32 codelet --- src/kernels/codelets.rs | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index 495ab81..400966f 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -425,12 +425,13 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut let r0_im = p0_im + p4_im; let r4_im = p0_im - p4_im; - // W8^1 = (1-j)/√2 + // W8^1 = (1-j)/√2: twiddle * p5, expressed to avoid double-negation. + // tw = (s, -s) where s = 1/√2. + // tw*p5 = (s*p5_re + s*p5_im, s*p5_im - s*p5_re) const FRAC_1_SQRT_2: f32 = std::f32::consts::FRAC_1_SQRT_2; - let tw1_re = f32x4::splat(simd, FRAC_1_SQRT_2); - let tw1_im = f32x4::splat(simd, -FRAC_1_SQRT_2); - let tw_p5_re = tw1_im.mul_add(-p5_im, tw1_re * p5_re); - let tw_p5_im = tw1_im.mul_add(p5_re, tw1_re * p5_im); + let s = f32x4::splat(simd, FRAC_1_SQRT_2); + let tw_p5_re = s.mul_add(p5_im, s * p5_re); + let tw_p5_im = s.mul_sub(p5_im, s * p5_re); let r1_re = p1_re + tw_p5_re; let r5_re = p1_re - tw_p5_re; let r1_im = p1_im + tw_p5_im; @@ -442,15 +443,18 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut let r2_im = p2_im - p6_re; let r6_im = p2_im + p6_re; - // W8^3 = (-1-j)/√2 - let tw3_re = f32x4::splat(simd, -FRAC_1_SQRT_2); - let tw3_im = f32x4::splat(simd, -FRAC_1_SQRT_2); - let tw_p7_re = tw3_im.mul_add(-p7_im, tw3_re * p7_re); - let tw_p7_im = tw3_im.mul_add(p7_re, tw3_re * p7_im); + // W8^3 = (-1-j)/√2: twiddle * p7, expressed to avoid double-negation. + // tw = (-s, -s) where s = 1/√2. + // tw*p7_re = s*p7_im - s*p7_re (fmsub, clean) + // tw*p7_im = -(s*p7_im + s*p7_re), but we compute the positive form + // neg_tw_p7_im = s*p7_im + s*p7_re (fmadd, clean) + // and swap the +/- in the butterfly for the im component. + let tw_p7_re = s.mul_sub(p7_im, s * p7_re); + let neg_tw_p7_im = s.mul_add(p7_im, s * p7_re); let r3_re = p3_re + tw_p7_re; let r7_re = p3_re - tw_p7_re; - let r3_im = p3_im + tw_p7_im; - let r7_im = p3_im - tw_p7_im; + let r3_im = p3_im - neg_tw_p7_im; + let r7_im = p3_im + neg_tw_p7_im; // Single inverse transpose and recombine f32x4 → f32x8 let (g0_lo_re, g1_lo_re, g2_lo_re, g3_lo_re) = From 35951fd23c99e0e5c4594a98378670a7491b92a3 Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Fri, 17 Apr 2026 18:38:48 +0100 Subject: [PATCH 10/17] Add a comment on register splilling --- src/kernels/codelets.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index 400966f..bfa4f97 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -12,6 +12,10 @@ use fearless_simd::{ /// Register-resident implementation: all 32 complex values are loaded into f64x4 vectors, /// all 5 butterfly stages execute in registers with no intermediate memory traffic, /// then results are stored back. +/// +/// In reality there is a lot of spilling to the stack, +/// but empirically it performs consistently better than loading/storing after every step +/// or even reducing the load/store traffic with radix-2^2. #[inline(never)] pub fn fft_dit_codelet_32_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { simd.vectorize( From c69ca7c3b3ccb1fa86dac48585f6881e3b5d1e50 Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 18 Apr 2026 10:17:04 +0100 Subject: [PATCH 11/17] Use codelets unconditionally --- src/algorithms/dit.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/algorithms/dit.rs b/src/algorithms/dit.rs index 267d188..cf95164 100644 --- a/src/algorithms/dit.rs +++ b/src/algorithms/dit.rs @@ -43,9 +43,10 @@ fn recursive_dit_fft_f64( if size <= L1_BLOCK_SIZE { // Use FFT-32 codelet to fuse stages 0-4 into a single pass per 32-element chunk - let start_stage = if planner.use_codelet_32 { + let codelet_stages = 5; + let start_stage = if stage_twiddle_idx == 0 && size >= power_of_two(codelet_stages) { fft_dit_codelet_32_f64(simd, &mut reals[..size], &mut imags[..size]); - 5 + codelet_stages } else { 0 }; @@ -109,9 +110,10 @@ fn recursive_dit_fft_f32( if size <= L1_BLOCK_SIZE { // Use FFT-32 codelet to fuse stages 0-4 into a single pass per 32-element chunk - let start_stage = if planner.use_codelet_32 { + let codelet_stages = 5; + let start_stage = if stage_twiddle_idx == 0 && size >= power_of_two(codelet_stages) { fft_dit_codelet_32_f32(simd, &mut reals[..size], &mut imags[..size]); - 5 + codelet_stages } else { 0 }; @@ -389,3 +391,10 @@ fn fft_32_dit_with_planner_and_opts_impl( } } } + +#[inline] +fn power_of_two(power: usize) -> usize { + // 2.pow() requires a lot of ugly type annotations so here's a helper function + debug_assert!(power < usize::BITS as usize); + 1 << power +} From c065e53bf187b5055f3425713662b9d2b3fa1c5a Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 18 Apr 2026 10:17:21 +0100 Subject: [PATCH 12/17] Drop unused import --- src/kernels/codelets.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index bfa4f97..45cd125 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -4,7 +4,7 @@ //! call, eliminating per-stage function call overhead and giving LLVM a wider optimization window. //! use fearless_simd::{ - f32x16, f32x4, f32x8, f64x4, Simd, SimdBase, SimdCombine, SimdFloat, SimdFrom, SimdSplit, + f32x4, f32x8, f64x4, Simd, SimdBase, SimdCombine, SimdFloat, SimdFrom, SimdSplit, }; /// FFT-32 codelet for `f64`: executes stages 0-4 (chunk_size 2 through 32) in a single function. From 62d7eed823eafa3a072da3e12c81156c46ba8baf Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 18 Apr 2026 10:28:21 +0100 Subject: [PATCH 13/17] Remove codelet control from planner now that it is always beneficial --- src/lib.rs | 83 ------------------------------------------ src/planner.rs | 99 +------------------------------------------------- 2 files changed, 1 insertion(+), 181 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5c8e03a..ecb3d2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -394,27 +394,6 @@ mod tests { } } - #[test] - fn heuristic_disables_codelet_for_small_sizes() { - let planner = PlannerDit64::new(16, Direction::Forward); // log_n = 4 - assert!(!planner.use_codelet_32); - } - - #[test] - fn heuristic_enables_codelet_at_threshold() { - let planner = PlannerDit64::new(32, Direction::Forward); // log_n = 5 - assert!(planner.use_codelet_32); - - let planner = PlannerDit64::new(8192, Direction::Forward); // log_n = 13 - assert!(planner.use_codelet_32); - } - - #[test] - fn heuristic_disables_codelet_above_threshold() { - let planner = PlannerDit64::new(16384, Direction::Forward); // log_n = 14 - assert!(!planner.use_codelet_32); - } - #[test] fn tune_mode_does_not_panic() { use crate::planner::PlannerMode; @@ -451,66 +430,4 @@ mod tests { } } } - - #[test] - fn codelet_forced_on_above_heuristic_threshold_f64() { - for n in 14..=15 { - let size = 1 << n; - let mut reals_original = vec![0.0f64; size]; - let mut imags_original = vec![0.0f64; size]; - gen_random_signal_f64(&mut reals_original, &mut imags_original); - - let mut reals = reals_original.clone(); - let mut imags = imags_original.clone(); - - let mut fwd = PlannerDit64::new(size, Direction::Forward); - assert!( - !fwd.use_codelet_32, - "heuristic should disable codelet at n={n}" - ); - fwd.use_codelet_32 = true; - - let mut inv = PlannerDit64::new(size, Direction::Reverse); - inv.use_codelet_32 = true; - - fft_64_dit_with_planner(&mut reals, &mut imags, &fwd); - fft_64_dit_with_planner(&mut reals, &mut imags, &inv); - - for i in 0..size { - assert_float_closeness(reals[i], reals_original[i], 1e-10); - assert_float_closeness(imags[i], imags_original[i], 1e-10); - } - } - } - - #[test] - fn codelet_forced_on_above_heuristic_threshold_f32() { - for n in 14..=15 { - let size = 1 << n; - let mut reals_original = vec![0.0f32; size]; - let mut imags_original = vec![0.0f32; size]; - gen_random_signal_f32(&mut reals_original, &mut imags_original); - - let mut reals = reals_original.clone(); - let mut imags = imags_original.clone(); - - let mut fwd = PlannerDit32::new(size, Direction::Forward); - assert!( - !fwd.use_codelet_32, - "heuristic should disable codelet at n={n}" - ); - fwd.use_codelet_32 = true; - - let mut inv = PlannerDit32::new(size, Direction::Reverse); - inv.use_codelet_32 = true; - - fft_32_dit_with_planner(&mut reals, &mut imags, &fwd); - fft_32_dit_with_planner(&mut reals, &mut imags, &inv); - - for i in 0..size { - assert_float_closeness(reals[i], reals_original[i], 1e-4); - assert_float_closeness(imags[i], imags_original[i], 1e-4); - } - } - } } diff --git a/src/planner.rs b/src/planner.rs index f9869cf..39d17a4 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -42,8 +42,6 @@ macro_rules! impl_planner_dit_for { pub(crate) log_n: usize, /// The level of SIMD instruction support, detected at runtime on x86 and hardcoded elsewhere pub(crate) simd_level: fearless_simd::Level, - /// Whether to use the fused 32-point codelet for stages 0-4 in L1 blocks - pub(crate) use_codelet_32: bool, } impl $struct_name { @@ -91,110 +89,15 @@ macro_rules! impl_planner_dit_for { } } - let use_codelet_32 = Self::estimate_use_codelet_32(log_n); - - let mut planner = Self { + let planner = Self { stage_twiddles, direction, log_n, simd_level, - use_codelet_32, }; - if matches!(mode, PlannerMode::Tune) { - planner.tune_codelet_32(num_points); - } - planner } - - /// Conservative, arch-independent heuristic for whether the 32-point - /// codelet is beneficial. - /// - /// At small sizes (N ≤ 8192) the codelet dominates runtime and - /// cross-block kernel eviction from the µop cache doesn't matter. - /// At large sizes the codelet's code footprint can evict the - /// cross-block kernel on platforms with small L1i caches (e.g., 32KB - /// on x86). Use [`PlannerMode::Tune`] to discover the real threshold - /// on your hardware. - fn estimate_use_codelet_32(log_n: usize) -> bool { - // Codelet needs at least 32 elements (5 stages) - if log_n < 5 { - return false; - } - - // Conservative threshold: enable only for N ≤ 8192 where the - // codelet dominates runtime. On platforms with large L1i (e.g., - // Apple Silicon at 192KB), Tune mode will discover that the - // codelet wins at larger sizes too. - // log_n <= 13 - true - } - - /// Benchmark both paths and set `use_codelet_32` to whichever is faster. - fn tune_codelet_32(&mut self, num_points: usize) { - if self.log_n < 5 { - self.use_codelet_32 = false; - return; - } - - let opts = crate::options::Options { - multithreaded_bit_reversal: false, - smallest_parallel_chunk_size: usize::MAX, - }; - - // Generate random complex signal via xorshift64 (no rand dependency) - let mut rng_state: u64 = 0x517C_C1B7_2722_0A95; - let mut next_f = || -> $precision { - rng_state ^= rng_state << 13; - rng_state ^= rng_state >> 7; - rng_state ^= rng_state << 17; - (rng_state as $precision) / (u64::MAX as $precision) * 2.0 - 1.0 - }; - let reals_orig: Vec<$precision> = (0..num_points).map(|_| next_f()).collect(); - let imags_orig: Vec<$precision> = (0..num_points).map(|_| next_f()).collect(); - let mut reals = reals_orig.clone(); - let mut imags = imags_orig.clone(); - - const WARMUP: usize = 3; - const ITERS: usize = 5; - - // Time WITHOUT codelet - self.use_codelet_32 = false; - for _ in 0..WARMUP { - reals.copy_from_slice(&reals_orig); - imags.copy_from_slice(&imags_orig); - $fft_func(&mut reals, &mut imags, &*self, &opts); - } - - let mut best_without = std::time::Duration::MAX; - for _ in 0..ITERS { - reals.copy_from_slice(&reals_orig); - imags.copy_from_slice(&imags_orig); - let start = std::time::Instant::now(); - $fft_func(&mut reals, &mut imags, &*self, &opts); - best_without = best_without.min(start.elapsed()); - } - - // Time WITH codelet - self.use_codelet_32 = true; - for _ in 0..WARMUP { - reals.copy_from_slice(&reals_orig); - imags.copy_from_slice(&imags_orig); - $fft_func(&mut reals, &mut imags, &*self, &opts); - } - - let mut best_with = std::time::Duration::MAX; - for _ in 0..ITERS { - reals.copy_from_slice(&reals_orig); - imags.copy_from_slice(&imags_orig); - let start = std::time::Instant::now(); - $fft_func(&mut reals, &mut imags, &*self, &opts); - best_with = best_with.min(start.elapsed()); - } - - self.use_codelet_32 = best_with < best_without; - } } }; } From ff6d8d0d778f445dcf45d72945b12bcd22611f3b Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 18 Apr 2026 10:33:27 +0100 Subject: [PATCH 14/17] Mark mode as unused to suppress compiler warning (for now) --- src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/planner.rs b/src/planner.rs index 39d17a4..b718d39 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -59,7 +59,7 @@ macro_rules! impl_planner_dit_for { /// leave performance on the table on platforms with large L1i caches. /// - [`PlannerMode::Tune`]: Benchmarks both paths at plan time. Use this /// when you can afford extra planning time (e.g., planner is reused). - pub fn with_mode(num_points: usize, direction: Direction, mode: PlannerMode) -> Self { + pub fn with_mode(num_points: usize, direction: Direction, _mode: PlannerMode) -> Self { assert!(num_points > 0 && num_points.is_power_of_two()); let simd_level = fearless_simd::Level::new(); From 6e9d21360106259c4843e0b7e68a5a5b9cdcdcc2 Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 18 Apr 2026 11:06:08 +0100 Subject: [PATCH 15/17] Replace zip_low()+zip_high() with interleave(), should have better performance on avx2 --- Cargo.lock | 3 +-- Cargo.toml | 3 +++ src/kernels/codelets.rs | 30 ++++++++++-------------------- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 702d8cb..9ec8a79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -380,8 +380,7 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "fearless_simd" version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76258897e51fd156ee03b6246ea53f3e0eb395d0b327e9961c4fc4c8b2fa151a" +source = "git+https://github.com/Shnatsel/fearless_simd.git?branch=interleave#7642be30901bc4a3702ebe51dfb0e4f81dadcd99" [[package]] name = "fftw" diff --git a/Cargo.toml b/Cargo.toml index 2d182e0..091e7b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,3 +62,6 @@ debug = true all-features = true [lints.rust] + +[patch.crates-io] +fearless_simd = {git = "https://github.com/Shnatsel/fearless_simd.git", branch = "interleave"} \ No newline at end of file diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index 45cd125..b9d4b24 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -33,16 +33,11 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut for (re, im) in reals.chunks_exact_mut(32).zip(imags.chunks_exact_mut(32)) { macro_rules! transpose4x4_f64 { ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ - let t0 = $g0.zip_low($g2); - let t1 = $g0.zip_high($g2); - let t2 = $g1.zip_low($g3); - let t3 = $g1.zip_high($g3); - ( - t0.zip_low(t2), - t0.zip_high(t2), - t1.zip_low(t3), - t1.zip_high(t3), - ) + let (t0, t1) = $g0.interleave($g2); + let (t2, t3) = $g1.interleave($g3); + let (r0, r1) = t0.interleave(t2); + let (r2, r3) = t1.interleave(t3); + (r0, r1, r2, r3) }}; } @@ -352,16 +347,11 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut { macro_rules! transpose4x4 { ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ - let t0 = $g0.zip_low($g2); - let t1 = $g0.zip_high($g2); - let t2 = $g1.zip_low($g3); - let t3 = $g1.zip_high($g3); - ( - t0.zip_low(t2), - t0.zip_high(t2), - t1.zip_low(t3), - t1.zip_high(t3), - ) + let (t0, t1) = $g0.interleave($g2); + let (t2, t3) = $g1.interleave($g3); + let (r0, r1) = t0.interleave(t2); + let (r2, r3) = t1.interleave(t3); + (r0, r1, r2, r3) }}; } From 5cfd39ac5864a40905a432ff1ceb0808d75562dc Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 18 Apr 2026 11:53:55 +0100 Subject: [PATCH 16/17] Reduce f64 codelet from 5 stages to 4, to reduce register pressure in the final stage --- src/algorithms/dit.rs | 8 +- src/kernels/codelets.rs | 182 ++++++++-------------------------------- 2 files changed, 39 insertions(+), 151 deletions(-) diff --git a/src/algorithms/dit.rs b/src/algorithms/dit.rs index cf95164..227d28c 100644 --- a/src/algorithms/dit.rs +++ b/src/algorithms/dit.rs @@ -17,7 +17,7 @@ use fearless_simd::{dispatch, Simd}; use crate::algorithms::bravo::{bit_rev_bravo_f32, bit_rev_bravo_f64}; -use crate::kernels::codelets::{fft_dit_codelet_32_f32, fft_dit_codelet_32_f64}; +use crate::kernels::codelets::{fft_dit_codelet_16_f64, fft_dit_codelet_32_f32}; use crate::kernels::dit::*; use crate::options::Options; use crate::parallel::run_maybe_in_parallel; @@ -42,10 +42,10 @@ fn recursive_dit_fft_f64( let log_size = size.ilog2() as usize; if size <= L1_BLOCK_SIZE { - // Use FFT-32 codelet to fuse stages 0-4 into a single pass per 32-element chunk - let codelet_stages = 5; + // Use FFT-16 codelet to fuse stages 0-3 into a single pass per 16-element chunk + let codelet_stages = 4; let start_stage = if stage_twiddle_idx == 0 && size >= power_of_two(codelet_stages) { - fft_dit_codelet_32_f64(simd, &mut reals[..size], &mut imags[..size]); + fft_dit_codelet_16_f64(simd, &mut reals[..size], &mut imags[..size]); codelet_stages } else { 0 diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index b9d4b24..cd828f3 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -7,30 +7,30 @@ use fearless_simd::{ f32x4, f32x8, f64x4, Simd, SimdBase, SimdCombine, SimdFloat, SimdFrom, SimdSplit, }; -/// FFT-32 codelet for `f64`: executes stages 0-4 (chunk_size 2 through 32) in a single function. +/// FFT-16 codelet for `f64`: executes stages 0-3 (chunk_size 2 through 16) in a single function. /// -/// Register-resident implementation: all 32 complex values are loaded into f64x4 vectors, -/// all 5 butterfly stages execute in registers with no intermediate memory traffic, +/// Register-resident implementation: all 16 complex values are loaded into f64x4 vectors, +/// all 4 butterfly stages execute in registers with no intermediate memory traffic, /// then results are stored back. /// -/// In reality there is a lot of spilling to the stack, -/// but empirically it performs consistently better than loading/storing after every step -/// or even reducing the load/store traffic with radix-2^2. +/// The limiting factor is register pressure: we cannot fuse more stages on AVX2 or NEON. +/// Trying to do so incurs haphazard loads/stores from the stack and ends up being slower +/// than running orderly single-stage passes that load/store predictably. #[inline(never)] -pub fn fft_dit_codelet_32_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { +pub fn fft_dit_codelet_16_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { simd.vectorize( #[inline(always)] - || fft_dit_codelet_32_simd_f64(simd, reals, imags), + || fft_dit_codelet_16_simd_f64(simd, reals, imags), ) } #[inline(always)] -fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { +fn fft_dit_codelet_16_simd_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { assert_eq!(reals.len(), imags.len()); let two = f64x4::splat(simd, 2.0); - for (re, im) in reals.chunks_exact_mut(32).zip(imags.chunks_exact_mut(32)) { + for (re, im) in reals.chunks_exact_mut(16).zip(imags.chunks_exact_mut(16)) { macro_rules! transpose4x4_f64 { ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ let (t0, t1) = $g0.interleave($g2); @@ -82,8 +82,7 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut }}; } - // ---- Load first group and do stages 0+1 ---- - // Deferring v4-v7 loads reduces peak register pressure during transpose. + // ---- Load and do stages 0+1 ---- let mut v0_re = f64x4::from_slice(simd, &re[0..4]); let mut v1_re = f64x4::from_slice(simd, &re[4..8]); let mut v2_re = f64x4::from_slice(simd, &re[8..12]); @@ -95,18 +94,6 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut radix4_transpose!(v0_re, v1_re, v2_re, v3_re, v0_im, v1_im, v2_im, v3_im); - // ---- Load second group and do stages 0+1 ---- - let mut v4_re = f64x4::from_slice(simd, &re[16..20]); - let mut v5_re = f64x4::from_slice(simd, &re[20..24]); - let mut v6_re = f64x4::from_slice(simd, &re[24..28]); - let mut v7_re = f64x4::from_slice(simd, &re[28..32]); - let mut v4_im = f64x4::from_slice(simd, &im[16..20]); - let mut v5_im = f64x4::from_slice(simd, &im[20..24]); - let mut v6_im = f64x4::from_slice(simd, &im[24..28]); - let mut v7_im = f64x4::from_slice(simd, &im[28..32]); - - radix4_transpose!(v4_re, v5_re, v6_re, v7_re, v4_im, v5_im, v6_im, v7_im); - // Butterfly macro: twiddle-multiply hi, then add/sub with lo. // out_lo = lo + tw*hi, out_hi = lo - tw*hi (via 2*lo - out_lo). macro_rules! butterfly { @@ -123,7 +110,7 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut } // ---- Stage 2: dist=4, W_8 twiddles ---- - // Butterfly pairs: (v0,v1), (v2,v3), (v4,v5), (v6,v7) + // Butterfly pairs: (v0,v1), (v2,v3) // All pairs use the same twiddle: W_8^{0,1,2,3} { let tw_re = f64x4::simd_from( @@ -147,14 +134,12 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut butterfly!(v0_re, v0_im, v1_re, v1_im, tw_re, tw_im); butterfly!(v2_re, v2_im, v3_re, v3_im, tw_re, tw_im); - butterfly!(v4_re, v4_im, v5_re, v5_im, tw_re, tw_im); - butterfly!(v6_re, v6_im, v7_re, v7_im, tw_re, tw_im); } // ---- Stage 3: dist=8, W_16 twiddles ---- - // Butterfly pairs: (v0,v2), (v1,v3), (v4,v6), (v5,v7) - // (v0,v2) and (v4,v6) use W_16^{0,1,2,3} - // (v1,v3) and (v5,v7) use W_16^{4,5,6,7} + // Butterfly pairs: (v0,v2), (v1,v3) + // (v0,v2) uses W_16^{0,1,2,3} + // (v1,v3) uses W_16^{4,5,6,7} { let tw_lo_re = f64x4::simd_from( simd, @@ -195,94 +180,6 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut butterfly!(v0_re, v0_im, v2_re, v2_im, tw_lo_re, tw_lo_im); butterfly!(v1_re, v1_im, v3_re, v3_im, tw_hi_re, tw_hi_im); - butterfly!(v4_re, v4_im, v6_re, v6_im, tw_lo_re, tw_lo_im); - butterfly!(v5_re, v5_im, v7_re, v7_im, tw_hi_re, tw_hi_im); - } - - // ---- Stage 4: dist=16, W_32 twiddles ---- - // Butterfly pairs: (v0,v4), (v1,v5), (v2,v6), (v3,v7) - // Each pair uses its own twiddle: W_32^{0..3}, W_32^{4..7}, W_32^{8..11}, W_32^{12..15} - { - let tw0_re = f64x4::simd_from( - simd, - [ - 1.0, // W_32^0 - 0.9807852804032304, // W_32^1 - 0.9238795325112867, // W_32^2 - 0.8314696123025452, // W_32^3 - ], - ); - let tw0_im = f64x4::simd_from( - simd, - [ - 0.0, // W_32^0 - -0.19509032201612825, // W_32^1 - -0.3826834323650898, // W_32^2 - -0.5555702330196022, // W_32^3 - ], - ); - - let tw1_re = f64x4::simd_from( - simd, - [ - std::f64::consts::FRAC_1_SQRT_2, // W_32^4 - 0.5555702330196022, // W_32^5 - 0.3826834323650898, // W_32^6 - 0.19509032201612825, // W_32^7 - ], - ); - let tw1_im = f64x4::simd_from( - simd, - [ - -std::f64::consts::FRAC_1_SQRT_2, // W_32^4 - -0.8314696123025452, // W_32^5 - -0.9238795325112867, // W_32^6 - -0.9807852804032304, // W_32^7 - ], - ); - - let tw2_re = f64x4::simd_from( - simd, - [ - 0.0, // W_32^8 - -0.19509032201612825, // W_32^9 - -0.3826834323650898, // W_32^10 - -0.5555702330196022, // W_32^11 - ], - ); - let tw2_im = f64x4::simd_from( - simd, - [ - -1.0, // W_32^8 - -0.9807852804032304, // W_32^9 - -0.9238795325112867, // W_32^10 - -0.8314696123025452, // W_32^11 - ], - ); - - let tw3_re = f64x4::simd_from( - simd, - [ - -std::f64::consts::FRAC_1_SQRT_2, // W_32^12 - -0.8314696123025452, // W_32^13 - -0.9238795325112867, // W_32^14 - -0.9807852804032304, // W_32^15 - ], - ); - let tw3_im = f64x4::simd_from( - simd, - [ - -std::f64::consts::FRAC_1_SQRT_2, // W_32^12 - -0.5555702330196022, // W_32^13 - -0.3826834323650898, // W_32^14 - -0.19509032201612825, // W_32^15 - ], - ); - - butterfly!(v0_re, v0_im, v4_re, v4_im, tw0_re, tw0_im); - butterfly!(v1_re, v1_im, v5_re, v5_im, tw1_re, tw1_im); - butterfly!(v2_re, v2_im, v6_re, v6_im, tw2_re, tw2_im); - butterfly!(v3_re, v3_im, v7_re, v7_im, tw3_re, tw3_im); } // ---- Store all vectors back ---- @@ -290,19 +187,11 @@ fn fft_dit_codelet_32_simd_f64(simd: S, reals: &mut [f64], imags: &mut v1_re.store_slice(&mut re[4..8]); v2_re.store_slice(&mut re[8..12]); v3_re.store_slice(&mut re[12..16]); - v4_re.store_slice(&mut re[16..20]); - v5_re.store_slice(&mut re[20..24]); - v6_re.store_slice(&mut re[24..28]); - v7_re.store_slice(&mut re[28..32]); v0_im.store_slice(&mut im[0..4]); v1_im.store_slice(&mut im[4..8]); v2_im.store_slice(&mut im[8..12]); v3_im.store_slice(&mut im[12..16]); - v4_im.store_slice(&mut im[16..20]); - v5_im.store_slice(&mut im[20..24]); - v6_im.store_slice(&mut im[24..28]); - v7_im.store_slice(&mut im[28..32]); } } @@ -599,13 +488,12 @@ mod tests { use super::*; use crate::kernels::dit::*; - fn run_stages_0_to_4_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { - assert_eq!(reals.len(), 32); + fn run_stages_0_to_3_f64(simd: S, reals: &mut [f64], imags: &mut [f64]) { + assert_eq!(reals.len(), 16); fft_dit_chunk_2(simd, reals, imags); fft_dit_chunk_4_f64(simd, reals, imags); fft_dit_chunk_8_f64(simd, reals, imags); fft_dit_chunk_16_f64(simd, reals, imags); - fft_dit_chunk_32_f64(simd, reals, imags); } fn run_stages_0_to_4_f32(simd: S, reals: &mut [f32], imags: &mut [f32]) { @@ -618,14 +506,14 @@ mod tests { } #[test] - fn codelet_32_f64_matches_staged() { + fn codelet_16_f64_matches_staged() { use fearless_simd::dispatch; let simd_level = fearless_simd::Level::new(); // Test with a simple impulse signal - let mut re_staged = vec![0.0f64; 32]; - let mut im_staged = vec![0.0f64; 32]; + let mut re_staged = vec![0.0f64; 16]; + let mut im_staged = vec![0.0f64; 16]; re_staged[0] = 1.0; let mut re_codelet = re_staged.clone(); @@ -634,12 +522,12 @@ mod tests { dispatch!(simd_level, simd => { simd.vectorize( #[inline(always)] - || run_stages_0_to_4_f64(simd, &mut re_staged, &mut im_staged), + || run_stages_0_to_3_f64(simd, &mut re_staged, &mut im_staged), ); - fft_dit_codelet_32_f64(simd, &mut re_codelet, &mut im_codelet); + fft_dit_codelet_16_f64(simd, &mut re_codelet, &mut im_codelet); }); - for i in 0..32 { + for i in 0..16 { assert!( (re_staged[i] - re_codelet[i]).abs() < 1e-14, "re[{i}]: staged={}, codelet={}", @@ -654,8 +542,8 @@ mod tests { ); } - let mut re_staged: Vec = (1..=32).map(|i| i as f64).collect(); - let mut im_staged: Vec = (1..=32).map(|i| -(i as f64) * 0.5).collect(); + let mut re_staged: Vec = (1..=16).map(|i| i as f64).collect(); + let mut im_staged: Vec = (1..=16).map(|i| -(i as f64) * 0.5).collect(); let mut re_codelet = re_staged.clone(); let mut im_codelet = im_staged.clone(); @@ -663,12 +551,12 @@ mod tests { dispatch!(simd_level, simd => { simd.vectorize( #[inline(always)] - || run_stages_0_to_4_f64(simd, &mut re_staged, &mut im_staged), + || run_stages_0_to_3_f64(simd, &mut re_staged, &mut im_staged), ); - fft_dit_codelet_32_f64(simd, &mut re_codelet, &mut im_codelet); + fft_dit_codelet_16_f64(simd, &mut re_codelet, &mut im_codelet); }); - for i in 0..32 { + for i in 0..16 { assert!( (re_staged[i] - re_codelet[i]).abs() < 1e-10, "re[{i}]: staged={}, codelet={}", @@ -751,12 +639,12 @@ mod tests { } #[test] - fn codelet_32_f64_multi_chunk() { + fn codelet_16_f64_multi_chunk() { use fearless_simd::dispatch; let simd_level = fearless_simd::Level::new(); - // Test that the codelet correctly processes multiple 32-element chunks + // Test that the codelet correctly processes multiple 16-element chunks let n = 128; let mut re_staged: Vec = (0..n).map(|i| (i as f64) * 0.1).collect(); let mut im_staged: Vec = (0..n).map(|i| -(i as f64) * 0.05).collect(); @@ -766,17 +654,17 @@ mod tests { dispatch!(simd_level, simd => { // Run individual stage kernels on all chunks - for chunk_start in (0..n).step_by(32) { - let re = &mut re_staged[chunk_start..chunk_start + 32]; - let im = &mut im_staged[chunk_start..chunk_start + 32]; + for chunk_start in (0..n).step_by(16) { + let re = &mut re_staged[chunk_start..chunk_start + 16]; + let im = &mut im_staged[chunk_start..chunk_start + 16]; simd.vectorize( #[inline(always)] - || run_stages_0_to_4_f64(simd, re, im), + || run_stages_0_to_3_f64(simd, re, im), ); } // Run codelet on the full array - fft_dit_codelet_32_f64(simd, &mut re_codelet, &mut im_codelet); + fft_dit_codelet_16_f64(simd, &mut re_codelet, &mut im_codelet); }); for i in 0..n { From af675daeed9ae0ad10715ef8c984ba8a7c54de2b Mon Sep 17 00:00:00 2001 From: "Sergey \"Shnatsel\" Davidoff" Date: Sat, 18 Apr 2026 12:06:02 +0100 Subject: [PATCH 17/17] Polyfill interleave() until the upstream fearless_simd PR is merged --- Cargo.lock | 3 ++- Cargo.toml | 5 +---- src/kernels/codelets.rs | 30 ++++++++++++++++++++++-------- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9ec8a79..702d8cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -380,7 +380,8 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "fearless_simd" version = "0.4.0" -source = "git+https://github.com/Shnatsel/fearless_simd.git?branch=interleave#7642be30901bc4a3702ebe51dfb0e4f81dadcd99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76258897e51fd156ee03b6246ea53f3e0eb395d0b327e9961c4fc4c8b2fa151a" [[package]] name = "fftw" diff --git a/Cargo.toml b/Cargo.toml index 091e7b8..aa13898 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,4 @@ debug = true [package.metadata.docs.rs] all-features = true -[lints.rust] - -[patch.crates-io] -fearless_simd = {git = "https://github.com/Shnatsel/fearless_simd.git", branch = "interleave"} \ No newline at end of file +[lints.rust] \ No newline at end of file diff --git a/src/kernels/codelets.rs b/src/kernels/codelets.rs index cd828f3..765dfb9 100644 --- a/src/kernels/codelets.rs +++ b/src/kernels/codelets.rs @@ -7,6 +7,20 @@ use fearless_simd::{ f32x4, f32x8, f64x4, Simd, SimdBase, SimdCombine, SimdFloat, SimdFrom, SimdSplit, }; +/// Equivalent to `a.interleave(b)` — returns `(a.zip_low(b), a.zip_high(b))`. +/// Slow polyfill for +#[inline(always)] +fn interleave_f64x4(a: f64x4, b: f64x4) -> (f64x4, f64x4) { + (a.zip_low(b), a.zip_high(b)) +} + +/// Equivalent to `a.interleave(b)` — returns `(a.zip_low(b), a.zip_high(b))`. +/// Slow polyfill for +#[inline(always)] +fn interleave_f32x4(a: f32x4, b: f32x4) -> (f32x4, f32x4) { + (a.zip_low(b), a.zip_high(b)) +} + /// FFT-16 codelet for `f64`: executes stages 0-3 (chunk_size 2 through 16) in a single function. /// /// Register-resident implementation: all 16 complex values are loaded into f64x4 vectors, @@ -33,10 +47,10 @@ fn fft_dit_codelet_16_simd_f64(simd: S, reals: &mut [f64], imags: &mut for (re, im) in reals.chunks_exact_mut(16).zip(imags.chunks_exact_mut(16)) { macro_rules! transpose4x4_f64 { ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ - let (t0, t1) = $g0.interleave($g2); - let (t2, t3) = $g1.interleave($g3); - let (r0, r1) = t0.interleave(t2); - let (r2, r3) = t1.interleave(t3); + let (t0, t1) = interleave_f64x4($g0, $g2); + let (t2, t3) = interleave_f64x4($g1, $g3); + let (r0, r1) = interleave_f64x4(t0, t2); + let (r2, r3) = interleave_f64x4(t1, t3); (r0, r1, r2, r3) }}; } @@ -236,10 +250,10 @@ fn fft_dit_codelet_32_simd_f32(simd: S, reals: &mut [f32], imags: &mut { macro_rules! transpose4x4 { ($g0:expr, $g1:expr, $g2:expr, $g3:expr) => {{ - let (t0, t1) = $g0.interleave($g2); - let (t2, t3) = $g1.interleave($g3); - let (r0, r1) = t0.interleave(t2); - let (r2, r3) = t1.interleave(t3); + let (t0, t1) = interleave_f32x4($g0, $g2); + let (t2, t3) = interleave_f32x4($g1, $g3); + let (r0, r1) = interleave_f32x4(t0, t2); + let (r2, r3) = interleave_f32x4(t1, t3); (r0, r1, r2, r3) }}; }