Skip to content

Commit 0b8550e

Browse files
committed
Add temporal reuse tracking for weight fills at shared_glb
1 parent a101dd9 commit 0b8550e

6 files changed

Lines changed: 60 additions & 6 deletions

File tree

accelforge/model/_looptree/reuse/symbolic/symbolic.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ class BuffetStats:
146146

147147
persistent: bool = field(default=False)
148148

149+
# Temporal reuse tracking: True if a relevant temporal loop has processed
150+
# this buffet since the last Storage node set total_reads_to_parent.
151+
# When False and an irrelevant temporal is encountered, parent-facing attrs
152+
# are not multiplied (the buffer persists across irrelevant iterations).
153+
_has_relevant_temporal_above: bool = field(default=False)
154+
149155
@property
150156
def n_loops_above(self) -> int:
151157
if self.persistent:
@@ -158,14 +164,22 @@ def n_loops_above(self, value: int):
158164

159165
def repeat_temporal(self, factor: int, is_fully_relevant: bool) -> "BuffetStats":
160166
new = copy.copy(self)
167+
# Temporal reuse: if the loop is irrelevant and no relevant temporal
168+
# has intervened since the Storage node set parent-facing stats, the
169+
# buffer persists across iterations — skip parent-facing attrs.
170+
skip_parent = not is_fully_relevant and not self._has_relevant_temporal_above
161171
for attr in self.__dict__:
162172
if not attr.startswith(("total_", "max_", "min_")):
163173
continue
164174
if "skipped_first" in attr and not is_fully_relevant:
165175
continue # First actions occur once per relevant iteration.
166176
if attr == "max_occupancy":
167177
continue # Max occupancy is not affected by temporal loops above
178+
if "parent" in attr and skip_parent:
179+
continue # Temporal reuse: buffer persists across irrelevant iters.
168180
setattr(new, attr, getattr(new, attr) * factor)
181+
if is_fully_relevant:
182+
new._has_relevant_temporal_above = True
169183
return new
170184

171185
def repeat_spatial(self, factor: int, reuse_parent_accesses: bool) -> "BuffetStats":
@@ -204,7 +218,10 @@ def min(self, **kwargs: Any):
204218
def __add__(self, other: "BuffetStats") -> "BuffetStats":
205219
new = copy.copy(self)
206220
for attr in self.__dict__:
207-
if attr.startswith("min_"):
221+
if attr == "_has_relevant_temporal_above":
222+
# Combine conservatively: if either has relevant above, so does result
223+
setattr(new, attr, getattr(self, attr) or getattr(other, attr))
224+
elif attr.startswith("min_"):
208225
setattr(
209226
new, attr, min_nonzero(getattr(self, attr), getattr(other, attr))
210227
)
@@ -1178,6 +1195,11 @@ def inherit_add(attr: str, default_value: Any = fills) -> Any:
11781195
inherit_add("total_skipped_first_reads_to_parent")
11791196
inherit_add("min_per_parent_skipped_first_reads_to_parent")
11801197

1198+
# Reset temporal reuse tracking: this Storage node just set fresh
1199+
# parent-facing stats; irrelevant temporals above should not
1200+
# multiply them until a relevant temporal intervenes.
1201+
stats._has_relevant_temporal_above = False
1202+
11811203
# ==============================================================================
11821204
# Convert to actions. These are not used used upward; they are used to get
11831205
# energy and latency.
@@ -1354,6 +1376,10 @@ def analyze_compute(
13541376
stats.total_skipped_first_reads_to_parent = 1
13551377
stats.min_per_parent_skipped_first_reads_to_parent = 1
13561378
stats.max_occupancy = 1
1379+
# Compute-level accesses have no buffering: every iteration reads from
1380+
# parent regardless of relevancy. Mark as having a "relevant temporal
1381+
# above" so that irrelevant temporal loops still multiply parent attrs.
1382+
stats._has_relevant_temporal_above = True
13571383
result_accumulator.buffet_stats[buffet] = stats
13581384

13591385
network_node = info.job.spec.arch.find_first_of_type_above(

tests/input_files/table7/mapping_conv1.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ mapping:
2828
rank_variable: c
2929
tile_shape: 1
3030

31-
# === shared_glb level (Inputs + Outputs; Weights bypass) ===
31+
# === shared_glb level ===
3232
- !Storage
3333
tensors: [Inputs, Outputs]
3434
component: shared_glb
@@ -37,6 +37,12 @@ mapping:
3737
- !Temporal
3838
rank_variable: m
3939
tile_shape: 32
40+
41+
# Weights buffered at shared_glb (split: m relevant above, p irrelevant below)
42+
- !Storage
43+
tensors: [Weights]
44+
component: shared_glb
45+
4046
- !Temporal
4147
rank_variable: p
4248
tile_shape: 1

tests/input_files/table7/mapping_conv2.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,17 @@ mapping:
4141
name: X
4242
component: PEColumns
4343

44-
# shared_glb temporal (outer→inner: M, N, P)
44+
# shared_glb temporal: M (W-relevant)
4545
- !Temporal
4646
rank_variable: m
4747
tile_shape: 16
48+
49+
# Weights buffered at shared_glb (split: m relevant above, n/p irrelevant below)
50+
- !Storage
51+
tensors: [Weights]
52+
component: shared_glb
53+
54+
# shared_glb temporal: N, P (W-irrelevant, below shared_glb[W] for reuse)
4855
- !Temporal
4956
rank_variable: n
5057
tile_shape: 1

tests/input_files/table7/mapping_conv3.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ mapping:
3535
name: X
3636
component: PEColumns
3737

38-
# shared_glb temporal (outer→inner: N, P)
38+
# Weights buffered at shared_glb (split: n/p irrelevant below for reuse)
39+
- !Storage
40+
tensors: [Weights]
41+
component: shared_glb
42+
43+
# shared_glb temporal: N, P (W-irrelevant, below shared_glb[W] for reuse)
3944
- !Temporal
4045
rank_variable: n
4146
tile_shape: 1

tests/input_files/table7/mapping_conv4.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ mapping:
3636
name: X
3737
component: PEColumns
3838

39-
# shared_glb temporal (outer→inner: P)
39+
# Weights buffered at shared_glb (split: p irrelevant below for reuse)
40+
- !Storage
41+
tensors: [Weights]
42+
component: shared_glb
43+
44+
# shared_glb temporal: P (W-irrelevant, below shared_glb[W] for reuse)
4045
- !Temporal
4146
rank_variable: p
4247
tile_shape: 1

tests/input_files/table7/mapping_conv5.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ mapping:
3636
name: X
3737
component: PEColumns
3838

39-
# shared_glb temporal (outer→inner: P)
39+
# Weights buffered at shared_glb (split: p irrelevant below for reuse)
40+
- !Storage
41+
tensors: [Weights]
42+
component: shared_glb
43+
44+
# shared_glb temporal: P (W-irrelevant, below shared_glb[W] for reuse)
4045
- !Temporal
4146
rank_variable: p
4247
tile_shape: 1

0 commit comments

Comments
 (0)