From aa79d420ca9a0942da744075a24b948fb10f691a Mon Sep 17 00:00:00 2001 From: Lilith River Date: Sun, 18 Jan 2026 03:20:34 -0700 Subject: [PATCH 1/3] Replace platform-specific SIMD with portable wide crate Replace hand-written SSE2 (x86_64) and NEON (aarch64) intrinsics in f_pixel::diff() with safe, portable SIMD using the wide crate's f32x4. - Direct bytemuck cast from ARGBF to f32x4 (both are Pod) - Single implementation works on all platforms (x86, ARM, WASM, etc.) - Includes scalar reference and brute-force comparison test - ~70 lines removed, eliminates all unsafe in this function --- Cargo.toml | 1 + src/pal.rs | 128 ++++++++++++++++++++++++----------------------------- 2 files changed, 59 insertions(+), 70 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 93f6940..a9b49c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ doctest = false [dependencies] arrayvec = { version = "0.7.4", default-features = false } rgb = { version = "0.8.47", default-features = false, features = ["bytemuck"] } +wide = { version = "1.1.1", default-features = false } rayon = { version = "1.10.0", optional = true } thread_local = { version = "1.1.8", optional = true } # Used only in no_std diff --git a/src/pal.rs b/src/pal.rs index 59d618e..7b3aebf 100644 --- a/src/pal.rs +++ b/src/pal.rs @@ -3,6 +3,7 @@ use arrayvec::ArrayVec; use core::iter; use core::ops::{Deref, DerefMut}; use rgb::prelude::*; +use wide::f32x4; #[cfg(all(not(feature = "std"), feature = "no_std"))] use crate::no_std_compat::*; @@ -24,20 +25,37 @@ const LIQ_WEIGHT_MSE: f64 = 0.45; /// 4xf32 color using internal gamma. /// -/// ARGB layout is important for x86 SIMD. -/// I've created the newtype wrapper to try a 16-byte alignment, but it didn't improve perf :( -#[cfg_attr( - any(target_arch = "x86_64", all(target_feature = "neon", target_arch = "aarch64")), - repr(C, align(16)) -)] +/// ARGB layout: [a, r, g, b] as f32x4 +#[repr(C, align(16))] #[derive(Debug, Copy, Clone, Default, PartialEq)] #[allow(non_camel_case_types)] pub struct f_pixel(pub ARGBF); impl f_pixel { - #[cfg(not(any(target_arch = "x86_64", all(target_feature = "neon", target_arch = "aarch64"))))] + /// Compute perceptual color difference using portable SIMD. + /// + /// Computes max(onblack², onwhite²) for RGB channels and sums them, + /// where onblack = self - other, onwhite = onblack + alpha_diff. #[inline(always)] pub fn diff(&self, other: &f_pixel) -> f32 { + // ARGBF and f32x4 are both Pod and same size/alignment + let px: f32x4 = rgb::bytemuck::cast(self.0); + let py: f32x4 = rgb::bytemuck::cast(other.0); + + // alpha is at index 0 in ARGBF layout + let alpha_diff = f32x4::splat(other.0.a - self.0.a); + let onblack = px - py; + let onwhite = onblack + alpha_diff; + let max_sq = (onblack * onblack).max(onwhite * onwhite); + + // Sum RGB channels (indices 1, 2, 3), skip alpha (index 0) + let arr: [f32; 4] = max_sq.into(); + arr[1] + arr[2] + arr[3] + } + + /// Scalar reference implementation for verification + #[cfg(test)] + fn diff_scalar(&self, other: &f_pixel) -> f32 { let alphas = other.0.a - self.0.a; let black = self.0 - other.0; let white = ARGBF { @@ -51,69 +69,6 @@ impl f_pixel { (black.b * black.b).max(white.b * white.b) } - #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] - #[inline(always)] - pub fn diff(&self, other: &Self) -> f32 { - unsafe { - use core::arch::aarch64::*; - - let px = vld1q_f32((self as *const Self).cast::()); - let py = vld1q_f32((other as *const Self).cast::()); - - // y.a - x.a - let mut alphas = vsubq_f32(py, px); - alphas = vdupq_laneq_f32(alphas, 0); // copy first to all four - - let mut onblack = vsubq_f32(px, py); // x - y - let mut onwhite = vaddq_f32(onblack, alphas); // x - y + (y.a - x.a) - - onblack = vmulq_f32(onblack, onblack); - onwhite = vmulq_f32(onwhite, onwhite); - - let max = vmaxq_f32(onwhite, onblack); - - let mut max_r = [0.; 4]; - vst1q_f32(max_r.as_mut_ptr(), max); - - let mut max_gb = [0.; 4]; - vst1q_f32(max_gb.as_mut_ptr(), vpaddq_f32(max, max)); - - // add rgb, not a - - max_r[1] + max_gb[1] - } - } - - #[cfg(target_arch = "x86_64")] - #[inline(always)] - pub fn diff(&self, other: &f_pixel) -> f32 { - unsafe { - use core::arch::x86_64::*; - - let px = _mm_loadu_ps(self as *const f_pixel as *const f32); - let py = _mm_loadu_ps(other as *const f_pixel as *const f32); - - // y.a - x.a - let mut alphas = _mm_sub_ss(py, px); - alphas = _mm_shuffle_ps(alphas, alphas, 0); // copy first to all four - - let mut onblack = _mm_sub_ps(px, py); // x - y - let mut onwhite = _mm_add_ps(onblack, alphas); // x - y + (y.a - x.a) - - onblack = _mm_mul_ps(onblack, onblack); - onwhite = _mm_mul_ps(onwhite, onwhite); - let max = _mm_max_ps(onwhite, onblack); - - // the compiler is better at horizontal add than I am - let mut tmp = [0.; 4]; - _mm_storeu_ps(tmp.as_mut_ptr(), max); - - // add rgb, not a - let res = tmp[1] + tmp[2] + tmp[3]; - res - } - } - #[inline] pub(crate) fn to_rgb(self, gamma: f64) -> RGBA { if self.is_fully_transparent() { @@ -434,6 +389,39 @@ fn diff_test() { assert!(c.diff(&b) < c.diff(&d)); } +/// Verify SIMD diff matches scalar reference implementation across edge cases +#[test] +fn diff_simd_matches_scalar() { + // Test values: boundaries and slight overflow (common in dithering) + let values: &[f32] = &[0.0, 0.5, 1.0, -0.01, 1.01]; + for &a1 in values { + for &r1 in values { + for &g1 in values { + for &b1 in values { + let px1 = f_pixel(ARGBF { a: a1, r: r1, g: g1, b: b1 }); + for &a2 in values { + for &r2 in values { + for &g2 in values { + for &b2 in values { + let px2 = f_pixel(ARGBF { a: a2, r: r2, g: g2, b: b2 }); + let simd = px1.diff(&px2); + let scalar = px1.diff_scalar(&px2); + let diff = (simd - scalar).abs(); + assert!( + diff < 1e-5 || diff < scalar.abs() * 1e-5, + "SIMD {simd} != scalar {scalar} (diff {diff}) for {:?} vs {:?}", + px1.0, px2.0 + ); + } + } + } + } + } + } + } + } +} + #[test] fn alpha_test() { let gamma = gamma_lut(0.45455); From bfa4657e3a77397d9f97f68ec4a7c37c4273dc7e Mon Sep 17 00:00:00 2001 From: Lilith River Date: Sun, 18 Jan 2026 03:23:39 -0700 Subject: [PATCH 2/3] Remove HistSortTmp union in favor of plain u32 The union was used to reinterpret the same memory as either mc_sort_value (u32) during median cut sorting, or likely_palette_index (PalIndex) after. Since PalIndex fits in u32, we can simply store u32 and cast on read. This eliminates unsafe union field access with no runtime cost. --- src/hist.rs | 20 +++++--------------- src/kmeans.rs | 2 +- src/mediancut.rs | 4 ++-- 3 files changed, 8 insertions(+), 18 deletions(-) diff --git a/src/hist.rs b/src/hist.rs index 39ee00f..1ac9be7 100644 --- a/src/hist.rs +++ b/src/hist.rs @@ -51,23 +51,20 @@ pub(crate) struct HistItem { pub color: f_pixel, pub adjusted_weight: f32, pub perceptual_weight: f32, - /// temporary in median cut pub mc_color_weight: f32, - pub tmp: HistSortTmp, + /// Reused: mc_sort_value during median cut, then likely_palette_index after + pub tmp: u32, } impl HistItem { - // Safety: just an int, and it's been initialized when constructing the object #[inline(always)] pub fn mc_sort_value(&self) -> u32 { - unsafe { self.tmp.mc_sort_value } + self.tmp } - // The u32 has been initialized when constructing the object, and u8/u16 is smaller than that #[inline(always)] pub fn likely_palette_index(&self) -> PalIndex { - assert!(mem::size_of::() <= mem::size_of::()); - unsafe { self.tmp.likely_palette_index } + self.tmp as PalIndex } } @@ -83,13 +80,6 @@ impl fmt::Debug for HistItem { } } -#[repr(C)] -#[derive(Clone, Copy)] -pub union HistSortTmp { - pub mc_sort_value: u32, - pub likely_palette_index: PalIndex, -} - impl Histogram { /// Creates histogram object that will be used to collect color statistics from multiple images. /// @@ -323,7 +313,7 @@ impl Histogram { adjusted_weight: if cfg!(debug_assertions) { f32::NAN } else { 0. }, perceptual_weight: if cfg!(debug_assertions) { f32::NAN } else { 0. }, mc_color_weight: if cfg!(debug_assertions) { f32::NAN } else { 0. }, - tmp: HistSortTmp { mc_sort_value: if cfg!(debug_assertions) { !0 } else { 0 } }, + tmp: if cfg!(debug_assertions) { !0 } else { 0 }, }); let mut items = items.into_boxed_slice(); diff --git a/src/kmeans.rs b/src/kmeans.rs index c193538..1c25e35 100644 --- a/src/kmeans.rs +++ b/src/kmeans.rs @@ -93,7 +93,7 @@ impl Kmeans { self.weighed_diff_sum += batch.iter_mut().map(|item| { let px = item.color; let (matched, mut diff) = n.search(&px, item.likely_palette_index()); - item.tmp.likely_palette_index = matched; + item.tmp = matched as u32; if adjust_weight { let remapped = colors[matched as usize]; let (_, new_diff) = n.search(&f_pixel(px.0 + px.0 - remapped.0), matched); diff --git a/src/mediancut.rs b/src/mediancut.rs index b83d5eb..f256353 100644 --- a/src/mediancut.rs +++ b/src/mediancut.rs @@ -96,7 +96,7 @@ impl<'hist> MBox<'hist> { let chans: [f32; 4] = rgb::bytemuck::cast(item.color.0); // Only the first channel really matters. But other channels are included, because when trying median cut // many times with different histogram weights, I don't want sort randomness to influence the outcome. - item.tmp.mc_sort_value = (u32::from((chans[channels[0].chan] * 65535.) as u16) << 16) + item.tmp = (u32::from((chans[channels[0].chan] * 65535.) as u16) << 16) | u32::from(((chans[channels[2].chan] + chans[channels[1].chan] / 2. + chans[channels[3].chan] / 4.) * 65535.) as u16); // box will be split to make color_weight of each side even } } @@ -252,7 +252,7 @@ impl<'hist> MedianCutter<'hist> { let mut palette = PalF::new(); for (i, mbox) in self.boxes.iter_mut().enumerate() { - mbox.colors.iter_mut().for_each(move |a| a.tmp.likely_palette_index = i as _); + mbox.colors.iter_mut().for_each(move |a| a.tmp = i as _); // store total color popularity (perceptual_weight is approximation of it) let pop = mbox.colors.iter().map(|a| f64::from(a.perceptual_weight)).sum::(); From 4bf8a30f66e8cff4887fe19d503731e714aacd3c Mon Sep 17 00:00:00 2001 From: Lilith River Date: Sun, 18 Jan 2026 03:24:41 -0700 Subject: [PATCH 3/3] Remove unnecessary get_unchecked in qsort_pivot This sorts 3 pivot indices and accesses them 3 times total. The bounds check overhead is negligible compared to the partitioning work that follows. Safe indexing is simpler and equally fast. --- src/mediancut.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/mediancut.rs b/src/mediancut.rs index f256353..0fd37ef 100644 --- a/src/mediancut.rs +++ b/src/mediancut.rs @@ -141,11 +141,7 @@ fn qsort_pivot(base: &[HistItem]) -> usize { return len / 2; } let mut pivots = [8, len / 2, len - 1]; - // LLVM can't see it's in bounds :( - pivots.sort_unstable_by_key(move |&idx| unsafe { - debug_assert!(base.get(idx).is_some()); - base.get_unchecked(idx) - }.mc_sort_value()); + pivots.sort_unstable_by_key(move |&idx| base[idx].mc_sort_value()); pivots[1] }