|
14 | 14 | Storage as StorageNode, |
15 | 15 | Toll as TollNode, |
16 | 16 | Compute as ComputeNode, |
17 | | - Reservation, |
18 | 17 | ) |
19 | 18 |
|
20 | 19 | from accelforge.frontend.spec import Spec |
21 | | -from accelforge.frontend._workload_isl._symbolic import ( |
22 | | - get_rank_variable_relevancy, |
23 | | - Irrelevant, |
24 | | -) |
25 | | -from accelforge.frontend.workload import TensorName |
26 | 20 | from accelforge.mapper.FFM._make_pmappings.pmapper_job import Job |
27 | 21 | from accelforge.model._looptree.reuse.symbolic import ( |
28 | 22 | Compute, |
@@ -237,223 +231,6 @@ def _get_tensor_rank_variables(einsum, tensor_name: str) -> set[str]: |
237 | 231 | return rank_vars |
238 | 232 |
|
239 | 233 |
|
240 | | -def _apply_temporal_reuse_corrections( |
241 | | - reuse: SymbolicAnalysisOutput, |
242 | | - spec: Spec, |
243 | | - job: Job, |
244 | | -) -> None: |
245 | | - """Correct inflated fills caused by irrelevant temporal loops in the dense model. |
246 | | -
|
247 | | - The dense model's repeat_temporal multiplies ALL buffet stats (including |
248 | | - lower-level stats that propagate upward) by the temporal iteration count, |
249 | | - regardless of whether the loop variable is relevant to the tensor. When |
250 | | - contiguous innermost irrelevant temporals sit above a storage zone, the |
251 | | - buffer retains data across those iterations — the "temporal reuse" concept |
252 | | - from Sparseloop. |
253 | | -
|
254 | | - This function computes the reuse factor for each buffet by walking the |
255 | | - per-tensor mapping upward from each Storage/Toll node, collecting the |
256 | | - innermost contiguous block of irrelevant temporal iterations (skipping |
257 | | - Spatials and Reservations, continuing through Tolls). It then applies |
258 | | - delta-based corrections to the inflated stats and action counts. |
259 | | -
|
260 | | - Only corrected buffets and their parents are modified — all other buffet |
261 | | - stats remain untouched. |
262 | | - """ |
263 | | - if not hasattr(reuse, "tensor2mapping") or not reuse.tensor2mapping: |
264 | | - return |
265 | | - |
266 | | - workload = spec.workload |
267 | | - einsum_name = job.einsum_name |
268 | | - einsum = workload.einsums[einsum_name] |
269 | | - |
270 | | - for tensor_name, mapping in reuse.tensor2mapping.items(): |
271 | | - relevancy = get_rank_variable_relevancy(einsum, TensorName(tensor_name)) |
272 | | - nodes = mapping.nodes |
273 | | - |
274 | | - # Build a dict of temporal iteration counts by walking top-down |
275 | | - # and tracking the remaining shape at each node. |
276 | | - shape = dict(job.rank_variable_bounds) |
277 | | - node_iterations: dict[int, int] = {} # node_index -> iteration_count |
278 | | - for idx, node in enumerate(nodes): |
279 | | - if isinstance(node, (TemporalNode, SpatialNode)): |
280 | | - rv = str(node.rank_variable) if node.rank_variable else None |
281 | | - if rv and rv in shape and node.tile_shape is not None: |
282 | | - try: |
283 | | - ts = int(node.tile_shape) |
284 | | - remaining = int(shape[rv]) |
285 | | - iters = math.ceil(remaining / ts) if ts > 0 else 1 |
286 | | - node_iterations[idx] = iters |
287 | | - shape[rv] = ts |
288 | | - except (TypeError, ValueError): |
289 | | - pass |
290 | | - |
291 | | - # For each Storage/Toll node that holds this tensor, compute the |
292 | | - # temporal reuse factor from the zone above it. |
293 | | - for i, node in enumerate(nodes): |
294 | | - if not isinstance(node, (StorageNode, TollNode)): |
295 | | - continue |
296 | | - if tensor_name not in [str(t) for t in node.tensors]: |
297 | | - continue |
298 | | - |
299 | | - buffet = Buffet(tensor_name, einsum_name, node.component) |
300 | | - if buffet not in reuse.buffet_stats: |
301 | | - continue |
302 | | - |
303 | | - # Walk upward: collect contiguous innermost irrelevant temporals. |
304 | | - # Skip Spatials, Reservations, and Tolls (pass-through). |
305 | | - # Stop at relevant Temporal or Storage (parent boundary). |
306 | | - reuse_factor = 1 |
307 | | - for j in range(i - 1, -1, -1): |
308 | | - above = nodes[j] |
309 | | - if isinstance(above, (SpatialNode, Reservation)): |
310 | | - continue |
311 | | - if isinstance(above, TollNode): |
312 | | - # Continue through Toll only if it doesn't hold the tensor |
313 | | - # (i.e., it's a pass-through for this tensor's data path) |
314 | | - if tensor_name in [str(t) for t in above.tensors]: |
315 | | - continue |
316 | | - continue |
317 | | - if isinstance(above, TemporalNode): |
318 | | - rv = str(above.rank_variable) if above.rank_variable else None |
319 | | - if rv and isinstance(relevancy.get(rv), Irrelevant): |
320 | | - iters = node_iterations.get(j, 1) |
321 | | - if iters > 1: |
322 | | - reuse_factor *= iters |
323 | | - continue |
324 | | - else: |
325 | | - break # Relevant temporal → end of contiguous block |
326 | | - if isinstance(above, StorageNode): |
327 | | - break # Parent storage boundary |
328 | | - |
329 | | - if reuse_factor <= 1: |
330 | | - continue |
331 | | - |
332 | | - # Delta-based correction: only modify this buffet and its parent. |
333 | | - stats = reuse.buffet_stats[buffet] |
334 | | - reduction = 1.0 - 1.0 / reuse_factor # fraction to subtract |
335 | | - |
336 | | - # Save old values for delta computation. |
337 | | - old_reads_to_parent = float(stats.total_reads_to_parent) |
338 | | - old_max_reads_to_parent = float(stats.max_per_parent_reads_to_parent) |
339 | | - old_skip_reads = float(stats.total_skipped_first_reads_to_parent) |
340 | | - old_min_skip_reads = float( |
341 | | - stats.min_per_parent_skipped_first_reads_to_parent |
342 | | - ) |
343 | | - |
344 | | - # Correct element counts. |
345 | | - inv = 1.0 / reuse_factor |
346 | | - stats.total_reads_to_parent *= inv |
347 | | - stats.max_per_parent_reads_to_parent *= inv |
348 | | - stats.total_skipped_first_reads_to_parent *= inv |
349 | | - stats.min_per_parent_skipped_first_reads_to_parent *= inv |
350 | | - |
351 | | - # Correct this buffet's fill action counts (write_actions from fills). |
352 | | - component_obj = spec.arch.find(buffet.level) |
353 | | - if not isinstance(component_obj, arch.TensorHolder): |
354 | | - continue |
355 | | - ta = _find_tensor_access(einsum, buffet.tensor) |
356 | | - if ta is None: |
357 | | - continue |
358 | | - count_writes = not isinstance(component_obj, arch.Toll) |
359 | | - if count_writes: |
360 | | - bpvs = component_obj.bits_per_value_scale[buffet.tensor] |
361 | | - bpv = bpvs * ta.bits_per_value |
362 | | - write_bpa = component_obj.actions["write"].bits_per_action |
363 | | - write_scale = bpv / write_bpa |
364 | | - |
365 | | - delta_write = old_reads_to_parent * reduction * write_scale |
366 | | - stats.total_write_actions -= delta_write |
367 | | - stats.max_per_unit_write_actions -= delta_write |
368 | | - delta_skip_write = old_skip_reads * reduction * write_scale |
369 | | - stats.total_skipped_first_write_actions -= delta_skip_write |
370 | | - stats.min_per_unit_skipped_first_write_actions -= delta_skip_write |
371 | | - |
372 | | - # Propagate correction upward through the buffet chain. |
373 | | - # Tolls with propagate_child_results add child.reads_to_parent |
374 | | - # to their own reads_to_parent via inherit_add BEFORE spatial/ |
375 | | - # temporal multiplications. Thus the absolute delta in |
376 | | - # reads_to_parent is the same at every level in the chain. |
377 | | - # |
378 | | - # Walk up: correct each Toll's reads_to_parent and action counts, |
379 | | - # then correct the first Storage parent's read action counts. |
380 | | - delta_reads = old_reads_to_parent * reduction |
381 | | - delta_max_reads = old_max_reads_to_parent * reduction |
382 | | - delta_skip = old_skip_reads * reduction |
383 | | - delta_min_skip = old_min_skip_reads * reduction |
384 | | - |
385 | | - cur = buffet |
386 | | - while True: |
387 | | - parent_buffet = _get_parent_buffet(reuse, cur) |
388 | | - if parent_buffet is None: |
389 | | - break |
390 | | - parent_stats = reuse.buffet_stats[parent_buffet] |
391 | | - parent_obj = spec.arch.find(parent_buffet.level) |
392 | | - if not isinstance(parent_obj, arch.TensorHolder): |
393 | | - break |
394 | | - |
395 | | - p_bpvs = parent_obj.bits_per_value_scale[parent_buffet.tensor] |
396 | | - p_bpv = p_bpvs * ta.bits_per_value |
397 | | - p_read_bpa = parent_obj.actions["read"].bits_per_action |
398 | | - p_read_scale = p_bpv / p_read_bpa |
399 | | - is_toll = isinstance(parent_obj, arch.Toll) |
400 | | - |
401 | | - if is_toll: |
402 | | - # Toll: correct its reads_to_parent (inherited from child) |
403 | | - # and continue upward. |
404 | | - parent_stats.total_reads_to_parent -= delta_reads |
405 | | - parent_stats.max_per_parent_reads_to_parent -= delta_max_reads |
406 | | - parent_stats.total_skipped_first_reads_to_parent -= delta_skip |
407 | | - parent_stats.min_per_parent_skipped_first_reads_to_parent -= ( |
408 | | - delta_min_skip |
409 | | - ) |
410 | | - # Toll read_actions (serving child) — usually 0 energy. |
411 | | - parent_stats.total_read_actions -= delta_reads * p_read_scale |
412 | | - parent_stats.max_per_unit_read_actions -= ( |
413 | | - delta_max_reads * p_read_scale |
414 | | - ) |
415 | | - parent_stats.total_skipped_first_read_actions -= ( |
416 | | - delta_skip * p_read_scale |
417 | | - ) |
418 | | - parent_stats.min_per_unit_skipped_first_read_actions -= ( |
419 | | - delta_min_skip * p_read_scale |
420 | | - ) |
421 | | - cur = parent_buffet |
422 | | - continue |
423 | | - else: |
424 | | - # Storage: correct read actions from serving child fills. |
425 | | - parent_stats.total_read_actions -= delta_reads * p_read_scale |
426 | | - parent_stats.max_per_unit_read_actions -= ( |
427 | | - delta_max_reads * p_read_scale |
428 | | - ) |
429 | | - parent_stats.total_skipped_first_read_actions -= ( |
430 | | - delta_skip * p_read_scale |
431 | | - ) |
432 | | - parent_stats.min_per_unit_skipped_first_read_actions -= ( |
433 | | - delta_min_skip * p_read_scale |
434 | | - ) |
435 | | - break |
436 | | - |
437 | | - |
438 | | -def _get_parent_buffet( |
439 | | - reuse: SymbolicAnalysisOutput, |
440 | | - buffet: Buffet, |
441 | | -) -> Buffet | None: |
442 | | - """Find the parent (outer-level) Buffet key for the same tensor. |
443 | | -
|
444 | | - buffet_stats are ordered inner-to-outer, so the parent is the next |
445 | | - matching entry after the current buffet in forward iteration order. |
446 | | - """ |
447 | | - seen = False |
448 | | - for b in reuse.buffet_stats: |
449 | | - if not seen: |
450 | | - seen = b == buffet |
451 | | - continue |
452 | | - if b.tensor == buffet.tensor and b.einsum == buffet.einsum: |
453 | | - return b |
454 | | - return None |
455 | | - |
456 | | - |
457 | 234 | def _compute_buffet_tile_shapes( |
458 | 235 | reuse: SymbolicAnalysisOutput, |
459 | 236 | job: Job, |
|
0 commit comments