Skip to content

Commit 99e6989

Browse files
fisherxueclaude
andcommitted
Remove _apply_temporal_reuse_corrections post-processing step
Temporal reuse should be expressed purely through mapping structure (placing Storage nodes above or below loops), not detected implicitly. If a tensor's Storage is below an irrelevant loop, fills are inflated — and that's correct per the mapping. Users can split Storage nodes to control reuse explicitly. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 769aca9 commit 99e6989

2 files changed

Lines changed: 0 additions & 228 deletions

File tree

accelforge/model/run_model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from accelforge.model.sparse_adjustment import (
1919
apply_sparse_adjustments,
2020
LatencyInfo,
21-
_apply_temporal_reuse_corrections,
2221
)
2322
from accelforge.mapper.FFM._join_pmappings.pmapping_dataframe import (
2423
nameloop2col,
@@ -50,10 +49,6 @@ def run_model(
5049
job, add_reservations=add_reservations
5150
)
5251

53-
# Temporal reuse correction: divide inflated parent-facing stats for
54-
# buffers that sit inside contiguous irrelevant temporal loops.
55-
_apply_temporal_reuse_corrections(reuse, spec, job)
56-
5752
# Phase 1: Dense latency (before sparse adjustments)
5853
latency = component_latency(reuse, job.flattened_arch, pmapping, spec)
5954
try:

accelforge/model/sparse_adjustment.py

Lines changed: 0 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@
1414
Storage as StorageNode,
1515
Toll as TollNode,
1616
Compute as ComputeNode,
17-
Reservation,
1817
)
1918

