Skip to content

Commit 33561fc

Browse files
committed
CFG-based loop body detection, remove Theta union, fix DCE used_ids
Three fixes for remaining CI difftest failures: 1. Loop body detection: Replace contiguous index range (header..merge) with BFS from loop header following branch targets, stopping at the merge block. SPIR-V does not guarantee contiguous block layout, so the old range missed non-contiguous loop body blocks that should be protected from RVSDG transformation. Fixes control_flow_complex. 2. Remove Theta-value union: The union of id{N} with Theta(true, id{N}, FConst(0.0)) puts the synthetic zero init value in the same e-class as the actual computation, allowing the extractor to pick FConst(0.0) instead of the real value. The Theta term still exists in the egraph for LoopInvariant propagation without the union. Fixes math_ops. 3. DCE used_ids: Add operand references from types_global_values to used_ids set. OpConstantComposite and other globals may reference IDs that were aliased by the optimizer. Fixes matrix_ops undefined id.
1 parent 3d85fc6 commit 33561fc

1 file changed

Lines changed: 50 additions & 15 deletions

File tree

  • rust/spirv-tools-opt/src/direct

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -220,16 +220,43 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
220220
) = (inst.operands.first(), inst.operands.get(1))
221221
{
222222
let label_map = &func_block_labels[func_idx];
223-
if let Some(&merge_idx) = label_map.get(merge_label) {
224-
let continue_idx = label_map.get(continue_label).copied();
225-
// Loop body spans from header to just before merge block
226-
let body_indices: Vec<usize> = (block_idx..merge_idx).collect();
227-
loop_constructs.push(LoopInfo {
228-
body_block_indices: body_indices,
229-
continue_block_idx: continue_idx,
230-
func_idx,
231-
});
223+
let continue_idx = label_map.get(continue_label).copied();
224+
// Collect loop body via CFG traversal from header, stopping
225+
// at the merge block. This handles non-contiguous block layouts
226+
// that a simple (header..merge) index range would miss.
227+
let mut body_indices: Vec<usize> = Vec::new();
228+
let mut visited: HashSet<usize> = HashSet::new();
229+
let mut worklist: Vec<usize> = vec![block_idx];
230+
let merge_idx = label_map.get(merge_label).copied();
231+
while let Some(idx) = worklist.pop() {
232+
if !visited.insert(idx) {
233+
continue;
234+
}
235+
// Don't include the merge block itself
236+
if Some(idx) == merge_idx {
237+
continue;
238+
}
239+
body_indices.push(idx);
240+
// Follow all branch targets from this block
241+
if let Some(blk) = func.blocks.get(idx) {
242+
for bi in &blk.instructions {
243+
for op in &bi.operands {
244+
if let Some(target_label) = op.id_ref_any() {
245+
if let Some(&target_idx) =
246+
label_map.get(&target_label)
247+
{
248+
worklist.push(target_idx);
249+
}
250+
}
251+
}
252+
}
253+
}
232254
}
255+
loop_constructs.push(LoopInfo {
256+
body_block_indices: body_indices,
257+
continue_block_idx: continue_idx,
258+
func_idx,
259+
});
233260
}
234261
}
235262
}
@@ -467,12 +494,12 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
467494
.map_err(|e| EgglogOptError::ExecutionError(e.to_string()))?;
468495
theta_bound_ids.insert(id);
469496

470-
// Union original ID with Theta - after saturation, the egraph will
471-
// have propagated LoopInvariant through the expression if applicable
472-
let union_cmd = format!("(union id{} theta_{})", id, id);
473-
egraph
474-
.parse_and_run_program(None, &union_cmd)
475-
.map_err(|e| EgglogOptError::ExecutionError(e.to_string()))?;
497+
// NOTE: We intentionally do NOT union id{N} with theta_{N}.
498+
// Doing so puts FConst(0.0)/Const(0)/BoolConst(0) init values
499+
// into the same e-class as the actual computation, which can
500+
// cause the extractor to pick the zero constant instead of
501+
// the real value. The Theta term exists in the egraph so
502+
// LoopInvariant rules can still detect and mark invariants.
476503
}
477504
}
478505
}
@@ -1870,6 +1897,14 @@ fn remove_dead_instructions(module: &mut Module, true_roots: &HashSet<Word>) ->
18701897
used_ids.insert(id);
18711898
}
18721899
}
1900+
// Also mark operand references from types_global_values as used.
1901+
// OpConstantComposite and other globals may reference function-body
1902+
// IDs that were aliased/rewritten by the optimizer.
1903+
for op in &inst.operands {
1904+
if let Some(ref_id) = op.id_ref_any() {
1905+
used_ids.insert(ref_id);
1906+
}
1907+
}
18731908
}
18741909

18751910
let mut removed_any = false;

0 commit comments

Comments
 (0)