diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index d86db1cbd00..a8cdc0deb3b 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -2423,7 +2423,46 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { "pointercast called on non-pointer dest type: {other:?}" )), }; - let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self); + + let dst_pointee_ty = self.lookup_type(dest_pointee); + let dest_pointee_size = dst_pointee_ty.sizeof(self); + let src_pointee_ty = self.lookup_type(ptr_pointee); + + // array -> element: *[T; N] -> *T + if let SpirvType::Array { + element: elem_ty, .. + } = src_pointee_ty + && elem_ty == dest_pointee + { + let zero = self.constant_u32(self.span(), 0).def(self); + return self + .emit() + .in_bounds_access_chain(dest_ty, None, ptr.def(self), [zero]) + .unwrap() + .with_type(dest_ty); + } + + // array -> RuntimeArray: *[T; N] -> *[T] + if let SpirvType::Array { + element: elem_ty, .. + } = src_pointee_ty + && let SpirvType::RuntimeArray { + element: rt_elem_ty, + } = dst_pointee_ty + && elem_ty == rt_elem_ty + { + let zero = self.constant_u32(self.span(), 0).def(self); + let elem_ptr_ty = self.type_ptr_to(elem_ty); + let elem_ptr = self + .emit() + .in_bounds_access_chain(elem_ptr_ty, None, ptr.def(self), [zero]) + .unwrap(); + return self + .emit() + .bitcast(dest_ty, None, elem_ptr) + .unwrap() + .with_type(dest_ty); + } if let Some((indices, _)) = self.recover_access_chain_from_offset( ptr_pointee, diff --git a/tests/compiletests/ui/lang/core/array-slice-cast.rs b/tests/compiletests/ui/lang/core/array-slice-cast.rs new file mode 100644 index 00000000000..bc1d252df58 --- /dev/null +++ b/tests/compiletests/ui/lang/core/array-slice-cast.rs @@ -0,0 +1,49 @@ +// build-pass +// compile-flags: -C llvm-args=--disassemble +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "%\d+ = OpString .*\n" -> "" +// normalize-stderr-test "^(; .*\n)*" -> "" +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" +// ignore-spv1.0 +// ignore-spv1.1 +// ignore-spv1.2 +// ignore-spv1.3 +// ignore-spv1.4 +// ignore-spv1.5 +// ignore-spv1.6 +// ignore-vulkan1.0 +// ignore-vulkan1.1 + +use spirv_std::spirv; + +const CONST_ARRAY: [u32; 3] = [1, 2, 3]; + +#[spirv(compute(threads(64)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32; 3], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut u32, +) { + let mut out = 0; + + // &[u32] fat pointer from a runtime storage buffer array + let slice: &[u32] = input; + out += slice[1]; + + // *[u32; 3] -> *u32 via AccessChain (array->element path) + let array_ptr: *const [u32; 3] = input; + let element_ptr: *const u32 = array_ptr as *const u32; + out += unsafe { *element_ptr }; + + // &[u32] fat pointer from a runtime storage buffer array + let slice: &[u32] = &CONST_ARRAY; + out += slice[1]; + + // *[u32; 3] -> *u32 via AccessChain (array->element path) + let array_ptr: *const [u32; 3] = &CONST_ARRAY; + let element_ptr: *const u32 = array_ptr as *const u32; + out += unsafe { *element_ptr }; + + *output = out; +} diff --git a/tests/compiletests/ui/lang/core/array-slice-cast.stderr b/tests/compiletests/ui/lang/core/array-slice-cast.stderr new file mode 100644 index 00000000000..eacadc5c73a --- /dev/null +++ b/tests/compiletests/ui/lang/core/array-slice-cast.stderr @@ -0,0 +1,57 @@ +OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "main" %2 %3 %4 +OpExecutionMode %1 LocalSize 64 1 1 +OpName %2 "input" +OpName %3 "output" +OpDecorate %6 ArrayStride 4 +OpDecorate %7 Block +OpMemberDecorate %7 0 Offset 0 +OpDecorate %8 Block +OpMemberDecorate %8 0 Offset 0 +OpDecorate %2 NonWritable +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +OpDecorate %3 Binding 1 +OpDecorate %3 DescriptorSet 0 +%9 = OpTypeInt 32 0 +%10 = OpConstant %9 3 +%6 = OpTypeArray %9 %10 +%7 = OpTypeStruct %6 +%11 = OpTypePointer StorageBuffer %7 +%8 = OpTypeStruct %9 +%12 = OpTypePointer StorageBuffer %8 +%13 = OpTypeVoid +%14 = OpTypeFunction %13 +%15 = OpTypePointer StorageBuffer %6 +%2 = OpVariable %11 StorageBuffer +%16 = OpConstant %9 0 +%17 = OpTypePointer StorageBuffer %9 +%3 = OpVariable %12 StorageBuffer +%18 = OpConstant %9 1 +%19 = OpTypePointer Private %9 +%20 = OpTypeArray %9 %10 +%21 = OpTypePointer Private %20 +%22 = OpConstant %9 2 +%23 = OpConstantComposite %20 %18 %22 %10 +%4 = OpVariable %21 Private %23 +%1 = OpFunction %13 None %14 +%24 = OpLabel +%25 = OpInBoundsAccessChain %15 %2 %16 +%26 = OpInBoundsAccessChain %17 %3 %16 +%27 = OpInBoundsAccessChain %17 %25 %18 +%28 = OpLoad %9 %27 +%29 = OpIAdd %9 %16 %28 +%30 = OpInBoundsAccessChain %17 %25 %16 +%31 = OpLoad %9 %30 +%32 = OpIAdd %9 %29 %31 +%33 = OpInBoundsAccessChain %19 %4 %18 +%34 = OpLoad %9 %33 +%35 = OpIAdd %9 %32 %34 +%36 = OpInBoundsAccessChain %19 %4 %16 +%37 = OpLoad %9 %36 +%38 = OpIAdd %9 %35 %37 +OpStore %26 %38 +OpNoLine +OpReturn +OpFunctionEnd