88from .surrogate_model import *
99from .utlis import *
1010from .preprocessFIM import *
11-
11+
1212
1313# MODEL LOADING
1414def load_model (model ):
1515 """Downloads and loads the model checkpoint."""
1616 fs = s3fs .S3FileSystem (anon = True )
1717 bucket_path = "sdmlab/SM_dataset/trained_model/SM_trainedmodel.ckpt"
1818
19- with fs .open (bucket_path , 'rb' ) as s3file :
19+ with fs .open (bucket_path , "rb" ) as s3file :
2020 with tempfile .NamedTemporaryFile (suffix = ".ckpt" , delete = False ) as tmp_ckpt :
2121 tmp_ckpt .write (s3file .read ())
2222 tmp_ckpt_path = tmp_ckpt .name
2323
24- device = torch .device (' cuda' if torch .cuda .is_available () else ' cpu' )
24+ device = torch .device (" cuda" if torch .cuda .is_available () else " cpu" )
2525 checkpoint = torch .load (tmp_ckpt_path , map_location = device )
26- model .load_state_dict (checkpoint [' state_dict' ])
26+ model .load_state_dict (checkpoint [" state_dict" ])
2727 model .to (device )
2828 model .eval ()
2929
@@ -39,42 +39,54 @@ def create_weight_map(M: int, N: int, device):
3939
4040 # Vectorized calculation
4141 Y , X = np .ogrid [:M , :N ]
42- dist_sq = (X - center_y )** 2 + (Y - center_x )** 2
43- weight_map = np .exp (- dist_sq / (2 * (min (M , N ) / 2 )** 2 ))
42+ dist_sq = (X - center_y ) ** 2 + (Y - center_x ) ** 2
43+ weight_map = np .exp (- dist_sq / (2 * (min (M , N ) / 2 ) ** 2 ))
4444
4545 # FIX: Only unsqueeze once to get shape (1, M, N)
4646 return torch .from_numpy (weight_map ).float ().unsqueeze (0 ).to (device )
4747
48+
4849def save_image (image : torch .Tensor , path : Path , reference_tif : str ):
4950 """Saves the prediction tensor as a GeoTIFF."""
50- image_np = image .squeeze ().cpu ().numpy ().astype (' float32' )
51+ image_np = image .squeeze ().cpu ().numpy ().astype (" float32" )
5152 with rasterio .open (reference_tif ) as ref :
5253 meta = ref .meta .copy ()
53- meta .update ({
54- "driver" : "GTiff" ,
55- "height" : image_np .shape [0 ],
56- "width" : image_np .shape [1 ],
57- "count" : 1 ,
58- "dtype" : 'float32'
59- })
60-
61- with rasterio .open (path , 'w' , ** meta ) as dst :
54+ meta .update (
55+ {
56+ "driver" : "GTiff" ,
57+ "height" : image_np .shape [0 ],
58+ "width" : image_np .shape [1 ],
59+ "count" : 1 ,
60+ "dtype" : "float32" ,
61+ }
62+ )
63+
64+ with rasterio .open (path , "w" , ** meta ) as dst :
6265 dst .write (image_np , 1 )
6366
6467 # Apply water body mask
6568 mask_with_PWB (path , path )
6669
6770 # Binarize and compress
68- with rasterio .open (path , 'r+' ) as dst :
71+ with rasterio .open (path , "r+" ) as dst :
6972 data = dst .read (1 )
7073 binary_data = np .where (data > 0 , 1 , 0 ).astype (np .uint8 )
7174 dst .write (binary_data , 1 )
7275
7376 compress_tif_lzw (path )
7477
78+
7579# REDICTION
76- def predict_optimized (dataset , model , shape : torch .Tensor , M : int = 256 , N : int = 256 ,
77- stride : int = 128 , device = None , batch_size = 32 ):
80+ def predict_optimized (
81+ dataset ,
82+ model ,
83+ shape : torch .Tensor ,
84+ M : int = 256 ,
85+ N : int = 256 ,
86+ stride : int = 128 ,
87+ device = None ,
88+ batch_size = 32 ,
89+ ):
7890 """
7991 Highly optimized prediction loop.
8092 - Uses VIEWs instead of COPIES for memory efficiency.
@@ -88,16 +100,20 @@ def predict_optimized(dataset, model, shape: torch.Tensor, M: int = 256, N: int
88100 elif sorted (dataset .x_feature_index ) == list (range (shape .shape [0 ])):
89101 X = shape [:]
90102 else :
91- print ("Warning: Creating tensor copy for non-contiguous indices (High RAM usage)" )
103+ print (
104+ "Warning: Creating tensor copy for non-contiguous indices (High RAM usage)"
105+ )
92106 X = shape [dataset .x_feature_index ]
93107
94108 y = shape [dataset .y_feature_index ]
95109
96110 img_channels , img_rows , img_cols = X .shape
97111
98112 # SETUP OUTPUT ACCUMULATORS (On CPU)
99- weighted_prediction_sum = torch .zeros ((1 , img_rows , img_cols ), dtype = torch .float32 , device = 'cpu' )
100- weight_sum = torch .zeros ((1 , img_rows , img_cols ), dtype = torch .float32 , device = 'cpu' )
113+ weighted_prediction_sum = torch .zeros (
114+ (1 , img_rows , img_cols ), dtype = torch .float32 , device = "cpu"
115+ )
116+ weight_sum = torch .zeros ((1 , img_rows , img_cols ), dtype = torch .float32 , device = "cpu" )
101117
102118 # Weight map: Shape (1, M, N)
103119 weight_map_gpu = create_weight_map (M , N , device )
@@ -127,7 +143,7 @@ def predict_optimized(dataset, model, shape: torch.Tensor, M: int = 256, N: int
127143 if h_valid < M or w_valid < N :
128144 pad_h = M - h_valid
129145 pad_w = N - w_valid
130- patch = F .pad (patch , (0 , pad_w , 0 , pad_h ), mode = ' constant' , value = 0 )
146+ patch = F .pad (patch , (0 , pad_w , 0 , pad_h ), mode = " constant" , value = 0 )
131147
132148 batch_patches .append (patch )
133149 batch_coords .append ((r , r_end , c , c_end , h_valid , w_valid ))
@@ -144,18 +160,23 @@ def predict_optimized(dataset, model, shape: torch.Tensor, M: int = 256, N: int
144160 pred_valid = preds [k , :, :h_val , :w_val ]
145161 weight_valid = weight_map_cpu [:, :h_val , :w_val ]
146162
147- weighted_prediction_sum [:, r0 :r1 , c0 :c1 ] += pred_valid * weight_valid
163+ weighted_prediction_sum [:, r0 :r1 , c0 :c1 ] += (
164+ pred_valid * weight_valid
165+ )
148166 weight_sum [:, r0 :r1 , c0 :c1 ] += weight_valid
149167
150168 processed_steps += len (batch_patches )
151- print (f" Progress: { processed_steps } /{ total_steps } ({ 100 * processed_steps / total_steps :.1f} %)" , end = '\r ' )
169+ print (
170+ f" Progress: { processed_steps } /{ total_steps } ({ 100 * processed_steps / total_steps :.1f} %)" ,
171+ end = "\r " ,
172+ )
152173
153174 batch_patches = []
154175 batch_coords = []
155176 del batch_tensor , preds
156177
157178 if processed_steps % (batch_size * 10 ) == 0 :
158- torch .cuda .empty_cache ()
179+ torch .cuda .empty_cache ()
159180
160181 # Process remaining patches
161182 if batch_patches :
@@ -178,21 +199,24 @@ def predict_optimized(dataset, model, shape: torch.Tensor, M: int = 256, N: int
178199
179200 lf_idx = dataset .lf_index
180201 if isinstance (shape , torch .Tensor ):
181- lf = shape [lf_idx ].unsqueeze (0 )
202+ lf = shape [lf_idx ].unsqueeze (0 )
182203 else :
183- lf = None
204+ lf = None
184205
185206 return final_prediction , y , lf
186207
208+
187209# MAIN FUNCTION
188210def enhanceFIM (huc_id , patch_size = (256 , 256 ), batch_size = 32 ):
189211
190- device_type = ' cuda' if torch .cuda .is_available () else ' cpu'
212+ device_type = " cuda" if torch .cuda .is_available () else " cpu"
191213 print (f"\n { '=' * 60 } \n SYSTEM: { device_type .upper ()} \n { '=' * 60 } " )
192214
193- data_dir = Path (f' ./HUC{ huc_id } _forcings/' )
215+ data_dir = Path (f" ./HUC{ huc_id } _forcings/" )
194216 model = AttentionUNet (channel = 8 )
195- preprocessor = InferenceDataPreprocessor (data_dir = Path (data_dir ), patch_size = patch_size , verbose = True )
217+ preprocessor = InferenceDataPreprocessor (
218+ data_dir = Path (data_dir ), patch_size = patch_size , verbose = True
219+ )
196220
197221 print ("Loading model..." )
198222 model , device = load_model (model )
@@ -204,7 +228,7 @@ def enhanceFIM(huc_id, patch_size=(256, 256), batch_size=32):
204228 print (f"\n Processing [{ idx } /{ len (lf_files )} ]: { lf_filename } " )
205229
206230 static_stack = preprocessor .get_static_stack (huc_id )
207- lf_tensor = preprocessor .tif_to_tensor (lf_path , feature_name = ' low_fidelity' )
231+ lf_tensor = preprocessor .tif_to_tensor (lf_path , feature_name = " low_fidelity" )
208232
209233 print ("Merging tensors..." )
210234 area_tensor = torch .cat ([static_stack , lf_tensor ], dim = 0 )
@@ -213,7 +237,9 @@ def enhanceFIM(huc_id, patch_size=(256, 256), batch_size=32):
213237 del lf_tensor
214238 gc .collect ()
215239
216- print (f"Tensor Shape: { area_tensor .shape } | Memory: { area_tensor .element_size () * area_tensor .nelement () / 1e9 :.2f} GB" )
240+ print (
241+ f"Tensor Shape: { area_tensor .shape } | Memory: { area_tensor .element_size () * area_tensor .nelement () / 1e9 :.2f} GB"
242+ )
217243
218244 class Dummy :
219245 x_feature_index = slice (None )
@@ -229,7 +255,7 @@ class Dummy:
229255 N = patch_size [1 ],
230256 stride = patch_size [0 ] // 2 ,
231257 device = device ,
232- batch_size = batch_size
258+ batch_size = batch_size ,
233259 )
234260
235261 except RuntimeError as e :
@@ -238,15 +264,21 @@ class Dummy:
238264 torch .cuda .empty_cache ()
239265 gc .collect ()
240266 x , y , lf = predict_optimized (
241- Dummy , model , area_tensor , M = patch_size [0 ], N = patch_size [1 ],
242- stride = patch_size [0 ] // 2 , device = device , batch_size = 4
267+ Dummy ,
268+ model ,
269+ area_tensor ,
270+ M = patch_size [0 ],
271+ N = patch_size [1 ],
272+ stride = patch_size [0 ] // 2 ,
273+ device = device ,
274+ batch_size = 4 ,
243275 )
244276 else :
245277 raise e
246278
247279 del area_tensor
248280 gc .collect ()
249- if device .type == ' cuda' :
281+ if device .type == " cuda" :
250282 torch .cuda .empty_cache ()
251283
252284 pred_dir = Path (f"./Results/HUC{ huc_id } /" )
@@ -256,4 +288,4 @@ class Dummy:
256288 save_image (x , pred_path , str (lf_path ))
257289 print (f"✓ Saved: { pred_path } " )
258290
259- print (f"\n { '=' * 60 } \n COMPLETED\n { '=' * 60 } " )
291+ print (f"\n { '=' * 60 } \n COMPLETED\n { '=' * 60 } " )
0 commit comments