1+ import gc
2+ import tempfile
3+ from pathlib import Path
4+
5+
16from .SM_preprocess import *
27from .surrogate_model import *
38from .utlis import *
49from .preprocessFIM import *
10+
511
12+ # MODEL LOADING
613def load_model (model ):
7- # Set up S3 access
14+ """Downloads and loads the model checkpoint."""
815 fs = s3fs .S3FileSystem (anon = True )
916 bucket_path = "sdmlab/SM_dataset/trained_model/SM_trainedmodel.ckpt"
1017
11- # Download to a temporary file
1218 with fs .open (bucket_path , 'rb' ) as s3file :
1319 with tempfile .NamedTemporaryFile (suffix = ".ckpt" , delete = False ) as tmp_ckpt :
1420 tmp_ckpt .write (s3file .read ())
1521 tmp_ckpt_path = tmp_ckpt .name
1622
17- # Load checkpoint
18- checkpoint = torch .load (tmp_ckpt_path , map_location = 'cuda' if torch .cuda .is_available () else 'cpu' )
19- model .load_state_dict (checkpoint ['state_dict' ])
20-
21- # Move model to device
2223 device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
24+ checkpoint = torch .load (tmp_ckpt_path , map_location = device )
25+ model .load_state_dict (checkpoint ['state_dict' ])
2326 model .to (device )
2427 model .eval ()
2528
2629 return model , device
2730
28- #WEIGHTED AVERAGE for patch
29- def create_weight_map (M : int , N : int ):
31+
32+ # HELPER FUNCTIONS
33+ def create_weight_map (M : int , N : int , device ):
34+ """Creates a Gaussian weight map for smooth patch merging."""
35+ # Create numpy array first
3036 weight_map = np .zeros ((M , N ), dtype = np .float32 )
3137 center_x , center_y = M // 2 , N // 2
32- for i in range (M ):
33- for j in range (N ):
34- dist_sq = (i - center_x )** 2 + (j - center_y )** 2
35- weight = np .exp (- dist_sq / (2 * (min (M , N ) / 2 )** 2 ))
36- weight_map [i , j ] = weight
37- return torch .from_numpy (weight_map ).float ().unsqueeze (0 ).unsqueeze (0 ) # (1, 1, M, N)
38-
39- #If there is Stride
40- def predict_on_area (dataset , model , shape : torch .Tensor , M : int = 256 , N : int = 256 , stride : int = 128 , device = None ):
41- # Get row and col size
42- shape_row = shape .size (1 )
43- shape_col = shape .size (2 )
44-
45- # Pad if needed
46- pad_h = (stride * ((shape_row - M ) // stride + 1 ) + M - shape_row )
47- pad_w = (stride * ((shape_col - N ) // stride + 1 ) + N - shape_col )
48-
49- if pad_h > 0 or pad_w > 0 :
50- padding = (0 , pad_w , 0 , pad_h )
51- shape = nn .functional .pad (shape , padding , mode = 'constant' , value = 0 )
52-
53- # Update new shape after padding
54- new_row = shape .size (1 )
55- new_col = shape .size (2 )
56-
57- # Separate X and Y
58- X = shape [dataset .x_feature_index ]
59- y = shape [dataset .y_feature_index ]
60-
61- # Initialize weighted prediction sum and weight sum arrays
62- weighted_prediction_sum = torch .zeros ((1 , new_row , new_col ), device = device )
63- weight_sum = torch .zeros ((1 , new_row , new_col ), device = device )
6438
65- # Create the weight map
66- weight_map = create_weight_map (M , N ).to (device )
39+ # Vectorized calculation
40+ Y , X = np .ogrid [:M , :N ]
41+ dist_sq = (X - center_y )** 2 + (Y - center_x )** 2
42+ weight_map = np .exp (- dist_sq / (2 * (min (M , N ) / 2 )** 2 ))
6743
68- # Loop over patches
69- for start_i in range (0 , new_row - M + 1 , stride ):
70- for start_j in range (0 , new_col - N + 1 , stride ):
71- end_i = start_i + M
72- end_j = start_j + N
73- patch = X [:, start_i :end_i , start_j :end_j ].unsqueeze (0 ).to (device )
44+ # FIX: Only unsqueeze once to get shape (1, M, N)
45+ return torch .from_numpy (weight_map ).float ().unsqueeze (0 ).to (device )
7446
75- with torch .no_grad ():
76- patch_prediction_raw = model (patch )
77-
78- weighted_prediction = patch_prediction_raw * weight_map
79- weighted_prediction_sum [:, start_i :end_i , start_j :end_j ] += weighted_prediction .squeeze (0 )
80- weight_sum [:, start_i :end_i , start_j :end_j ] += weight_map .squeeze (0 )
81-
82- epsilon = 1e-8
83- final_prediction = weighted_prediction_sum / (weight_sum + epsilon )
84- final_prediction = (final_prediction > 0.01 ).float ()
85-
86- # Crop back to original shape (before padding)
87- final_prediction = final_prediction [:, :shape_row , :shape_col ]
88- y = y [:, :shape_row , :shape_col ]
89- lf = shape [[dataset .lf_index ]][:, :shape_row , :shape_col ]
90-
91- return final_prediction .cpu (), y .cpu (), lf .cpu ()
92-
93- #Save the tif file
9447def save_image (image : torch .Tensor , path : Path , reference_tif : str ):
95- """Save the image as a .tif file.
96-
97- Args:
98- image (torch.Tensor): The image to save
99- path (Path): The path to save the image
100- """
48+ """Saves the prediction tensor as a GeoTIFF."""
10149 image_np = image .squeeze ().cpu ().numpy ().astype ('float32' )
10250 with rasterio .open (reference_tif ) as ref :
10351 meta = ref .meta .copy ()
@@ -111,60 +59,200 @@ def save_image(image: torch.Tensor, path: Path, reference_tif: str):
11159
11260 with rasterio .open (path , 'w' , ** meta ) as dst :
11361 dst .write (image_np , 1 )
62+
63+ # Apply water body mask
11464 mask_with_PWB (path , path )
11565
66+ # Binarize and compress
11667 with rasterio .open (path , 'r+' ) as dst :
11768 data = dst .read (1 )
11869 binary_data = np .where (data > 0 , 1 , 0 ).astype (np .uint8 )
11970 dst .write (binary_data , 1 )
12071
12172 compress_tif_lzw (path )
122-
123-
124- #ENHANCE THE LOW-FIDELITY FLOOD MAP
125- def enhanceFIM (huc_id , patch_size = (256 , 256 )):
126-
73+
74+ # REDICTION
75+ def predict_optimized (dataset , model , shape : torch .Tensor , M : int = 256 , N : int = 256 ,
76+ stride : int = 128 , device = None , batch_size = 32 ):
77+ """
78+ Highly optimized prediction loop.
79+ - Uses VIEWs instead of COPIES for memory efficiency.
80+ - Performs on-the-fly padding.
81+ - Streams batches to GPU while keeping the main map on CPU.
82+ """
83+
84+ # SETUP INPUTS
85+ if isinstance (dataset .x_feature_index , slice ):
86+ X = shape [dataset .x_feature_index ]
87+ elif sorted (dataset .x_feature_index ) == list (range (shape .shape [0 ])):
88+ X = shape [:]
89+ else :
90+ print ("Warning: Creating tensor copy for non-contiguous indices (High RAM usage)" )
91+ X = shape [dataset .x_feature_index ]
92+
93+ y = shape [dataset .y_feature_index ]
94+
95+ img_channels , img_rows , img_cols = X .shape
96+
97+ # SETUP OUTPUT ACCUMULATORS (On CPU)
98+ weighted_prediction_sum = torch .zeros ((1 , img_rows , img_cols ), dtype = torch .float32 , device = 'cpu' )
99+ weight_sum = torch .zeros ((1 , img_rows , img_cols ), dtype = torch .float32 , device = 'cpu' )
100+
101+ # Weight map: Shape (1, M, N)
102+ weight_map_gpu = create_weight_map (M , N , device )
103+ weight_map_cpu = weight_map_gpu .cpu ()
104+
105+ # BATCH PROCESSING LOOP
106+ batch_patches = []
107+ batch_coords = []
108+
109+ total_steps = ((img_rows - 1 ) // stride + 1 ) * ((img_cols - 1 ) // stride + 1 )
110+ processed_steps = 0
111+
112+ print (f" Starting inference on { img_rows } x{ img_cols } image..." )
113+
114+ for r in range (0 , img_rows , stride ):
115+ for c in range (0 , img_cols , stride ):
116+ r_end = min (r + M , img_rows )
117+ c_end = min (c + N , img_cols )
118+
119+ h_valid = r_end - r
120+ w_valid = c_end - c
121+
122+ # Extract patch from CPU tensor (View)
123+ patch = X [:, r :r_end , c :c_end ]
124+
125+ # Handle Boundary Padding (On-the-fly)
126+ if h_valid < M or w_valid < N :
127+ pad_h = M - h_valid
128+ pad_w = N - w_valid
129+ patch = F .pad (patch , (0 , pad_w , 0 , pad_h ), mode = 'constant' , value = 0 )
130+
131+ batch_patches .append (patch )
132+ batch_coords .append ((r , r_end , c , c_end , h_valid , w_valid ))
133+
134+ # INFERENCE STEP
135+ if len (batch_patches ) >= batch_size :
136+ batch_tensor = torch .stack (batch_patches ).to (device )
137+
138+ with torch .no_grad ():
139+ preds = model (batch_tensor ).cpu ()
140+
141+ # Accumulate
142+ for k , (r0 , r1 , c0 , c1 , h_val , w_val ) in enumerate (batch_coords ):
143+ pred_valid = preds [k , :, :h_val , :w_val ]
144+ weight_valid = weight_map_cpu [:, :h_val , :w_val ]
145+
146+ weighted_prediction_sum [:, r0 :r1 , c0 :c1 ] += pred_valid * weight_valid
147+ weight_sum [:, r0 :r1 , c0 :c1 ] += weight_valid
148+
149+ processed_steps += len (batch_patches )
150+ print (f" Progress: { processed_steps } /{ total_steps } ({ 100 * processed_steps / total_steps :.1f} %)" , end = '\r ' )
151+
152+ batch_patches = []
153+ batch_coords = []
154+ del batch_tensor , preds
155+
156+ if processed_steps % (batch_size * 10 ) == 0 :
157+ torch .cuda .empty_cache ()
158+
159+ # Process remaining patches
160+ if batch_patches :
161+ batch_tensor = torch .stack (batch_patches ).to (device )
162+ with torch .no_grad ():
163+ preds = model (batch_tensor ).cpu ()
164+
165+ for k , (r0 , r1 , c0 , c1 , h_val , w_val ) in enumerate (batch_coords ):
166+ pred_valid = preds [k , :, :h_val , :w_val ]
167+ weight_valid = weight_map_cpu [:, :h_val , :w_val ]
168+ weighted_prediction_sum [:, r0 :r1 , c0 :c1 ] += pred_valid * weight_valid
169+ weight_sum [:, r0 :r1 , c0 :c1 ] += weight_valid
170+
171+ print (f" Progress: 100% - Inference Complete." )
172+
173+ # NORMALIZE AND FINALIZE
174+ epsilon = 1e-8
175+ final_prediction = weighted_prediction_sum / (weight_sum + epsilon )
176+ final_prediction = (final_prediction > 0.01 ).float ()
177+
178+ lf_idx = dataset .lf_index
179+ if isinstance (shape , torch .Tensor ):
180+ lf = shape [lf_idx ].unsqueeze (0 )
181+ else :
182+ lf = None
183+
184+ return final_prediction , y , lf
185+
186+ # MAIN FUNCTION
187+ def enhanceFIM (huc_id , patch_size = (256 , 256 ), batch_size = 32 ):
188+
189+ device_type = 'cuda' if torch .cuda .is_available () else 'cpu'
190+ print (f"\n { '=' * 60 } \n SYSTEM: { device_type .upper ()} \n { '=' * 60 } " )
191+
127192 data_dir = Path (f'./HUC{ huc_id } _forcings/' )
128- model = AttentionUNet (channel = 8 )
129-
193+ model = AttentionUNet (channel = 8 )
130194 preprocessor = InferenceDataPreprocessor (data_dir = Path (data_dir ), patch_size = patch_size , verbose = True )
131-
195+
132196 print ("Loading model..." )
133197 model , device = load_model (model )
134- print ("Model loaded." )
135-
136-
198+
137199 lf_files = preprocessor .get_all_lf_maps (huc_id )
138- for lf_path in lf_files :
200+
201+ for idx , lf_path in enumerate (lf_files , 1 ):
139202 lf_filename = lf_path .name
140- print (f"Predicting for : { lf_filename } \n " )
203+ print (f"\n Processing [ { idx } / { len ( lf_files ) } ] : { lf_filename } " )
141204
142- print (f"Loading static features for HUC { huc_id } ..." )
143205 static_stack = preprocessor .get_static_stack (huc_id )
144206 lf_tensor = preprocessor .tif_to_tensor (lf_path , feature_name = 'low_fidelity' )
145207
146- # Combine and validate
208+ print ( "Merging tensors..." )
147209 area_tensor = torch .cat ([static_stack , lf_tensor ], dim = 0 )
148- if area_tensor .shape [0 ] != 8 :
149- raise ValueError (f"Expected 8 channels, got { area_tensor .shape [0 ]} — check missing static feature for HUC { huc_id } ." )
150210
151- # Define dummy interface
211+ del static_stack
212+ del lf_tensor
213+ gc .collect ()
214+
215+ print (f"Tensor Shape: { area_tensor .shape } | Memory: { area_tensor .element_size () * area_tensor .nelement () / 1e9 :.2f} GB" )
216+
152217 class Dummy :
153- x_feature_index = list ( range ( area_tensor . shape [ 0 ]))
154- y_feature_index = [area_tensor .shape [0 ] - 1 ]
218+ x_feature_index = slice ( None )
219+ y_feature_index = [area_tensor .shape [0 ] - 1 ]
155220 lf_index = area_tensor .shape [0 ] - 1
156221
157- print (f"Static features loaded for { huc_id } .\n " )
158-
159- # Predict
160- print (f"Enhancing { lf_path } ..." )
161- x , y , lf = predict_on_area (Dummy , model , area_tensor , M = patch_size [0 ], N = patch_size [1 ], stride = patch_size [0 ] // 2 , device = device )
222+ try :
223+ x , y , lf = predict_optimized (
224+ Dummy ,
225+ model ,
226+ area_tensor ,
227+ M = patch_size [0 ],
228+ N = patch_size [1 ],
229+ stride = patch_size [0 ] // 2 ,
230+ device = device ,
231+ batch_size = batch_size
232+ )
233+
234+ except RuntimeError as e :
235+ if "out of memory" in str (e ):
236+ print ("OOM Error. Retrying with batch_size=4..." )
237+ torch .cuda .empty_cache ()
238+ gc .collect ()
239+ x , y , lf = predict_optimized (
240+ Dummy , model , area_tensor , M = patch_size [0 ], N = patch_size [1 ],
241+ stride = patch_size [0 ] // 2 , device = device , batch_size = 4
242+ )
243+ else :
244+ raise e
245+
246+ del area_tensor
247+ gc .collect ()
248+ if device .type == 'cuda' :
249+ torch .cuda .empty_cache ()
162250
163- # Save result
164251 pred_dir = Path (f"./Results/HUC{ huc_id } /" )
165252 pred_dir .mkdir (parents = True , exist_ok = True )
166253 pred_path = pred_dir / f"SMprediction_{ lf_filename } "
167- save_image (x , pred_path , lf_path )
168- print (f"Enhancement completed for { lf_filename } .\n " )
169-
170254
255+ save_image (x , pred_path , str (lf_path ))
256+ print (f"✓ Saved: { pred_path } " )
257+
258+ print (f"\n { '=' * 60 } \n COMPLETED\n { '=' * 60 } " )
0 commit comments