2019
from accelforge.frontend.spec import Spec
21-
from accelforge.frontend._workload_isl._symbolic import (
22-
get_rank_variable_relevancy,
23-
Irrelevant,
24-
)
25-
from accelforge.frontend.workload import TensorName
2620
from accelforge.mapper.FFM._make_pmappings.pmapper_job import Job
2721
from accelforge.model._looptree.reuse.symbolic import (
2822
Compute,
@@ -237,223 +231,6 @@ def _get_tensor_rank_variables(einsum, tensor_name: str) -> set[str]:
237231
return rank_vars
238232

239233

240-
def _apply_temporal_reuse_corrections(
241-
reuse: SymbolicAnalysisOutput,
242-
spec: Spec,
243-
job: Job,
244-
) -> None:
245-
"""Correct inflated fills caused by irrelevant temporal loops in the dense model.
246-
247-
The dense model's repeat_temporal multiplies ALL buffet stats (including
248-
lower-level stats that propagate upward) by the temporal iteration count,
249-
regardless of whether the loop variable is relevant to the tensor. When
250-
contiguous innermost irrelevant temporals sit above a storage zone, the
251-
buffer retains data across those iterations — the "temporal reuse" concept
252-
from Sparseloop.
253-
254-
This function computes the reuse factor for each buffet by walking the
255-
per-tensor mapping upward from each Storage/Toll node, collecting the
256-
innermost contiguous block of irrelevant temporal iterations (skipping
257-
Spatials and Reservations, continuing through Tolls). It then applies
258-
delta-based corrections to the inflated stats and action counts.
259-
260-
Only corrected buffets and their parents are modified — all other buffet
261-
stats remain untouched.
262-
"""
263-
if not hasattr(reuse, "tensor2mapping") or not reuse.tensor2mapping:
264-
return
265-
266-
workload = spec.workload
267-
einsum_name = job.einsum_name
268-
einsum = workload.einsums[einsum_name]
269-
270-
for tensor_name, mapping in reuse.tensor2mapping.items():
271-
relevancy = get_rank_variable_relevancy(einsum, TensorName(tensor_name))
272-
nodes = mapping.nodes
273-
274-
# Build a dict of temporal iteration counts by walking top-down
275-
# and tracking the remaining shape at each node.
276-
shape = dict(job.rank_variable_bounds)
277-
node_iterations: dict[int, int] = {} # node_index -> iteration_count
278-
for idx, node in enumerate(nodes):
279-
if isinstance(node, (TemporalNode, SpatialNode)):
280-
rv = str(node.rank_variable) if node.rank_variable else None
281-
if rv and rv in shape and node.tile_shape is not None:
282-
try:
283-
ts = int(node.tile_shape)
284-
remaining = int(shape[rv])
285-
iters = math.ceil(remaining / ts) if ts > 0 else 1
286-
node_iterations[idx] = iters
287-
shape[rv] = ts
288-
except (TypeError, ValueError):
289-
pass
290-
291-
# For each Storage/Toll node that holds this tensor, compute the
292-
# temporal reuse factor from the zone above it.
293-
for i, node in enumerate(nodes):
294-
if not isinstance(node, (StorageNode, TollNode)):
295-
continue
296-
if tensor_name not in [str(t) for t in node.tensors]:
297-
continue
298-
299-
buffet = Buffet(tensor_name, einsum_name, node.component)
300-
if buffet not in reuse.buffet_stats:
301-
continue
302-
303-
# Walk upward: collect contiguous innermost irrelevant temporals.
304-
# Skip Spatials, Reservations, and Tolls (pass-through).
305-
# Stop at relevant Temporal or Storage (parent boundary).
306-
reuse_factor = 1
307-
for j in range(i - 1, -1, -1):
308-
above = nodes[j]
309-
if isinstance(above, (SpatialNode, Reservation)):
310-
continue
311-
if isinstance(above, TollNode):
312-
# Continue through Toll only if it doesn't hold the tensor
313-
# (i.e., it's a pass-through for this tensor's data path)
314-
if tensor_name in [str(t) for t in above.tensors]:
315-
continue
316-
continue
317-
if isinstance(above, TemporalNode):
318-
rv = str(above.rank_variable) if above.rank_variable else None
319-
if rv and isinstance(relevancy.get(rv), Irrelevant):
320-
iters = node_iterations.get(j, 1)
321-
if iters > 1:
322-
reuse_factor *= iters
323-
continue
324-
else:
325-
break # Relevant temporal → end of contiguous block
326-
if isinstance(above, StorageNode):
327-
break # Parent storage boundary
328-
329-
if reuse_factor <= 1:
330-
continue
331-
332-
# Delta-based correction: only modify this buffet and its parent.
333-
stats = reuse.buffet_stats[buffet]
334-
reduction = 1.0 - 1.0 / reuse_factor # fraction to subtract
335-
336-
# Save old values for delta computation.
337-
old_reads_to_parent = float(stats.total_reads_to_parent)
338-
old_max_reads_to_parent = float(stats.max_per_parent_reads_to_parent)
339-
old_skip_reads = float(stats.total_skipped_first_reads_to_parent)
340-
old_min_skip_reads = float(
341-
stats.min_per_parent_skipped_first_reads_to_parent
342-
)
343-
344-
# Correct element counts.
345-
inv = 1.0 / reuse_factor
346-
stats.total_reads_to_parent *= inv
347-
stats.max_per_parent_reads_to_parent *= inv
348-
stats.total_skipped_first_reads_to_parent *= inv
349-
stats.min_per_parent_skipped_first_reads_to_parent *= inv
350-
351-
# Correct this buffet's fill action counts (write_actions from fills).
352-
component_obj = spec.arch.find(buffet.level)
353-
if not isinstance(component_obj, arch.TensorHolder):
354-
continue
355-
ta = _find_tensor_access(einsum, buffet.tensor)
356-
if ta is None:
357-
continue
358-
count_writes = not isinstance(component_obj, arch.Toll)
359-
if count_writes:
360-
bpvs = component_obj.bits_per_value_scale[buffet.tensor]
361-
bpv = bpvs * ta.bits_per_value
362-
write_bpa = component_obj.actions["write"].bits_per_action
363-
write_scale = bpv / write_bpa
364-
365-
delta_write = old_reads_to_parent * reduction * write_scale
366-
stats.total_write_actions -= delta_write
367-
stats.max_per_unit_write_actions -= delta_write
368-
delta_skip_write = old_skip_reads * reduction * write_scale
369-
stats.total_skipped_first_write_actions -= delta_skip_write
370-
stats.min_per_unit_skipped_first_write_actions -= delta_skip_write
371-
372-
# Propagate correction upward through the buffet chain.
373-
# Tolls with propagate_child_results add child.reads_to_parent
374-
# to their own reads_to_parent via inherit_add BEFORE spatial/
375-
# temporal multiplications. Thus the absolute delta in
376-
# reads_to_parent is the same at every level in the chain.
377-
#
378-
# Walk up: correct each Toll's reads_to_parent and action counts,
379-
# then correct the first Storage parent's read action counts.
380-
delta_reads = old_reads_to_parent * reduction
381-
delta_max_reads = old_max_reads_to_parent * reduction
382-
delta_skip = old_skip_reads * reduction
383-
delta_min_skip = old_min_skip_reads * reduction
384-
385-
cur = buffet
386-
while True:
387-
parent_buffet = _get_parent_buffet(reuse, cur)
388-
if parent_buffet is None:
389-
break
390-
parent_stats = reuse.buffet_stats[parent_buffet]
391-
parent_obj = spec.arch.find(parent_buffet.level)
392-
if not isinstance(parent_obj, arch.TensorHolder):
393-
break
394-
395-
p_bpvs = parent_obj.bits_per_value_scale[parent_buffet.tensor]
396-
p_bpv = p_bpvs * ta.bits_per_value
397-
p_read_bpa = parent_obj.actions["read"].bits_per_action
398-
p_read_scale = p_bpv / p_read_bpa
399-
is_toll = isinstance(parent_obj, arch.Toll)
400-
401-
if is_toll:
402-
# Toll: correct its reads_to_parent (inherited from child)
403-
# and continue upward.
404-
parent_stats.total_reads_to_parent -= delta_reads
405-
parent_stats.max_per_parent_reads_to_parent -= delta_max_reads
406-
parent_stats.total_skipped_first_reads_to_parent -= delta_skip
407-
parent_stats.min_per_parent_skipped_first_reads_to_parent -= (
408-
delta_min_skip
409-
)
410-
# Toll read_actions (serving child) — usually 0 energy.
411-
parent_stats.total_read_actions -= delta_reads * p_read_scale
412-
parent_stats.max_per_unit_read_actions -= (
413-
delta_max_reads * p_read_scale
414-
)
415-
parent_stats.total_skipped_first_read_actions -= (
416-
delta_skip * p_read_scale
417-
)
418-
parent_stats.min_per_unit_skipped_first_read_actions -= (
419-
delta_min_skip * p_read_scale
420-
)
421-
cur = parent_buffet
422-
continue
423-
else:
424-
# Storage: correct read actions from serving child fills.
425-
parent_stats.total_read_actions -= delta_reads * p_read_scale
426-
parent_stats.max_per_unit_read_actions -= (
427-
delta_max_reads * p_read_scale
428-
)
429-
parent_stats.total_skipped_first_read_actions -= (
430-
delta_skip * p_read_scale
431-
)
432-
parent_stats.min_per_unit_skipped_first_read_actions -= (
433-
delta_min_skip * p_read_scale
434-
)
435-
break
436-
437-
438-
def _get_parent_buffet(
439-
reuse: SymbolicAnalysisOutput,
440-
buffet: Buffet,
441-
) -> Buffet | None:
442-
"""Find the parent (outer-level) Buffet key for the same tensor.
443-
444-
buffet_stats are ordered inner-to-outer, so the parent is the next
445-
matching entry after the current buffet in forward iteration order.
446-
"""
447-
seen = False
448-
for b in reuse.buffet_stats:
449-
if not seen:
450-
seen = b == buffet
451-
continue
452-
if b.tensor == buffet.tensor and b.einsum == buffet.einsum:
453-
return b
454-
return None
455-
456-
457234
def _compute_buffet_tile_shapes(
458235
reuse: SymbolicAnalysisOutput,
459236
job: Job,

0 commit comments

Comments
 (0)