Skip to content

Commit f8a0430

Browse files
committed
wip
1 parent e65a70d commit f8a0430

3 files changed

Lines changed: 116 additions & 99 deletions

File tree

accelforge/mapper/FFM/_join_pmappings/join_pmappings.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -213,32 +213,27 @@ def join_strategy_2(
213213
import json
214214
with open(_runtime_log_file, "a") as f:
215215
f.write(json.dumps({"round": i, "threshold": threshold}) + "\n")
216-
try:
217-
cur_compressed = prune_with_tolerance(
218-
compressed,
219-
objective_tolerance=threshold,
220-
resource_usage_tolerance=resource_usage_tolerance,
221-
print_progress=print_progress,
222-
)
223-
joined = join_pmappings(
224-
cur_compressed,
225-
spec,
226-
_pmapping_row_filter_function=filter_func,
227-
print_progress=print_progress,
228-
metrics=metrics,
216+
217+
cur_compressed = prune_with_tolerance(
218+
compressed,
219+
objective_tolerance=threshold,
220+
resource_usage_tolerance=resource_usage_tolerance,
221+
print_progress=print_progress,
222+
)
223+
joined = join_pmappings(
224+
cur_compressed,
225+
spec,
226+
_pmapping_row_filter_function=filter_func,
227+
print_progress=print_progress,
228+
metrics=metrics,
229+
)
230+
if i < len(thresholds) - 1:
231+
filter_func = OptimalityThresholder(
232+
joined,
233+
_pmapping_row_filter_function,
234+
spec.mapper._metric_aggregator,
235+
print_progress
229236
)
230-
if i < len(thresholds) - 1:
231-
filter_func = OptimalityThresholder(
232-
joined,
233-
_pmapping_row_filter_function,
234-
spec.mapper._metric_aggregator,
235-
print_progress
236-
)
237-
except Exception as e:
238-
if i == len(thresholds) - 1:
239-
raise
240-
if print_progress:
241-
print(f"Error with optimality threshold {threshold}: {e}")
242237

243238
return joined
244239

accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py

Lines changed: 84 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,23 @@ def reduce_precision(data: pd.DataFrame) -> pd.DataFrame:
9494

9595
class PmappingDataframe:
9696
def _assert_invariants_before_and_after(f):
97+
@functools.wraps(f)
9798
def wrapped(self, *args, **kwargs):
98-
self._assert_reservation_includes_live_tensors()
99-
self._assert_consistent_left_right_reservations()
100-
self._assert_reservation_inclusivity()
99+
try:
100+
self._assert_reservation_includes_live_tensors()
101+
self._assert_consistent_left_right_reservations()
102+
self._assert_reservation_inclusivity()
103+
except Exception as e:
104+
raise ValueError(f"broken invariance before calling {f}") from e
105+
101106
result = f(self, *args, **kwargs)
102-
self._assert_reservation_includes_live_tensors()
103-
self._assert_consistent_left_right_reservations()
104-
self._assert_reservation_inclusivity()
107+
108+
try:
109+
self._assert_reservation_includes_live_tensors()
110+
self._assert_consistent_left_right_reservations()
111+
self._assert_reservation_inclusivity()
112+
except Exception as e:
113+
raise ValueError(f"broken invariance after calling {f}") from e
105114
return result
106115
return wrapped
107116

@@ -145,7 +154,7 @@ def __init__(
145154
)
146155

147156
if fill_reservation_cols: # Affects PmappingDataframe so must go before
148-
self.fill_reservation_cols(fill_reservation_cols)
157+
self._fill_reservation_cols(fill_reservation_cols)
149158
if check_above_subset_below:
150159
self.check_above_subset_below()
151160

@@ -171,49 +180,6 @@ def rename(self, renames: dict[str, str]) -> "PmappingDataframe":
171180
new.data.rename(columns=renames, inplace=True)
172181
return new
173182

174-
@error_check_wrapper
175-
def fill_reservation_cols(self, columns: set | str):
176-
l_reservations, r_reservations = self._make_reservations()
177-
targets = []
178-
if columns == "auto":
179-
for left, reservations_dict in [
180-
(True, l_reservations),
181-
(False, r_reservations),
182-
]:
183-
for resource, reservations in reservations_dict.items():
184-
for r in sorted(reservations):
185-
above = _get_reservation_or_parent(
186-
resource, r - 1, l_reservations, r_reservations
187-
)
188-
if above is not None:
189-
below = reservation2col(resource, r, left=left)
190-
targets.append((r, above, below))
191-
else:
192-
for below in columns:
193-
if (name_nloops := col2reservation(below)) is None:
194-
raise ValueError(f"{below} is not a valid reservation column")
195-
name, nloops = name_nloops.name, name_nloops.nloops
196-
above = _get_reservation_or_parent(
197-
name, nloops - 1, l_reservations, r_reservations
198-
)
199-
if above is not None:
200-
targets.append((nloops, above, below))
201-
202-
# Sort so we go from top to bottom. Needed in case we have to max 0->1
203-
# then 1->2
204-
for _, above, below in sorted(targets, key=lambda x: x[0]):
205-
assert (
206-
above in self.data.columns
207-
), f"Missing column {above}. Have columns:\n\t" + "\n\t".join(
208-
list(self.data.columns)
209-
)
210-
assert (
211-
below in self.data.columns
212-
), f"Missing column {below}. Have columns:\n\t" + "\n\t".join(
213-
list(self.data.columns)
214-
)
215-
max_to_col(self.data, below, above)
216-
217183
@property
218184
def data(self) -> pd.DataFrame:
219185
return self._data
@@ -265,6 +231,9 @@ def free_to_loop_index(self, loop_index: int) -> bool:
265231
if loop_index < -1:
266232
raise ValueError("loop_index must be >= -1")
267233

234+
if len(self.data) == 0:
235+
return False
236+
268237
# We keep reservations under loop_index, which is index loop_index+1
269238
reservation_max_index = loop_index + 1
270239

@@ -425,12 +394,6 @@ def check_match(la: Loop, lb: Loop, param: str):
425394
df = pd.merge(sd, rd, how="cross", suffixes=["", MERGE_SUFFIX])
426395

427396
df = reduce_precision(df)
428-
# We made a new column! Update our reservations so future iterations
429-
# know about it.
430-
l_reservations, r_reservations = self._make_reservations()
431-
432-
# Assert all reservations are >= 0
433-
assert (self.data[target] >= 0).all(), f"Negative reservation: {target}"
434397

435398
# Drop all fused loop columns that are not used anymore
436399
remaining_symbols = compatibility_joined.symbols()
@@ -439,6 +402,12 @@ def check_match(la: Loop, lb: Loop, param: str):
439402
]
440403
df = df.drop(columns=dropcols)
441404

405+
# Making sure the column you want originated from the right pmapping
406+
# is tricky. Use this function!
407+
def col_from_right_pmapping(col):
408+
col_with_suffix = col + MERGE_SUFFIX
409+
return col_with_suffix if col_with_suffix in df else col
410+
442411
# Number of combinations
443412
n_total_pmappings = self.n_total_pmappings * right.n_total_pmappings
444413
n_valid_pmappings = self.n_valid_pmappings * right.n_valid_pmappings
@@ -481,9 +450,7 @@ def check_match(la: Loop, lb: Loop, param: str):
481450
adjustment_val = adjustment[nloops]
482451
if nloops not in right_df_r_reservations[resource]:
483452
raise RuntimeError("bug")
484-
target = reservation2col(resource, nloops)
485-
if target not in df:
486-
target += MERGE_SUFFIX
453+
target = col_from_right_pmapping(reservation2col(resource, nloops))
487454
add_to_col(df, target, adjustment_val)
488455

489456
# Make sure everything is done in increasing loop order so we don't have
@@ -523,9 +490,7 @@ def iter_reservations(reservations_dict):
523490
)
524491
) is None:
525492
continue
526-
right_merge_source = source + MERGE_SUFFIX
527-
if right_merge_source in df:
528-
source = right_merge_source
493+
source = col_from_right_pmapping(source)
529494
for target in get_reservation_cols_with(
530495
df,
531496
name=resource,
@@ -546,9 +511,7 @@ def iter_reservations(reservations_dict):
546511
)
547512
) is None:
548513
continue
549-
right_merge_source = source + MERGE_SUFFIX
550-
if right_merge_source in df:
551-
source = right_merge_source
514+
source = col_from_right_pmapping(source)
552515
target = reservation2col(resource, nloops)
553516
add_to_col(df, target, source)
554517

