Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
328 changes: 328 additions & 0 deletions src/spv/lift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,13 @@ enum Merge<L> {
},
}

/// Relationship between loop body condition and loop-continue condition.
#[derive(Copy, Clone, PartialEq, Eq)]
enum CondRelation {
Same,
Not,
}

impl<'a> NeedsIdsCollector<'a> {
fn alloc_ids<E>(
self,
Expand Down Expand Up @@ -508,6 +515,324 @@ impl FuncAt<'_, Node> {
}

impl<'a> FuncLifting<'a> {
/// Recompute incoming edge counts for each CFG point in `blocks`.
///
/// This has to run after control-flow rewrites and before dead-block
/// pruning, to avoid stale predecessor counts.
fn recompute_use_counts(
blocks: &FxIndexMap<CfgPoint, BlockLifting<'a>>,
use_counts: &mut FxHashMap<CfgPoint, usize>,
) {
use_counts.clear();
use_counts.reserve(blocks.len());
let all_edges = blocks.first().map(|(&entry_point, _)| entry_point).into_iter().chain(
blocks.values().flat_map(|block| {
block
.terminator
.merge
.iter()
.flat_map(|merge| {
let (a, b) = match merge {
Merge::Selection(a) => (a, None),
Merge::Loop { loop_merge: a, loop_continue: b } => (a, Some(b)),
};
[a].into_iter().chain(b)
})
.chain(&block.terminator.targets)
.copied()
}),
);
for target in all_edges {
*use_counts.entry(target).or_default() += 1;
}
}

/// Return `true` iff `point` is an empty pass-through block branching only
/// to `target`.
fn is_passthrough_branch_to(
blocks: &FxIndexMap<CfgPoint, BlockLifting<'a>>,
point: CfgPoint,
target: CfgPoint,
) -> bool {
let Some(block) = blocks.get(&point) else {
return false;
};
block.phis.is_empty()
&& block.insts.is_empty()
&& block.terminator.attrs == AttrSet::default()
&& matches!(&*block.terminator.kind, cfg::ControlInstKind::Branch)
&& block.terminator.inputs.is_empty()
&& block.terminator.targets.as_slice() == [target]
&& block.terminator.target_phi_values.keys().all(|&phi_target| phi_target == target)
&& block.terminator.merge.is_none()
}

fn is_const_opcode(cx: &Context, v: Value, opcode: spec::Opcode) -> bool {
match v {
Value::Const(c) => match &cx[c].kind {
ConstKind::SpvInst { spv_inst_and_const_inputs } => {
spv_inst_and_const_inputs.0.opcode == opcode
}
_ => false,
},
_ => false,
}
}

/// Determine whether `continue_cond` is equal to `body_cond` or its
/// logical negation.
fn continue_cond_relation(
cx: &Context,
func_def_body: &'a crate::FuncDefBody,
continue_cond: Value,
body_cond: Value,
) -> Option<CondRelation> {
if continue_cond == body_cond {
return Some(CondRelation::Same);
}

let wk = &spec::Spec::get().well_known;
match continue_cond {
Value::NodeOutput { node, output_idx } => {
let node_def = func_def_body.at(node).def();
let NodeKind::Select { kind: SelectionKind::BoolCond, scrutinee, cases } =
&node_def.kind
else {
return None;
};
if *scrutinee != body_cond || cases.len() != 2 {
return None;
}

let output_idx = output_idx as usize;
let true_case_outputs = &func_def_body.at(cases[0]).def().outputs;
let false_case_outputs = &func_def_body.at(cases[1]).def().outputs;
if output_idx >= true_case_outputs.len() || output_idx >= false_case_outputs.len() {
return None;
}

let on_true = true_case_outputs[output_idx];
let on_false = false_case_outputs[output_idx];
if Self::is_const_opcode(cx, on_true, wk.OpConstantTrue)
&& Self::is_const_opcode(cx, on_false, wk.OpConstantFalse)
{
Some(CondRelation::Same)
} else if Self::is_const_opcode(cx, on_true, wk.OpConstantFalse)
&& Self::is_const_opcode(cx, on_false, wk.OpConstantTrue)
{
Some(CondRelation::Not)
} else {
None
}
}

_ => None,
}
}

/// Rewrite `loop_continue` into an unconditional backedge while preserving
/// only phi payloads for the loop header edge.
fn rewrite_continue_as_unconditional_backedge(
blocks: &mut FxIndexMap<CfgPoint, BlockLifting<'a>>,
loop_continue: CfgPoint,
) {
let continue_block = blocks.get_mut(&loop_continue).unwrap();
continue_block.terminator.kind = Cow::Owned(cfg::ControlInstKind::Branch);
continue_block.terminator.inputs = [].into_iter().collect();
let header_point = continue_block.terminator.targets[0];
continue_block.terminator.targets = [header_point].into_iter().collect();
continue_block.terminator.target_phi_values.retain(|&target, _| target == header_point);
}

/// Canonicalize strict loop shortcut patterns:
/// * loop header branches to a body select,
/// * one body arm is an empty pass-through to body merge,
/// * body merge branches to `loop_continue`,
/// * `loop_continue` conditionally branches to header/merge.
///
/// Rewriting the pass-through arm directly to `loop_merge` avoids
/// preserving this fragile shape in lifted CFG.
fn canonicalize_loop_continue_shortcuts(
cx: &Context,
func_def_body: &'a crate::FuncDefBody,
blocks: &mut FxIndexMap<CfgPoint, BlockLifting<'a>>,
) {
let mut loop_continue_shortcuts =
SmallVec::<[(CfgPoint, CfgPoint, CfgPoint, CfgPoint, bool); 4]>::new();
for (&header_point, header_block) in &*blocks {
let header_term = &header_block.terminator;
let Some(Merge::Loop { loop_merge, loop_continue }) = header_term.merge else {
continue;
};
if header_term.attrs != AttrSet::default()
|| !matches!(&*header_term.kind, cfg::ControlInstKind::Branch)
|| !header_term.inputs.is_empty()
|| header_term.targets.len() != 1
|| !header_term.target_phi_values.is_empty()
{
continue;
}

let body_point = header_term.targets[0];
let Some(body_block) = blocks.get(&body_point) else {
continue;
};
let body_term = &body_block.terminator;
let Some(Merge::Selection(body_merge)) = body_term.merge else {
continue;
};
if body_term.attrs != AttrSet::default()
|| !matches!(
&*body_term.kind,
cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond)
)
|| body_term.inputs.len() != 1
|| body_term.targets.len() != 2
|| !body_term.target_phi_values.is_empty()
{
continue;
}

let Some(continue_block) = blocks.get(&loop_continue) else {
continue;
};
if !continue_block.phis.is_empty() || !continue_block.insts.is_empty() {
continue;
}
let continue_term = &continue_block.terminator;
if continue_term.attrs != AttrSet::default()
|| !matches!(
&*continue_term.kind,
cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond)
)
|| continue_term.inputs.len() != 1
|| continue_term.targets.as_slice() != [header_point, loop_merge]
|| continue_term.target_phi_values.keys().any(|&target| target != header_point)
|| continue_term.merge.is_some()
{
continue;
}

let t0 = body_term.targets[0];
let t1 = body_term.targets[1];
let (_work_target, pass_target) =
if Self::is_passthrough_branch_to(blocks, t0, body_merge) {
(t1, t0)
} else if Self::is_passthrough_branch_to(blocks, t1, body_merge) {
(t0, t1)
} else {
continue;
};
let Some(cond_relation) = Self::continue_cond_relation(
cx,
func_def_body,
continue_term.inputs[0],
body_term.inputs[0],
) else {
continue;
};
let continue_routes_work_to_header = match cond_relation {
CondRelation::Same => pass_target == t1,
CondRelation::Not => pass_target == t0,
};
if !continue_routes_work_to_header {
continue;
}

let body_merge_preds: SmallVec<[CfgPoint; 4]> = blocks
.iter()
.filter_map(|(&point, block)| {
block.terminator.targets.contains(&body_merge).then_some(point)
})
.collect();
if body_merge_preds.len() != 2 || !body_merge_preds.contains(&pass_target) {
continue;
}
let Some(other_body_merge_pred) =
body_merge_preds.into_iter().find(|&point| point != pass_target)
else {
continue;
};

let continue_pred_count = blocks
.values()
.filter(|block| block.terminator.targets.contains(&loop_continue))
.count();
if continue_pred_count != 1 {
continue;
}

let Some(loop_merge_block) = blocks.get(&loop_merge) else {
continue;
};
if !loop_merge_block.phis.is_empty() {
continue;
}

let Some(body_merge_block) = blocks.get(&body_merge) else {
continue;
};
let merge_term = &body_merge_block.terminator;
if merge_term.attrs != AttrSet::default()
|| !matches!(&*merge_term.kind, cfg::ControlInstKind::Branch)
|| !merge_term.inputs.is_empty()
|| merge_term.targets.as_slice() != [loop_continue]
|| !merge_term.target_phi_values.is_empty()
|| merge_term.merge.is_some()
{
continue;
}

let body_merge_phi_count = body_merge_block.phis.len();
let payload_arity_to = |source: CfgPoint, target: CfgPoint| {
blocks
.get(&source)
.and_then(|block| block.terminator.target_phi_values.get(&target))
.map_or(0, |values| values.len())
};
if payload_arity_to(pass_target, body_merge) != body_merge_phi_count
|| payload_arity_to(other_body_merge_pred, body_merge) != body_merge_phi_count
{
continue;
}

let header_phi_count = header_block.phis.len();
let continue_payload_arity =
continue_term.target_phi_values.get(&header_point).map_or(0, |values| values.len());
if continue_payload_arity != header_phi_count {
continue;
}

loop_continue_shortcuts.push((
body_point,
pass_target,
loop_merge,
loop_continue,
continue_routes_work_to_header,
));
}
for (
body_point,
pass_target,
loop_merge,
loop_continue,
rewrite_continue_to_unconditional_backedge,
) in loop_continue_shortcuts
{
let body_block = blocks.get_mut(&body_point).unwrap();
body_block.terminator.merge = None;
for target in &mut body_block.terminator.targets {
if *target == pass_target {
*target = loop_merge;
}
}

if rewrite_continue_to_unconditional_backedge {
Self::rewrite_continue_as_unconditional_backedge(blocks, loop_continue);
}
}
}

fn from_func_decl<E>(
cx: &Context,
func_decl: &'a FuncDecl,
Expand Down Expand Up @@ -913,6 +1238,9 @@ impl<'a> FuncLifting<'a> {
}
}

Self::canonicalize_loop_continue_shortcuts(cx, func_def_body, &mut blocks);
Self::recompute_use_counts(&blocks, &mut use_counts);

// Remove now-unused blocks.
blocks.retain(|point, _| use_counts.get(point).is_some_and(|&count| count > 0));

Expand Down
Binary file added tests/data/basic.frag.glsl.dbg.spvbin
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/loop-continue-shortcut.repro.spvbin
Binary file not shown.
Loading
Loading