Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 71 additions & 19 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, name_type_id};
use itertools::Itertools;
use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
use rustc_abi::ExternAbi as Abi;
Expand Down Expand Up @@ -317,8 +317,33 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
}
}

/// If `layout` has exactly one non-ZST field positioned at offset 0 with size and
/// alignment matching the outer layout, returns that field.
///
/// This captures the structural shape of a "newtype wrapper" — a single meaningful
/// field padded out to the outer type, which can be substituted for the outer type
/// in a SPIR-V type graph as long as the caller has *independently* verified that
/// the ABIs match (either via `BackendRepr::eq_up_to_validity`, or via
/// `#[repr(transparent)]`, which guarantees full ABI identity by construction).
fn sole_structural_newtype_field<'tcx>(
cx: &CodegenCx<'tcx>,
layout: TyAndLayout<'tcx>,
) -> Option<TyAndLayout<'tcx>> {
let mut non_zst = (0..layout.fields.count()).filter(|&i| !layout.field(cx, i).is_zst());
let i = non_zst.next()?;
if non_zst.next().is_some() {
return None;
}
let field = layout.field(cx, i);
(layout.fields.offset(i) == Size::ZERO
&& field.size == layout.size
&& field.align.abi == layout.align.abi)
.then_some(field)
}

impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
fn spirv_type(&self, mut span: Span, cx: &CodegenCx<'tcx>) -> Word {
let mut has_block_attr = false;
if let TyKind::Adt(adt, args) = *self.ty.kind() {
if span == DUMMY_SP {
span = cx.tcx.def_span(adt.did());
Expand All @@ -337,6 +362,8 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
)),
);

has_block_attr = attrs.block.is_some();

if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value)
&& let Ok(spirv_type) =
trans_intrinsic_type(cx, span, *self, args, intrinsic_type_attr)
Expand Down Expand Up @@ -375,23 +402,12 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
// a new one, offering the `(a, b)` shape `rustc_codegen_ssa`
// expects, while letting noop pointercasts access the sole
// `BackendRepr::ScalarPair` field - this is the approach taken here
let mut non_zst_fields = (0..self.fields.count())
.map(|i| (i, self.field(cx, i)))
.filter(|(_, field)| !field.is_zst());
let sole_non_zst_field = match (non_zst_fields.next(), non_zst_fields.next()) {
(Some(field), None) => Some(field),
_ => None,
};
if let Some((i, field)) = sole_non_zst_field {
// Only unpack a newtype if the field and the newtype line up
// perfectly, in every way that could potentially affect ABI.
if self.fields.offset(i) == Size::ZERO
&& field.size == self.size
&& field.align.abi == self.align.abi
&& field.backend_repr.eq_up_to_validity(&self.backend_repr)
{
return field.spirv_type(span, cx);
}
// Only unpack a newtype if the field and the newtype line up
// perfectly, in every way that could potentially affect ABI.
if let Some(field) = sole_structural_newtype_field(cx, *self)
&& field.backend_repr.eq_up_to_validity(&self.backend_repr)
{
return field.spirv_type(span, cx);
}

// Note: We can't use auto_struct_layout here because the spirv types here might be undefined due to
Expand Down Expand Up @@ -443,7 +459,43 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
.tcx
.dcx()
.fatal("scalable vectors are not supported in SPIR-V backend"),
BackendRepr::Memory { sized: _ } => trans_aggregate(cx, span, *self),
BackendRepr::Memory { sized: _ } => {
// For `#[repr(transparent)]` newtypes, reuse the single non-ZST
// field's SPIR-V type directly instead of wrapping it in an
// `OpTypeStruct`. This mirrors the `ScalarPair` newtype-unpacking
// above, but with two differences:
//
// 1. We gate on `#[repr(transparent)]` explicitly. Without this
// guard, semantically opaque single-field structs (e.g.
// `struct Viewport { rect: Vec4 }`) would also collapse,
// losing their `OpTypeStruct` identity - which matters for
// SPIR-V interface matching, `OpMemberDecorate`, and
// push-constant block declarations.
//
// 2. We additionally refuse to collapse if the wrapper carries
// a `#[spirv(block)]` attribute, since that decoration needs
// an `OpTypeStruct` to land on.
//
// We omit the `eq_up_to_validity` check on `backend_repr`: the
// outer wrapper always has `BackendRepr::Memory` here while the
// inner field may have a different repr (e.g. a type annotated
// with `#[rust_gpu::spirv(vector)]` goes through the intrinsic
// path above and returns an `OpTypeVector`). The reprs will
// never match, but `#[repr(transparent)]` guarantees full ABI
// identity, so the structural checks in the helper suffice.
Comment on lines +479 to +485
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is AI BS, if you add eq_up_to_validity here you'll get the same result. In general, this patch just allows #[transparent] to inline glam vecs (rather #[rust_gpu::vector::v1] annotated structs) from what I can tell.

My entire library ecosystem depends on #[repr(transparent)] working as intended, and rust-gpu was failing that.

Could you explain what exactly fails? Since this doesn't seem to be about optimization, but about something failing compilation in a certain context?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also #[spirv(Block)] is deprecated and doesn't need to be checked

if let TyKind::Adt(adt, _) = self.ty.kind()
&& adt.repr().transparent()
&& !has_block_attr
&& let Some(field) = sole_structural_newtype_field(cx, *self)
{
let inner_id = field.spirv_type(span, cx);
// Preserve the wrapper's name as an `OpName` alias on the
// inner SPIR-V type so disassembly still shows it.
name_type_id(cx, inner_id, TyLayoutNameKey::from(*self));
return inner_id;
}
trans_aggregate(cx, span, *self)
}
}
}
}
Expand Down
19 changes: 12 additions & 7 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ pub enum SpirvType<'tcx> {
RayQueryKhr,
}

