Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ doctest = false
[dependencies]
arrayvec = { version = "0.7.4", default-features = false }
rgb = { version = "0.8.47", default-features = false, features = ["bytemuck"] }
wide = { version = "1.1.1", default-features = false }
rayon = { version = "1.10.0", optional = true }
thread_local = { version = "1.1.8", optional = true }
# Used only in no_std
Expand Down
20 changes: 5 additions & 15 deletions src/hist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,20 @@ pub(crate) struct HistItem {
pub color: f_pixel,
pub adjusted_weight: f32,
pub perceptual_weight: f32,
/// temporary in median cut
pub mc_color_weight: f32,
pub tmp: HistSortTmp,
/// Reused: mc_sort_value during median cut, then likely_palette_index after
pub tmp: u32,
}

impl HistItem {
// Safety: just an int, and it's been initialized when constructing the object
#[inline(always)]
pub fn mc_sort_value(&self) -> u32 {
unsafe { self.tmp.mc_sort_value }
self.tmp
}

// The u32 has been initialized when constructing the object, and u8/u16 is smaller than that
#[inline(always)]
pub fn likely_palette_index(&self) -> PalIndex {
assert!(mem::size_of::<PalIndex>() <= mem::size_of::<u32>());
unsafe { self.tmp.likely_palette_index }
self.tmp as PalIndex
}
}

Expand All @@ -83,13 +80,6 @@ impl fmt::Debug for HistItem {
}
}

#[repr(C)]
#[derive(Clone, Copy)]
pub union HistSortTmp {
pub mc_sort_value: u32,
pub likely_palette_index: PalIndex,
}

impl Histogram {
/// Creates histogram object that will be used to collect color statistics from multiple images.
///
Expand Down Expand Up @@ -323,7 +313,7 @@ impl Histogram {
adjusted_weight: if cfg!(debug_assertions) { f32::NAN } else { 0. },
perceptual_weight: if cfg!(debug_assertions) { f32::NAN } else { 0. },
mc_color_weight: if cfg!(debug_assertions) { f32::NAN } else { 0. },
tmp: HistSortTmp { mc_sort_value: if cfg!(debug_assertions) { !0 } else { 0 } },
tmp: if cfg!(debug_assertions) { !0 } else { 0 },
});
let mut items = items.into_boxed_slice();

