From f12f4f2a3eaf29c24ca4dd0915362064820f1974 Mon Sep 17 00:00:00 2001 From: 39ali Date: Wed, 15 Apr 2026 12:41:56 +0300 Subject: [PATCH 1/3] add support for unions --- .../src/builder/builder_methods.rs | 563 ++++++++++++++++++ tests/compiletests/ui/lang/core/unions.rs | 155 +++++ 2 files changed, 718 insertions(+) create mode 100644 tests/compiletests/ui/lang/core/unions.rs diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index d86db1cbd00..1a452ff82a6 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -336,6 +336,21 @@ fn memset_dynamic_scalar( .unwrap() } +#[derive(Clone)] +struct ScalarLeaf { + // the access-chain index sequence to reach this leaf + indices: Vec, + // the SPIR-V `Word` for the scalar type + ty: Word, + bits: u32, +} + +/// matched pair of src/dst scalar-leaf groups with equal total bit widths, +struct LeafGroup { + src: Vec, + dst: Vec, +} + impl<'a, 'tcx> Builder<'a, 'tcx> { #[instrument(level = "trace", skip(self))] fn ordering_to_semantics_def(&mut self, ordering: AtomicOrdering) -> SpirvValue { @@ -1278,6 +1293,544 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { builder.insert_into_block(index, inst).unwrap(); result_id.with_type(ptr_ty) } + + /// attempt to copy memory between two pointers of different but layout-compatible types + /// by recursively decomposing both into scalar leaves and emitting element-wise + /// `OpLoad` + `OpStore` pairs (with `OpBitcast` for type mismatches). This is used to + /// support union field access, e.g. `[f32; 5]` -> `struct { a: f64, b: [f32; 3] }`. + /// handles both 1:1 (same scalar count) and N:M (different counts but same total bits) + + /// returns `false` if the types are not compatible. + fn memcpy_const_size(&mut self, dst: SpirvValue, src: SpirvValue) -> bool { + let src_stripped = src.strip_ptrcasts(); + let dst_stripped = dst.strip_ptrcasts(); + + let src_pointee = match self.lookup_type(src_stripped.ty) { + SpirvType::Pointer { pointee } => pointee, + _ => return false, + }; + let dst_pointee = match self.lookup_type(dst_stripped.ty) { + SpirvType::Pointer { pointee } => pointee, + _ => return false, + }; + + let mut src_leaves = Vec::new(); + let mut dst_leaves = Vec::new(); + if !self.collect_scalar_leaves(src_pointee, &mut Vec::new(), &mut src_leaves) + || !self.collect_scalar_leaves(dst_pointee, &mut Vec::new(), &mut dst_leaves) + { + return false; + } + + // group leaves by equal total bit widths. + //this handles both 1:1 (same scalar count) + // and N:M (different counts, same total bits) cases. + let groups = match Self::group_leaves_by_bits(&src_leaves, &dst_leaves) { + Some(g) => g, + None => return false, + }; + + trace!( + "memcpy_const_size: {} groups ({} src leaves, {} dst leaves) from {} to {}", + groups.len(), + src_leaves.len(), + dst_leaves.len(), + self.debug_type(src_pointee), + self.debug_type(dst_pointee), + ); + + for LeafGroup { + src: src_group, + dst: dst_group, + } in groups + { + match (src_group.as_slice(), dst_group.as_slice()) { + //1 -> 1: simple scalar copy, possibly with OpBitcast + ( + [ + ScalarLeaf { + indices: src_indices, + ty: src_ty, + .. + }, + ], + [ + ScalarLeaf { + indices: dst_indices, + ty: dst_ty, + .. + }, + ], + ) => { + let val = self.load_scalar_leaf(src_stripped, src_indices, *src_ty); + let val = self.bitcast_leaf_if_needed(val, *src_ty, *dst_ty); + self.store_scalar_leaf(dst_stripped, dst_indices, *dst_ty, val); + } + + // N -> 1: pack N same-width scalars into a vector, then bitcast + ( + src_group, + [ + ScalarLeaf { + indices: dst_indices, + ty: dst_ty, + bits: dst_bits, + }, + ], + ) => { + let n = src_group.len(); + if n > 4 { + return false; + } + // pick the element type for the intermediate vector. + // all src scalars have the same bit width (guaranteed by grouping). + let vec_elem_bits = src_group[0].bits; + let vec_elem_ty = + SpirvType::Integer(vec_elem_bits, false).def(self.span(), self); + // load each src scalar, bitcasting if needed. + let components: Vec<_> = src_group + .iter() + .map(|ScalarLeaf { indices, ty, .. }| { + let val = self.load_scalar_leaf(src_stripped, indices, *ty); + self.bitcast_leaf_if_needed(val, *ty, vec_elem_ty) + }) + .collect(); + let component_ids: Vec<_> = components.iter().map(|v| v.def(self)).collect(); + + let vec_ty = SpirvType::Vector { + element: vec_elem_ty, + count: n as u32, + size: Size::from_bits(*dst_bits), + align: Align::from_bits(vec_elem_bits as u64).unwrap(), + } + .def(self.span(), self); + let vec_val = self + .emit() + .composite_construct(vec_ty, None, component_ids) + .unwrap() + .with_type(vec_ty); + + let val = self.bitcast_leaf_if_needed(vec_val, vec_ty, *dst_ty); + self.store_scalar_leaf(dst_stripped, dst_indices, *dst_ty, val); + } + + // 1 -> N: bitcast src scalar to a vector, then extract each component + ( + [ + ScalarLeaf { + indices: src_indices, + ty: src_ty, + bits: src_bits, + }, + ], + dst_group, + ) => { + let n = dst_group.len(); + if n > 4 { + return false; + } + let vec_elem_bits = dst_group[0].bits; + let vec_elem_ty = + SpirvType::Integer(vec_elem_bits, false).def(self.span(), self); + let vec_ty = SpirvType::Vector { + element: vec_elem_ty, + count: n as u32, + size: Size::from_bits(*src_bits), + align: Align::from_bits(vec_elem_bits as u64).unwrap(), + } + .def(self.span(), self); + // load src scalar, bitcast to the intermediate vector type. + let val = self.load_scalar_leaf(src_stripped, src_indices, *src_ty); + let vec_val = self.bitcast_leaf_if_needed(val, *src_ty, vec_ty); + // extract each component and store to the corresponding dst leaf. + for ( + i, + ScalarLeaf { + indices: dst_indices, + ty: dst_ty, + .. + }, + ) in dst_group.iter().enumerate() + { + let component = self + .emit() + .composite_extract(vec_elem_ty, None, vec_val.def(self), [i as u32]) + .unwrap() + .with_type(vec_elem_ty); + let component = + self.bitcast_leaf_if_needed(component, vec_elem_ty, *dst_ty); + self.store_scalar_leaf(dst_stripped, dst_indices, *dst_ty, component); + } + } + + _ => return false, + } + } + + true + } + + /// attempt to store a composite value `val` into `dest_ptr` whose pointee type has a + /// different but layout-compatible SPIR-V type, by decomposing both types into scalar + /// leaves and emitting `OpCompositeExtract` + `OpStore` pairs (with `OpBitcast` for + /// type mismatches). Handles 1:1, N->1, and 1->N groups + /// Returns `true` if the store was emitted, `false` if the types are not compatible. + fn try_store_composite_scalar_leaves(&mut self, val: SpirvValue, dest_ptr: SpirvValue) -> bool { + let dest_pointee = match self.lookup_type(dest_ptr.ty) { + SpirvType::Pointer { pointee } => pointee, + _ => return false, + }; + + let mut src_leaves = Vec::new(); + let mut dst_leaves = Vec::new(); + if !self.collect_scalar_leaves(val.ty, &mut Vec::new(), &mut src_leaves) + || !self.collect_scalar_leaves(dest_pointee, &mut Vec::new(), &mut dst_leaves) + { + return false; + } + + let groups = match Self::group_leaves_by_bits(&src_leaves, &dst_leaves) { + Some(g) => g, + None => return false, + }; + + trace!( + "try_store_composite_via_scalar_leaves: {} groups ({} src leaves, {} dst leaves)", + groups.len(), + src_leaves.len(), + dst_leaves.len(), + ); + + for LeafGroup { + src: src_group, + dst: dst_group, + } in groups + { + match (src_group.as_slice(), dst_group.as_slice()) { + // 1 -> 1 + ( + [ + ScalarLeaf { + indices: src_indices, + ty: src_ty, + .. + }, + ], + [ + ScalarLeaf { + indices: dst_indices, + ty: dst_ty, + .. + }, + ], + ) => { + let extracted = self.extract_scalar_leaf(val, src_indices, *src_ty); + let casted = self.bitcast_leaf_if_needed(extracted, *src_ty, *dst_ty); + self.store_scalar_leaf(dest_ptr, dst_indices, *dst_ty, casted); + } + + // N -> 1: extract N scalars, pack into a vector, bitcast + ( + src_group, + [ + ScalarLeaf { + indices: dst_indices, + ty: dst_ty, + bits: dst_bits, + }, + ], + ) => { + let n = src_group.len(); + if n > 4 { + return false; + } + let vec_elem_bits = src_group[0].bits; + let vec_elem_ty = + SpirvType::Integer(vec_elem_bits, false).def(self.span(), self); + let components: Vec<_> = src_group + .iter() + .map(|ScalarLeaf { indices, ty, .. }| { + let extracted = self.extract_scalar_leaf(val, indices, *ty); + self.bitcast_leaf_if_needed(extracted, *ty, vec_elem_ty) + }) + .collect(); + let component_ids: Vec<_> = components.iter().map(|v| v.def(self)).collect(); + let vec_ty = SpirvType::Vector { + element: vec_elem_ty, + count: n as u32, + size: Size::from_bits(*dst_bits), + align: Align::from_bits(vec_elem_bits as u64).unwrap(), + } + .def(self.span(), self); + let vec_val = self + .emit() + .composite_construct(vec_ty, None, component_ids) + .unwrap() + .with_type(vec_ty); + let casted = self.bitcast_leaf_if_needed(vec_val, vec_ty, *dst_ty); + self.store_scalar_leaf(dest_ptr, dst_indices, *dst_ty, casted); + } + + // 1 -> N: extract src scalar, bitcast to vector, unpack + ( + [ + ScalarLeaf { + indices: src_indices, + ty: src_ty, + bits: src_bits, + }, + ], + dst_group, + ) => { + let n = dst_group.len(); + if n > 4 { + return false; + } + let vec_elem_bits = dst_group[0].bits; + let vec_elem_ty = + SpirvType::Integer(vec_elem_bits, false).def(self.span(), self); + let vec_ty = SpirvType::Vector { + element: vec_elem_ty, + count: n as u32, + size: Size::from_bits(*src_bits), + align: Align::from_bits(vec_elem_bits as u64).unwrap(), + } + .def(self.span(), self); + let extracted = self.extract_scalar_leaf(val, src_indices, *src_ty); + let vec_val = self.bitcast_leaf_if_needed(extracted, *src_ty, vec_ty); + for ( + i, + ScalarLeaf { + indices: dst_indices, + ty: dst_ty, + .. + }, + ) in dst_group.iter().enumerate() + { + let component = self + .emit() + .composite_extract(vec_elem_ty, None, vec_val.def(self), [i as u32]) + .unwrap() + .with_type(vec_elem_ty); + let casted = self.bitcast_leaf_if_needed(component, vec_elem_ty, *dst_ty); + self.store_scalar_leaf(dest_ptr, dst_indices, *dst_ty, casted); + } + } + + _ => return false, + } + } + + true + } + + fn load_scalar_leaf( + &mut self, + base_ptr: SpirvValue, + indices: &[u32], + scalar_ty: Word, + ) -> SpirvValue { + let ptr = if indices.is_empty() { + base_ptr + } else { + let ptr_ty = self.type_ptr_to(scalar_ty); + let idx_ids: Vec<_> = indices + .iter() + .map(|&i| self.constant_u32(self.span(), i).def(self)) + .collect(); + self.emit() + .in_bounds_access_chain(ptr_ty, None, base_ptr.def(self), idx_ids) + .unwrap() + .with_type(ptr_ty) + }; + let ptr_ty = self.type_ptr_to(scalar_ty); + let ptr = if ptr.ty == ptr_ty { + ptr + } else { + // base_ptr was already the scalar pointer (indices empty) + ptr + }; + self.emit() + .load(scalar_ty, None, ptr.def(self), None, []) + .unwrap() + .with_type(scalar_ty) + } + + fn extract_scalar_leaf( + &mut self, + base_val: SpirvValue, + indices: &[u32], + scalar_ty: Word, + ) -> SpirvValue { + if indices.is_empty() { + base_val + } else { + self.emit() + .composite_extract(scalar_ty, None, base_val.def(self), indices.iter().copied()) + .unwrap() + .with_type(scalar_ty) + } + } + + fn store_scalar_leaf( + &mut self, + base_ptr: SpirvValue, + indices: &[u32], + scalar_ty: Word, + val: SpirvValue, + ) { + let ptr = if indices.is_empty() { + base_ptr + } else { + let ptr_ty = self.type_ptr_to(scalar_ty); + let idx_ids: Vec<_> = indices + .iter() + .map(|&i| self.constant_u32(self.span(), i).def(self)) + .collect(); + self.emit() + .in_bounds_access_chain(ptr_ty, None, base_ptr.def(self), idx_ids) + .unwrap() + .with_type(ptr_ty) + }; + self.emit() + .store(ptr.def(self), val.def(self), None, []) + .unwrap(); + } + + fn bitcast_leaf_if_needed( + &mut self, + val: SpirvValue, + src_ty: Word, + dst_ty: Word, + ) -> SpirvValue { + if src_ty == dst_ty { + val + } else { + self.emit() + .bitcast(dst_ty, None, val.def(self)) + .unwrap() + .with_type(dst_ty) + } + } + + /// group scalar leaves from `src` and `dst` into pairs where each pair has equal total + /// bit widths. This allows N:M reinterpretation (e.g. two f32s -> one f64). + /// returns `None` if the total bit widths differ or if a leaf has no defined bit width + fn group_leaves_by_bits(src: &[ScalarLeaf], dst: &[ScalarLeaf]) -> Option> { + let mut groups = Vec::new(); + let mut src_i = 0usize; + let mut dst_i = 0usize; + let mut src_acc = 0u64; + let mut dst_acc = 0u64; + let mut cur_src: Vec = Vec::new(); + let mut cur_dst: Vec = Vec::new(); + + while src_i < src.len() || dst_i < dst.len() { + if src_acc <= dst_acc && src_i < src.len() { + let leaf = src[src_i].clone(); + if leaf.bits == 0 { + return None; + } + src_acc += leaf.bits as u64; + cur_src.push(leaf); + src_i += 1; + } else if dst_i < dst.len() { + let leaf = dst[dst_i].clone(); + if leaf.bits == 0 { + return None; + } + dst_acc += leaf.bits as u64; + cur_dst.push(leaf); + dst_i += 1; + } else { + return None; + } + + if src_acc == dst_acc && src_acc > 0 { + // bth sides have accumulated the same number of bits: flush the group. + groups.push(LeafGroup { + src: std::mem::take(&mut cur_src), + dst: std::mem::take(&mut cur_dst), + }); + src_acc = 0; + dst_acc = 0; + } + } + + // if anything remains unflushed, total bits didn't balance. + if src_acc != dst_acc { + return None; + } + + Some(groups) + } + + ///collect all scalar leaf types reachable from `ty` via `OpAccessChain`, + /// along with the index sequence needed to reach each one and its bit width. + /// returns `false` if the type contains any non-decomposable composite members + /// (e.g. images, samplers, pointers ...) + fn collect_scalar_leaves( + &self, + ty: Word, + indices: &mut Vec, + result: &mut Vec, + ) -> bool { + match self.lookup_type(ty) { + SpirvType::Integer(bits, _) | SpirvType::Float(bits) => { + result.push(ScalarLeaf { + indices: indices.clone(), + ty, + bits, + }); + true + } + SpirvType::Adt { field_types, .. } => { + for (i, &field_ty) in field_types.iter().enumerate() { + indices.push(i as u32); + if !self.collect_scalar_leaves(field_ty, indices, result) { + return false; + } + indices.pop(); + } + true + } + SpirvType::Array { element, count } => { + let n = match self.builder.lookup_const_scalar(count) { + Some(n) => n as u32, + None => return false, + }; + for i in 0..n { + indices.push(i); + if !self.collect_scalar_leaves(element, indices, result) { + return false; + } + indices.pop(); + } + true + } + SpirvType::Vector { element, count, .. } => { + for i in 0..count { + indices.push(i); + if !self.collect_scalar_leaves(element, indices, result) { + return false; + } + indices.pop(); + } + true + } + SpirvType::Matrix { element, count } => { + for i in 0..count { + indices.push(i); + if !self.collect_scalar_leaves(element, indices, result) { + return false; + } + indices.pop(); + } + true + } + _ => false, + } + } } impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { @@ -1958,7 +2511,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn store(&mut self, val: Self::Value, ptr: Self::Value, _align: Align) -> Self::Value { + let orig_ptr = ptr; let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, val.ty); + + if access_ty != val.ty { + let dest_ptr = orig_ptr.strip_ptrcasts(); + if self.try_store_composite_scalar_leaves(val, dest_ptr) { + return val; + } + } + let val = self.bitcast(val, access_ty); self.emit() @@ -2856,6 +3418,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { .copy_memory(dst.def(self), src.def(self), None, None, empty()) .unwrap(); } + } else if const_size.is_some() && self.memcpy_const_size(dst, src) { } else { self.emit() .copy_memory_sized( diff --git a/tests/compiletests/ui/lang/core/unions.rs b/tests/compiletests/ui/lang/core/unions.rs new file mode 100644 index 00000000000..06a4c7d8e3c --- /dev/null +++ b/tests/compiletests/ui/lang/core/unions.rs @@ -0,0 +1,155 @@ +// build-pass +// compile-flags: -C target-feature=+Float64,+Int8,+Int16 + +use spirv_std::glam::Vec4; +use spirv_std::spirv; + + +#[repr(C)] +#[derive(Clone, Copy)] +struct Data { + a: i32, + b: [f32; 3], + c: u32, +} + +union DataOrArray { + arr: [f32; 5], + str: Data, +} + +impl DataOrArray { + fn arr(self) -> [f32; 5] { + unsafe { self.arr } + } + fn new(arr: [f32; 5]) -> Self { + Self { arr } + } +} + +#[spirv(fragment)] +pub fn union_mixed_types() { + let dora = DataOrArray::new([1.0, 2.0, 3.0, 4.0, 5.0]); + let arr = dora.arr(); + +} + + +union TwoU8OrU16 { + bytes: [u8; 2], + half: u16, +} + +#[spirv(fragment)] +pub fn union_two_u8_as_u16(output: &mut u16) { + let u = TwoU8OrU16 { bytes: [0xABu8, 0xCDu8] }; + *output = unsafe { u.half }; +} + + + + +// ── N:M groups: [u8; 4] -> [u16; 2] +// Group 1: u8(0)+u8(1) = u16(0) +// Group 2: u8(2)+u8(3) = u16(1) + +union FourU8OrTwoU16 { + bytes: [u8; 4], + halves: [u16; 2], +} + +#[spirv(fragment)] +pub fn fragment_union_four_u8_as_two_u16(output: &mut [u16; 2]) { + let u = FourU8OrTwoU16 { bytes: [1u8, 2u8, 3u8, 4u8] }; + *output = unsafe { u.halves }; +} + +#[spirv(fragment)] +pub fn fragment_union_two_u16_as_four_u8(output: &mut [u8; 4]) { + let u = FourU8OrTwoU16 { halves: [0x0102u16, 0x0304u16] }; + *output = unsafe { u.bytes }; +} + +// [f32; 2] -> f64 (64 bits) +union TwoF32OrF64 { + pair: [f32; 2], + wide: f64, +} + +#[spirv(fragment)] +pub fn fragment_union_two_f32_as_f64(output: &mut f64) { + let u = TwoF32OrF64 { pair: [1.0f32, 2.0f32] }; + *output = unsafe { u.wide }; +} + +// f64 -> [f32; 2] + +#[spirv(fragment)] +pub fn fragment_union_f64_as_two_f32(output: &mut [f32; 2]) { + let u = TwoF32OrF64 { wide: 0.0f64 }; + *output = unsafe { u.pair }; +} + +// struct -> array, same total size, no padding +// struct { x: u32, y: u32, z: u32 } (12 bytes) -> [f32; 3] (12 bytes). + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct ThreeU32 { + pub x: u32, + pub y: u32, + pub z: u32, +} + +union ThreeU32OrF32Array { + s: ThreeU32, + a: [f32; 3], +} + +#[spirv(fragment)] +pub fn fragment_union_struct_u32_vs_f32_array(output: &mut [f32; 3]) { + let u = ThreeU32OrF32Array { + s: ThreeU32 { x: 0, y:3 , z: 2 }, + }; + *output = unsafe { u.a }; +} + +// ── 1:1, struct -> array, four f32 fields (16 bytes, no padding) + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct MyVec4 { + pub x: f32, + pub y: f32, + pub z: f32, + pub w: f32, +} + +union ArrayOrMyVec4 { + arr: [f32; 4], + vec: MyVec4, +} + +#[spirv(fragment)] +pub fn union_array_vs_struct_vec4(output: &mut MyVec4) { + let u = ArrayOrMyVec4 { arr: [1.0, 2.0, 3.0, 4.0] }; + *output = unsafe { u.vec }; +} + + +union FourF32OrTwoF64 { + narrow: [f32; 4], + wide: [f64; 2], +} + +#[spirv(fragment)] +pub fn union_four_f32_as_two_f64(output: &mut [f64; 2]) { + let u = FourF32OrTwoF64 { narrow: [1.0f32, 2.0f32, 3.0f32, 4.0f32] }; + *output = unsafe { u.wide }; +} + +#[spirv(fragment)] +pub fn union_two_f64_as_four_f32(output: &mut [f32; 4]) { + let u = FourF32OrTwoF64 { wide: [0.0f64, 1.0f64] }; + *output = unsafe { u.narrow }; +} From 046f1051c0779b3a8a511dfdef186255ad3dbe07 Mon Sep 17 00:00:00 2001 From: 39ali Date: Wed, 15 Apr 2026 13:00:36 +0300 Subject: [PATCH 2/3] fmt tests --- tests/compiletests/ui/lang/core/unions.rs | 43 +++++++++++++---------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/tests/compiletests/ui/lang/core/unions.rs b/tests/compiletests/ui/lang/core/unions.rs index 06a4c7d8e3c..4954dc5c278 100644 --- a/tests/compiletests/ui/lang/core/unions.rs +++ b/tests/compiletests/ui/lang/core/unions.rs @@ -4,7 +4,6 @@ use spirv_std::glam::Vec4; use spirv_std::spirv; - #[repr(C)] #[derive(Clone, Copy)] struct Data { @@ -31,10 +30,8 @@ impl DataOrArray { pub fn union_mixed_types() { let dora = DataOrArray::new([1.0, 2.0, 3.0, 4.0, 5.0]); let arr = dora.arr(); - } - union TwoU8OrU16 { bytes: [u8; 2], half: u16, @@ -42,16 +39,15 @@ union TwoU8OrU16 { #[spirv(fragment)] pub fn union_two_u8_as_u16(output: &mut u16) { - let u = TwoU8OrU16 { bytes: [0xABu8, 0xCDu8] }; + let u = TwoU8OrU16 { + bytes: [0xABu8, 0xCDu8], + }; *output = unsafe { u.half }; } - - - // ── N:M groups: [u8; 4] -> [u16; 2] -// Group 1: u8(0)+u8(1) = u16(0) -// Group 2: u8(2)+u8(3) = u16(1) +// Group 1: u8(0)+u8(1) = u16(0) +// Group 2: u8(2)+u8(3) = u16(1) union FourU8OrTwoU16 { bytes: [u8; 4], @@ -60,13 +56,17 @@ union FourU8OrTwoU16 { #[spirv(fragment)] pub fn fragment_union_four_u8_as_two_u16(output: &mut [u16; 2]) { - let u = FourU8OrTwoU16 { bytes: [1u8, 2u8, 3u8, 4u8] }; + let u = FourU8OrTwoU16 { + bytes: [1u8, 2u8, 3u8, 4u8], + }; *output = unsafe { u.halves }; } #[spirv(fragment)] pub fn fragment_union_two_u16_as_four_u8(output: &mut [u8; 4]) { - let u = FourU8OrTwoU16 { halves: [0x0102u16, 0x0304u16] }; + let u = FourU8OrTwoU16 { + halves: [0x0102u16, 0x0304u16], + }; *output = unsafe { u.bytes }; } @@ -78,11 +78,13 @@ union TwoF32OrF64 { #[spirv(fragment)] pub fn fragment_union_two_f32_as_f64(output: &mut f64) { - let u = TwoF32OrF64 { pair: [1.0f32, 2.0f32] }; + let u = TwoF32OrF64 { + pair: [1.0f32, 2.0f32], + }; *output = unsafe { u.wide }; } -// f64 -> [f32; 2] +// f64 -> [f32; 2] #[spirv(fragment)] pub fn fragment_union_f64_as_two_f32(output: &mut [f32; 2]) { @@ -109,7 +111,7 @@ union ThreeU32OrF32Array { #[spirv(fragment)] pub fn fragment_union_struct_u32_vs_f32_array(output: &mut [f32; 3]) { let u = ThreeU32OrF32Array { - s: ThreeU32 { x: 0, y:3 , z: 2 }, + s: ThreeU32 { x: 0, y: 3, z: 2 }, }; *output = unsafe { u.a }; } @@ -132,11 +134,12 @@ union ArrayOrMyVec4 { #[spirv(fragment)] pub fn union_array_vs_struct_vec4(output: &mut MyVec4) { - let u = ArrayOrMyVec4 { arr: [1.0, 2.0, 3.0, 4.0] }; + let u = ArrayOrMyVec4 { + arr: [1.0, 2.0, 3.0, 4.0], + }; *output = unsafe { u.vec }; } - union FourF32OrTwoF64 { narrow: [f32; 4], wide: [f64; 2], @@ -144,12 +147,16 @@ union FourF32OrTwoF64 { #[spirv(fragment)] pub fn union_four_f32_as_two_f64(output: &mut [f64; 2]) { - let u = FourF32OrTwoF64 { narrow: [1.0f32, 2.0f32, 3.0f32, 4.0f32] }; + let u = FourF32OrTwoF64 { + narrow: [1.0f32, 2.0f32, 3.0f32, 4.0f32], + }; *output = unsafe { u.wide }; } #[spirv(fragment)] pub fn union_two_f64_as_four_f32(output: &mut [f32; 4]) { - let u = FourF32OrTwoF64 { wide: [0.0f64, 1.0f64] }; + let u = FourF32OrTwoF64 { + wide: [0.0f64, 1.0f64], + }; *output = unsafe { u.narrow }; } From 2698fbbaa8b424f21016c0d0832da6c17973762c Mon Sep 17 00:00:00 2001 From: 39ali Date: Wed, 15 Apr 2026 13:14:37 +0300 Subject: [PATCH 3/3] fix clippy --- .../src/builder/builder_methods.rs | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 1a452ff82a6..bae216af815 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -1299,7 +1299,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { /// `OpLoad` + `OpStore` pairs (with `OpBitcast` for type mismatches). This is used to /// support union field access, e.g. `[f32; 5]` -> `struct { a: f64, b: [f32; 3] }`. /// handles both 1:1 (same scalar count) and N:M (different counts but same total bits) - /// returns `false` if the types are not compatible. fn memcpy_const_size(&mut self, dst: SpirvValue, src: SpirvValue) -> bool { let src_stripped = src.strip_ptrcasts(); @@ -1495,7 +1494,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }; trace!( - "try_store_composite_via_scalar_leaves: {} groups ({} src leaves, {} dst leaves)", + "try_store_composite_scalar_leaves: {} groups ({} src leaves, {} dst leaves)", groups.len(), src_leaves.len(), dst_leaves.len(), @@ -1808,17 +1807,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } true } - SpirvType::Vector { element, count, .. } => { - for i in 0..count { - indices.push(i); - if !self.collect_scalar_leaves(element, indices, result) { - return false; - } - indices.pop(); - } - true - } - SpirvType::Matrix { element, count } => { + SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => { for i in 0..count { indices.push(i); if !self.collect_scalar_leaves(element, indices, result) {