@@ -568,7 +531,6 @@ def iter_reservations(reservations_dict):
568531
max_to_col(df, target, source)
569532
else:
570533
add_to_col(df, target, source)
571-
complete_df = df
572534
df = df.drop(columns=dropcols)
573535

574536
result = PmappingDataframe(
@@ -761,6 +723,49 @@ def has_reservations(self):
761723
# ============================================================================
762724
# Helper functions
763725
# ============================================================================
726+
@error_check_wrapper
727+
def _fill_reservation_cols(self, columns: set | str):
728+
l_reservations, r_reservations = self._make_reservations()
729+
targets = []
730+
if columns == "auto":
731+
for left, reservations_dict in [
732+
(True, l_reservations),
733+
(False, r_reservations),
734+
]:
735+
for resource, reservations in reservations_dict.items():
736+
for r in sorted(reservations):
737+
above = _get_reservation_or_parent(
738+
resource, r - 1, l_reservations, r_reservations
739+
)
740+
if above is not None:
741+
below = reservation2col(resource, r, left=left)
742+
targets.append((r, above, below))
743+
else:
744+
for below in columns:
745+
if (name_nloops := col2reservation(below)) is None:
746+
raise ValueError(f"{below} is not a valid reservation column")
747+
name, nloops = name_nloops.name, name_nloops.nloops
748+
above = _get_reservation_or_parent(
749+
name, nloops - 1, l_reservations, r_reservations
750+
)
751+
if above is not None:
752+
targets.append((nloops, above, below))
753+
754+
# Sort so we go from top to bottom. Needed in case we have to max 0->1
755+
# then 1->2
756+
for _, above, below in sorted(targets, key=lambda x: x[0]):
757+
assert (
758+
above in self.data.columns
759+
), f"Missing column {above}. Have columns:\n\t" + "\n\t".join(
760+
list(self.data.columns)
761+
)
762+
assert (
763+
below in self.data.columns
764+
), f"Missing column {below}. Have columns:\n\t" + "\n\t".join(
765+
list(self.data.columns)
766+
)
767+
max_to_col(self.data, below, above)
768+
764769
def _free_live_reservation_to_loop_index(self, loop_index: int):
765770
reservation_index = loop_index+1
766771
dropcols = []
@@ -906,13 +911,17 @@ def _shift_bottom_reservation_left(self):
906911
thread=DEFAULT_THREAD,
907912
):
908913
right_key = col2live_reservation(live_tensor_in_right)
909-
new_live_tensor = live_reservation2col(
910-
resource,
911-
right_key.tensor,
912-
bottom_loop_index,
913-
thread_i
914-
)
915-
max_to_col(df, new_live_tensor, live_tensor_in_right)
914+
for thread_j in left_concurrent_threads:
915+
new_live_tensor = live_reservation2col(
916+
resource,
917+
right_key.tensor,
918+
bottom_loop_index,
919+
thread_j
920+
)
921+
if thread_j == thread_i:
922+
max_to_col(df, new_live_tensor, live_tensor_in_right)
923+
else:
924+
df.loc[:,new_live_tensor] = 0
916925
df.drop(columns=[live_tensor_in_right], inplace=True)
917926