Expand Down
2 changes: 1 addition & 1 deletion src/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl Kmeans {
self.weighed_diff_sum += batch.iter_mut().map(|item| {
let px = item.color;
let (matched, mut diff) = n.search(&px, item.likely_palette_index());
item.tmp.likely_palette_index = matched;
item.tmp = matched as u32;
if adjust_weight {
let remapped = colors[matched as usize];
let (_, new_diff) = n.search(&f_pixel(px.0 + px.0 - remapped.0), matched);
Expand Down
10 changes: 3 additions & 7 deletions src/mediancut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl<'hist> MBox<'hist> {
let chans: [f32; 4] = rgb::bytemuck::cast(item.color.0);
// Only the first channel really matters. But other channels are included, because when trying median cut
// many times with different histogram weights, I don't want sort randomness to influence the outcome.
item.tmp.mc_sort_value = (u32::from((chans[channels[0].chan] * 65535.) as u16) << 16)
item.tmp = (u32::from((chans[channels[0].chan] * 65535.) as u16) << 16)
| u32::from(((chans[channels[2].chan] + chans[channels[1].chan] / 2. + chans[channels[3].chan] / 4.) * 65535.) as u16); // box will be split to make color_weight of each side even
}
}
Expand Down Expand Up @@ -141,11 +141,7 @@ fn qsort_pivot(base: &[HistItem]) -> usize {
return len / 2;
}
let mut pivots = [8, len / 2, len - 1];
// LLVM can't see it's in bounds :(
pivots.sort_unstable_by_key(move |&idx| unsafe {
debug_assert!(base.get(idx).is_some());
base.get_unchecked(idx)
}.mc_sort_value());
pivots.sort_unstable_by_key(move |&idx| base[idx].mc_sort_value());
pivots[1]
}

Expand Down Expand Up @@ -252,7 +248,7 @@ impl<'hist> MedianCutter<'hist> {
let mut palette = PalF::new();

for (i, mbox) in self.boxes.iter_mut().enumerate() {
mbox.colors.iter_mut().for_each(move |a| a.tmp.likely_palette_index = i as _);
mbox.colors.iter_mut().for_each(move |a| a.tmp = i as _);

// store total color popularity (perceptual_weight is approximation of it)
let pop = mbox.colors.iter().map(|a| f64::from(a.perceptual_weight)).sum::<f64>();
Expand Down
128 changes: 58 additions & 70 deletions src/pal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use arrayvec::ArrayVec;
use core::iter;
use core::ops::{Deref, DerefMut};
use rgb::prelude::*;
use wide::f32x4;

#[cfg(all(not(feature = "std"), feature = "no_std"))]
use crate::no_std_compat::*;
Expand All @@ -24,20 +25,37 @@ const LIQ_WEIGHT_MSE: f64 = 0.45;

/// 4xf32 color using internal gamma.
///
/// ARGB layout is important for x86 SIMD.
/// I've created the newtype wrapper to try a 16-byte alignment, but it didn't improve perf :(
#[cfg_attr(
any(target_arch = "x86_64", all(target_feature = "neon", target_arch = "aarch64")),
repr(C, align(16))
)]
/// ARGB layout: [a, r, g, b] as f32x4
#[repr(C, align(16))]
#[derive(Debug, Copy, Clone, Default, PartialEq)]
#[allow(non_camel_case_types)]
pub struct f_pixel(pub ARGBF);

impl f_pixel {
#[cfg(not(any(target_arch = "x86_64", all(target_feature = "neon", target_arch = "aarch64"))))]
/// Compute perceptual color difference using portable SIMD.
///
/// Computes max(onblack², onwhite²) for RGB channels and sums them,
/// where onblack = self - other, onwhite = onblack + alpha_diff.
#[inline(always)]
pub fn diff(&self, other: &f_pixel) -> f32 {
// ARGBF and f32x4 are both Pod and same size/alignment
let px: f32x4 = rgb::bytemuck::cast(self.0);
let py: f32x4 = rgb::bytemuck::cast(other.0);

// alpha is at index 0 in ARGBF layout
let alpha_diff = f32x4::splat(other.0.a - self.0.a);
let onblack = px - py;
let onwhite = onblack + alpha_diff;
let max_sq = (onblack * onblack).max(onwhite * onwhite);

// Sum RGB channels (indices 1, 2, 3), skip alpha (index 0)
let arr: [f32; 4] = max_sq.into();
arr[1] + arr[2] + arr[3]
}

/// Scalar reference implementation for verification
#[cfg(test)]
fn diff_scalar(&self, other: &f_pixel) -> f32 {
let alphas = other.0.a - self.0.a;
let black = self.0 - other.0;
let white = ARGBF {
Expand All @@ -51,69 +69,6 @@ impl f_pixel {
(black.b * black.b).max(white.b * white.b)
}

#[cfg(all(target_feature = "neon", target_arch = "aarch64"))]
#[inline(always)]
pub fn diff(&self, other: &Self) -> f32 {
unsafe {
use core::arch::aarch64::*;

let px = vld1q_f32((self as *const Self).cast::<f32>());
let py = vld1q_f32((other as *const Self).cast::<f32>());

// y.a - x.a
let mut alphas = vsubq_f32(py, px);
alphas = vdupq_laneq_f32(alphas, 0); // copy first to all four

let mut onblack = vsubq_f32(px, py); // x - y
let mut onwhite = vaddq_f32(onblack, alphas); // x - y + (y.a - x.a)

onblack = vmulq_f32(onblack, onblack);
onwhite = vmulq_f32(onwhite, onwhite);

let max = vmaxq_f32(onwhite, onblack);

let mut max_r = [0.; 4];
vst1q_f32(max_r.as_mut_ptr(), max);

let mut max_gb = [0.; 4];
vst1q_f32(max_gb.as_mut_ptr(), vpaddq_f32(max, max));

// add rgb, not a

max_r[1] + max_gb[1]
}
}

#[cfg(target_arch = "x86_64")]
#[inline(always)]
pub fn diff(&self, other: &f_pixel) -> f32 {
unsafe {
use core::arch::x86_64::*;

let px = _mm_loadu_ps(self as *const f_pixel as *const f32);
let py = _mm_loadu_ps(other as *const f_pixel as *const f32);

// y.a - x.a
let mut alphas = _mm_sub_ss(py, px);
alphas = _mm_shuffle_ps(alphas, alphas, 0); // copy first to all four

let mut onblack = _mm_sub_ps(px, py); // x - y
let mut onwhite = _mm_add_ps(onblack, alphas); // x - y + (y.a - x.a)

onblack = _mm_mul_ps(onblack, onblack);
onwhite = _mm_mul_ps(onwhite, onwhite);
let max = _mm_max_ps(onwhite, onblack);

// the compiler is better at horizontal add than I am
let mut tmp = [0.; 4];
_mm_storeu_ps(tmp.as_mut_ptr(), max);

// add rgb, not a
let res = tmp[1] + tmp[2] + tmp[3];
res
}
}

#[inline]
pub(crate) fn to_rgb(self, gamma: f64) -> RGBA {
if self.is_fully_transparent() {
Expand Down Expand Up @@ -434,6 +389,39 @@ fn diff_test() {
assert!(c.diff(&b) < c.diff(&d));
}

/// Verify SIMD diff matches scalar reference implementation across edge cases
#[test]
fn diff_simd_matches_scalar() {
// Test values: boundaries and slight overflow (common in dithering)
let values: &[f32] = &[0.0, 0.5, 1.0, -0.01, 1.01];
for &a1 in values {
for &r1 in values {
for &g1 in values {
for &b1 in values {
let px1 = f_pixel(ARGBF { a: a1, r: r1, g: g1, b: b1 });
for &a2 in values {
for &r2 in values {
for &g2 in values {
for &b2 in values {
let px2 = f_pixel(ARGBF { a: a2, r: r2, g: g2, b: b2 });
let simd = px1.diff(&px2);
let scalar = px1.diff_scalar(&px2);
let diff = (simd - scalar).abs();
assert!(
diff < 1e-5 || diff < scalar.abs() * 1e-5,
"SIMD {simd} != scalar {scalar} (diff {diff}) for {:?} vs {:?}",
px1.0, px2.0
);
}
}
}
}
}
}
}
}
}

#[test]
fn alpha_test() {
let gamma = gamma_lut(0.45455);
Expand Down