/// Emit an `OpName` for a type `id`, deduplicated per `(id, name_key)` pair.
/// Unlike [`SpirvType::def_with_name`], this operates on an existing `id` - useful
/// for types reused across multiple Rust types (e.g. `#[repr(transparent)]`
/// newtype collapse, where the wrapper's name is attached to the inner's id).
pub fn name_type_id<'tcx>(cx: &CodegenCx<'tcx>, id: Word, name_key: TyLayoutNameKey<'tcx>) {
let mut type_names = cx.type_cache.type_names.borrow_mut();
if type_names.entry(id).or_default().insert(name_key) {
cx.emit_global().name(id, name_key.to_string());
}
}

impl SpirvType<'_> {
/// Note: `Builder::type_*` should be called *nowhere else* but here, to ensure
/// `CodegenCx::type_defs` stays up-to-date
Expand Down Expand Up @@ -266,13 +277,7 @@ impl SpirvType<'_> {
name_key: TyLayoutNameKey<'tcx>,
) -> Word {
let id = self.def(def_span, cx);

// Only emit `OpName` if this is the first time we see this name.
let mut type_names = cx.type_cache.type_names.borrow_mut();
if type_names.entry(id).or_default().insert(name_key) {
cx.emit_global().name(id, name_key.to_string());
}

name_type_id(cx, id, name_key);
id
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ OpExecutionMode %1 LocalSize 1 1 1
OpExecutionMode %1 OutputVertices 9
OpExecutionMode %1 OutputPrimitivesNV 3
OpExecutionMode %1 OutputTrianglesNV
OpName %16 "core::ops::Range<usize>"
OpMemberName %16 0 "start"
OpMemberName %16 1 "end"
OpName %16 "ops::try_trait::NeverShortCircuit<[charOpName %17 "core::ops::Range<usize>"
OpMemberName %17 0 "start"
OpMemberName %17 1 "end"
OpName %2 "positions"
OpName %3 "out_per_vertex"
OpName %4 "out_per_vertex2"
OpName %5 "indices"
OpName %6 "out_per_primitive"
OpName %7 "out_per_primitive2"
OpMemberDecorate %16 0 Offset 0
OpMemberDecorate %16 1 Offset 4
OpMemberDecorate %17 0 Offset 0
OpMemberDecorate %17 1 Offset 4
OpDecorate %2 BuiltIn Position
OpDecorate %3 Location 0
OpDecorate %4 Location 1
Expand All @@ -26,48 +26,48 @@ OpDecorate %6 Location 2
OpDecorate %6 PerPrimitiveNV
OpDecorate %7 Location 3
OpDecorate %7 PerPrimitiveNV
%17 = OpTypeFloat 32
%18 = OpTypeVector %17 4
%19 = OpTypeInt 32 0
%20 = OpConstant %19 9
%21 = OpTypeArray %18 %20
%22 = OpTypePointer Output %21
%23 = OpTypeVector %19 3
%24 = OpConstant %19 3
%25 = OpTypeArray %23 %24
%26 = OpTypePointer Output %25
%27 = OpTypeArray %19 %20
%28 = OpTypePointer Output %27
%29 = OpTypeArray %17 %20
%30 = OpTypePointer Output %29
%31 = OpTypeArray %19 %24
%32 = OpTypePointer Output %31
%33 = OpTypeArray %17 %24
%18 = OpTypeFloat 32
%19 = OpTypeVector %18 4
%20 = OpTypeInt 32 0
%21 = OpConstant %20 9
%22 = OpTypeArray %19 %21
%23 = OpTypePointer Output %22
%24 = OpTypeVector %20 3
%25 = OpConstant %20 3
%26 = OpTypeArray %24 %25
%27 = OpTypePointer Output %26
%28 = OpTypeArray %20 %21
%29 = OpTypePointer Output %28
%30 = OpTypeArray %18 %21
%31 = OpTypePointer Output %30
%16 = OpTypeArray %20 %25
%32 = OpTypePointer Output %16
%33 = OpTypeArray %18 %25
%34 = OpTypePointer Output %33
%35 = OpTypeVoid
%36 = OpTypeFunction %35
%16 = OpTypeStruct %19 %19
%37 = OpConstant %19 0
%38 = OpUndef %16
%17 = OpTypeStruct %20 %20
%37 = OpConstant %20 0
%38 = OpUndef %17
%39 = OpTypeBool
%40 = OpConstantFalse %39
%41 = OpConstant %19 1
%41 = OpConstant %20 1
%42 = OpTypeInt 32 1
%43 = OpConstant %42 0
%44 = OpConstant %17 3204448256
%45 = OpConstant %17 1056964608
%46 = OpConstant %17 0
%47 = OpConstant %17 1065353216
%48 = OpTypePointer Output %18
%2 = OpVariable %22 Output
%49 = OpConstant %19 2
%50 = OpTypePointer Output %19
%3 = OpVariable %28 Output
%51 = OpTypePointer Output %17
%4 = OpVariable %30 Output
%52 = OpTypePointer Output %23
%5 = OpVariable %26 Output
%44 = OpConstant %18 3204448256
%45 = OpConstant %18 1056964608
%46 = OpConstant %18 0
%47 = OpConstant %18 1065353216
%48 = OpTypePointer Output %19
%2 = OpVariable %23 Output
%49 = OpConstant %20 2
%50 = OpTypePointer Output %20
%3 = OpVariable %29 Output
%51 = OpTypePointer Output %18
%4 = OpVariable %31 Output
%52 = OpTypePointer Output %24
%5 = OpVariable %27 Output
%6 = OpVariable %32 Output
%53 = OpConstant %19 42
%53 = OpConstant %20 42
%7 = OpVariable %34 Output
%54 = OpConstant %17 1116340224
%54 = OpConstant %18 1116340224
Loading