918927
for thread_j in left_concurrent_threads:
@@ -1174,6 +1183,7 @@ def _assert_reservation_includes_live_tensors(self):
11741183
):
11751184
continue
11761185
if (self.data[res_col] < self.data[col]).any():
1186+
breakpoint()
11771187
raise RuntimeError(f"reservation smaller than reservation for live tensor {col}")
11781188

11791189
def _assert_reservation_inclusivity(self):
@@ -1375,3 +1385,4 @@ def _get_reservation_or_parent(
13751385
left = False
13761386
level -= 1
13771387
return None
1388+

tests/test_mapper.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,21 @@ def test_tpuv4i_gpt3(self):
4040
spec = Spec.from_yaml(
4141
af.examples.arches.tpu_v4i,
4242
af.examples.workloads.gpt3_6_7B,
43+
jinja_parse_data={"N_TOKENS": 2048},
4344
)
4445
spec.mapper.metrics = Metrics.ENERGY | Metrics.LATENCY
4546
spec.mapper.n_concurrent_threads = 2
46-
mappings = spec.map_workload_to_arch()
47+
# mappings = spec.map_workload_to_arch()
48+
49+
import pickle
50+
51+
pmappings = af.mapper.FFM.make_pmappings(spec)
52+
with open("tmp.pkl", "wb") as f:
53+
pickle.dump(pmappings, f)
54+
55+
with open("tmp.pkl", "rb") as f:
56+
pmappings = pickle.load(f)
57+
mappings = af.mapper.FFM.join_pmappings(pmappings, spec.mapper.metrics)
4758

4859
class ActionChecker(unittest.TestCase):
4960
def _check_memory_actions_exist(self, spec, memory_names, result):

0 commit comments

Comments
 (0)