@@ -94,14 +94,23 @@ def reduce_precision(data: pd.DataFrame) -> pd.DataFrame:
9494
9595class PmappingDataframe :
9696 def _assert_invariants_before_and_after (f ):
97+ @functools .wraps (f )
9798 def wrapped (self , * args , ** kwargs ):
98- self ._assert_reservation_includes_live_tensors ()
99- self ._assert_consistent_left_right_reservations ()
100- self ._assert_reservation_inclusivity ()
99+ try :
100+ self ._assert_reservation_includes_live_tensors ()
101+ self ._assert_consistent_left_right_reservations ()
102+ self ._assert_reservation_inclusivity ()
103+ except Exception as e :
104+ raise ValueError (f"broken invariance before calling { f } " ) from e
105+
101106 result = f (self , * args , ** kwargs )
102- self ._assert_reservation_includes_live_tensors ()
103- self ._assert_consistent_left_right_reservations ()
104- self ._assert_reservation_inclusivity ()
107+
108+ try :
109+ self ._assert_reservation_includes_live_tensors ()
110+ self ._assert_consistent_left_right_reservations ()
111+ self ._assert_reservation_inclusivity ()
112+ except Exception as e :
113+ raise ValueError (f"broken invariance after calling { f } " ) from e
105114 return result
106115 return wrapped
107116
@@ -145,7 +154,7 @@ def __init__(
145154 )
146155
147156 if fill_reservation_cols : # Affects PmappingDataframe so must go before
148- self .fill_reservation_cols (fill_reservation_cols )
157+ self ._fill_reservation_cols (fill_reservation_cols )
149158 if check_above_subset_below :
150159 self .check_above_subset_below ()
151160
@@ -171,49 +180,6 @@ def rename(self, renames: dict[str, str]) -> "PmappingDataframe":
171180 new .data .rename (columns = renames , inplace = True )
172181 return new
173182
174- @error_check_wrapper
175- def fill_reservation_cols (self , columns : set | str ):
176- l_reservations , r_reservations = self ._make_reservations ()
177- targets = []
178- if columns == "auto" :
179- for left , reservations_dict in [
180- (True , l_reservations ),
181- (False , r_reservations ),
182- ]:
183- for resource , reservations in reservations_dict .items ():
184- for r in sorted (reservations ):
185- above = _get_reservation_or_parent (
186- resource , r - 1 , l_reservations , r_reservations
187- )
188- if above is not None :
189- below = reservation2col (resource , r , left = left )
190- targets .append ((r , above , below ))
191- else :
192- for below in columns :
193- if (name_nloops := col2reservation (below )) is None :
194- raise ValueError (f"{ below } is not a valid reservation column" )
195- name , nloops = name_nloops .name , name_nloops .nloops
196- above = _get_reservation_or_parent (
197- name , nloops - 1 , l_reservations , r_reservations
198- )
199- if above is not None :
200- targets .append ((nloops , above , below ))
201-
202- # Sort so we go from top to bottom. Needed in case we have to max 0->1
203- # then 1->2
204- for _ , above , below in sorted (targets , key = lambda x : x [0 ]):
205- assert (
206- above in self .data .columns
207- ), f"Missing column { above } . Have columns:\n \t " + "\n \t " .join (
208- list (self .data .columns )
209- )
210- assert (
211- below in self .data .columns
212- ), f"Missing column { below } . Have columns:\n \t " + "\n \t " .join (
213- list (self .data .columns )
214- )
215- max_to_col (self .data , below , above )
216-
217183 @property
218184 def data (self ) -> pd .DataFrame :
219185 return self ._data
@@ -265,6 +231,9 @@ def free_to_loop_index(self, loop_index: int) -> bool:
265231 if loop_index < - 1 :
266232 raise ValueError ("loop_index must be >= -1" )
267233
234+ if len (self .data ) == 0 :
235+ return False
236+
268237 # We keep reservations under loop_index, which is index loop_index+1
269238 reservation_max_index = loop_index + 1
270239
@@ -425,12 +394,6 @@ def check_match(la: Loop, lb: Loop, param: str):
425394 df = pd .merge (sd , rd , how = "cross" , suffixes = ["" , MERGE_SUFFIX ])
426395
427396 df = reduce_precision (df )
428- # We made a new column! Update our reservations so future iterations
429- # know about it.
430- l_reservations , r_reservations = self ._make_reservations ()
431-
432- # Assert all reservations are >= 0
433- assert (self .data [target ] >= 0 ).all (), f"Negative reservation: { target } "
434397
435398 # Drop all fused loop columns that are not used anymore
436399 remaining_symbols = compatibility_joined .symbols ()
@@ -439,6 +402,12 @@ def check_match(la: Loop, lb: Loop, param: str):
439402 ]
440403 df = df .drop (columns = dropcols )
441404
405+ # Making sure the column you want originated from the right pmapping
406+ # is tricky. Use this function!
407+ def col_from_right_pmapping (col ):
408+ col_with_suffix = col + MERGE_SUFFIX
409+ return col_with_suffix if col_with_suffix in df else col
410+
442411 # Number of combinations
443412 n_total_pmappings = self .n_total_pmappings * right .n_total_pmappings
444413 n_valid_pmappings = self .n_valid_pmappings * right .n_valid_pmappings
@@ -481,9 +450,7 @@ def check_match(la: Loop, lb: Loop, param: str):
481450 adjustment_val = adjustment [nloops ]
482451 if nloops not in right_df_r_reservations [resource ]:
483452 raise RuntimeError ("bug" )
484- target = reservation2col (resource , nloops )
485- if target not in df :
486- target += MERGE_SUFFIX
453+ target = col_from_right_pmapping (reservation2col (resource , nloops ))
487454 add_to_col (df , target , adjustment_val )
488455
489456 # Make sure everything is done in increasing loop order so we don't have
@@ -523,9 +490,7 @@ def iter_reservations(reservations_dict):
523490 )
524491 ) is None :
525492 continue
526- right_merge_source = source + MERGE_SUFFIX
527- if right_merge_source in df :
528- source = right_merge_source
493+ source = col_from_right_pmapping (source )
529494 for target in get_reservation_cols_with (
530495 df ,
531496 name = resource ,
@@ -546,9 +511,7 @@ def iter_reservations(reservations_dict):
546511 )
547512 ) is None :
548513 continue
549- right_merge_source = source + MERGE_SUFFIX
550- if right_merge_source in df :
551- source = right_merge_source
514+ source = col_from_right_pmapping (source )
552515 target = reservation2col (resource , nloops )
553516 add_to_col (df , target , source )
554517
@@ -568,7 +531,6 @@ def iter_reservations(reservations_dict):
568531 max_to_col (df , target , source )
569532 else :
570533 add_to_col (df , target , source )
571- complete_df = df
572534 df = df .drop (columns = dropcols )
573535
574536 result = PmappingDataframe (
@@ -761,6 +723,49 @@ def has_reservations(self):
761723 # ============================================================================
762724 # Helper functions
763725 # ============================================================================
726+ @error_check_wrapper
727+ def _fill_reservation_cols (self , columns : set | str ):
728+ l_reservations , r_reservations = self ._make_reservations ()
729+ targets = []
730+ if columns == "auto" :
731+ for left , reservations_dict in [
732+ (True , l_reservations ),
733+ (False , r_reservations ),
734+ ]:
735+ for resource , reservations in reservations_dict .items ():
736+ for r in sorted (reservations ):
737+ above = _get_reservation_or_parent (
738+ resource , r - 1 , l_reservations , r_reservations
739+ )
740+ if above is not None :
741+ below = reservation2col (resource , r , left = left )
742+ targets .append ((r , above , below ))
743+ else :
744+ for below in columns :
745+ if (name_nloops := col2reservation (below )) is None :
746+ raise ValueError (f"{ below } is not a valid reservation column" )
747+ name , nloops = name_nloops .name , name_nloops .nloops
748+ above = _get_reservation_or_parent (
749+ name , nloops - 1 , l_reservations , r_reservations
750+ )
751+ if above is not None :
752+ targets .append ((nloops , above , below ))
753+
754+ # Sort so we go from top to bottom. Needed in case we have to max 0->1
755+ # then 1->2
756+ for _ , above , below in sorted (targets , key = lambda x : x [0 ]):
757+ assert (
758+ above in self .data .columns
759+ ), f"Missing column { above } . Have columns:\n \t " + "\n \t " .join (
760+ list (self .data .columns )
761+ )
762+ assert (
763+ below in self .data .columns
764+ ), f"Missing column { below } . Have columns:\n \t " + "\n \t " .join (
765+ list (self .data .columns )
766+ )
767+ max_to_col (self .data , below , above )
768+
764769 def _free_live_reservation_to_loop_index (self , loop_index : int ):
765770 reservation_index = loop_index + 1
766771 dropcols = []
@@ -906,13 +911,17 @@ def _shift_bottom_reservation_left(self):
906911 thread = DEFAULT_THREAD ,
907912 ):
908913 right_key = col2live_reservation (live_tensor_in_right )
909- new_live_tensor = live_reservation2col (
910- resource ,
911- right_key .tensor ,
912- bottom_loop_index ,
913- thread_i
914- )
915- max_to_col (df , new_live_tensor , live_tensor_in_right )
914+ for thread_j in left_concurrent_threads :
915+ new_live_tensor = live_reservation2col (
916+ resource ,
917+ right_key .tensor ,
918+ bottom_loop_index ,
919+ thread_j
920+ )
921+ if thread_j == thread_i :
922+ max_to_col (df , new_live_tensor , live_tensor_in_right )
923+ else :
924+ df .loc [:,new_live_tensor ] = 0
916925 df .drop (columns = [live_tensor_in_right ], inplace = True )
917926
918927 for thread_j in left_concurrent_threads :
@@ -1174,6 +1183,7 @@ def _assert_reservation_includes_live_tensors(self):
11741183 ):
11751184 continue
11761185 if (self .data [res_col ] < self .data [col ]).any ():
1186+ breakpoint ()
11771187 raise RuntimeError (f"reservation smaller than reservation for live tensor { col } " )
11781188
11791189 def _assert_reservation_inclusivity (self ):
@@ -1375,3 +1385,4 @@ def _get_reservation_or_parent(
13751385 left = False
13761386 level -= 1
13771387 return None
1388+
0 commit comments