Skip to content

Commit 87e9a43

Browse files
update the custom boundary FIM enchancement using surrogate model
1 parent 13d851e commit 87e9a43

5 files changed

Lines changed: 538 additions & 156 deletions

File tree

172 KB
Binary file not shown.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "fimserve"
3-
version = "0.1.90"
3+
version = "0.1.91"
44
description = "Framework which is developed with the purpose of quickly generating Flood Inundation Maps (FIM) for emergency response and risk assessment. It is developed under Surface Dynamics Modeling Lab (SDML)."
55
authors = [{ name = "Supath Dhital", email = "sdhital@crimson.ua.edu" }]
66
maintainers = [{ name = "Supath Dhital", email = "sdhital@crimson.ua.edu" }]
Lines changed: 191 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,51 @@
1+
import gc
2+
import tempfile
3+
from pathlib import Path
4+
5+
16
from .SM_preprocess import *
27
from .surrogate_model import *
38
from .utlis import *
49
from .preprocessFIM import *
10+
511

12+
# MODEL LOADING
613
def load_model(model):
7-
# Set up S3 access
14+
"""Downloads and loads the model checkpoint."""
815
fs = s3fs.S3FileSystem(anon=True)
916
bucket_path = "sdmlab/SM_dataset/trained_model/SM_trainedmodel.ckpt"
1017

11-
# Download to a temporary file
1218
with fs.open(bucket_path, 'rb') as s3file:
1319
with tempfile.NamedTemporaryFile(suffix=".ckpt", delete=False) as tmp_ckpt:
1420
tmp_ckpt.write(s3file.read())
1521
tmp_ckpt_path = tmp_ckpt.name
1622

17-
# Load checkpoint
18-
checkpoint = torch.load(tmp_ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
19-
model.load_state_dict(checkpoint['state_dict'])
20-
21-
# Move model to device
2223
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24+
checkpoint = torch.load(tmp_ckpt_path, map_location=device)
25+
model.load_state_dict(checkpoint['state_dict'])
2326
model.to(device)
2427
model.eval()
2528

2629
return model, device
2730

28-
#WEIGHTED AVERAGE for patch
29-
def create_weight_map(M: int, N: int):
31+
32+
# HELPER FUNCTIONS
33+
def create_weight_map(M: int, N: int, device):
34+
"""Creates a Gaussian weight map for smooth patch merging."""
35+
# Create numpy array first
3036
weight_map = np.zeros((M, N), dtype=np.float32)
3137
center_x, center_y = M // 2, N // 2
32-
for i in range(M):
33-
for j in range(N):
34-
dist_sq = (i - center_x)**2 + (j - center_y)**2
35-
weight = np.exp(-dist_sq / (2 * (min(M, N) / 2)**2))
36-
weight_map[i, j] = weight
37-
return torch.from_numpy(weight_map).float().unsqueeze(0).unsqueeze(0) # (1, 1, M, N)
38-
39-
#If there is Stride
40-
def predict_on_area(dataset, model, shape: torch.Tensor, M: int = 256, N: int = 256, stride: int = 128, device=None):
41-
# Get row and col size
42-
shape_row = shape.size(1)
43-
shape_col = shape.size(2)
44-
45-
# Pad if needed
46-
pad_h = (stride * ((shape_row - M) // stride + 1) + M - shape_row)
47-
pad_w = (stride * ((shape_col - N) // stride + 1) + N - shape_col)
48-
49-
if pad_h > 0 or pad_w > 0:
50-
padding = (0, pad_w, 0, pad_h)
51-
shape = nn.functional.pad(shape, padding, mode='constant', value=0)
52-
53-
# Update new shape after padding
54-
new_row = shape.size(1)
55-
new_col = shape.size(2)
56-
57-
# Separate X and Y
58-
X = shape[dataset.x_feature_index]
59-
y = shape[dataset.y_feature_index]
60-
61-
# Initialize weighted prediction sum and weight sum arrays
62-
weighted_prediction_sum = torch.zeros((1, new_row, new_col), device=device)
63-
weight_sum = torch.zeros((1, new_row, new_col), device=device)
6438

65-
# Create the weight map
66-
weight_map = create_weight_map(M, N).to(device)
39+
# Vectorized calculation
40+
Y, X = np.ogrid[:M, :N]
41+
dist_sq = (X - center_y)**2 + (Y - center_x)**2
42+
weight_map = np.exp(-dist_sq / (2 * (min(M, N) / 2)**2))
6743

68-
# Loop over patches
69-
for start_i in range(0, new_row - M + 1, stride):
70-
for start_j in range(0, new_col - N + 1, stride):
71-
end_i = start_i + M
72-
end_j = start_j + N
73-
patch = X[:, start_i:end_i, start_j:end_j].unsqueeze(0).to(device)
44+
# FIX: Only unsqueeze once to get shape (1, M, N)
45+
return torch.from_numpy(weight_map).float().unsqueeze(0).to(device)
7446

75-
with torch.no_grad():
76-
patch_prediction_raw = model(patch)
77-
78-
weighted_prediction = patch_prediction_raw * weight_map
79-
weighted_prediction_sum[:, start_i:end_i, start_j:end_j] += weighted_prediction.squeeze(0)
80-
weight_sum[:, start_i:end_i, start_j:end_j] += weight_map.squeeze(0)
81-
82-
epsilon = 1e-8
83-
final_prediction = weighted_prediction_sum / (weight_sum + epsilon)
84-
final_prediction = (final_prediction > 0.01).float()
85-
86-
# Crop back to original shape (before padding)
87-
final_prediction = final_prediction[:, :shape_row, :shape_col]
88-
y = y[:, :shape_row, :shape_col]
89-
lf = shape[[dataset.lf_index]][:, :shape_row, :shape_col]
90-
91-
return final_prediction.cpu(), y.cpu(), lf.cpu()
92-
93-
#Save the tif file
9447
def save_image(image: torch.Tensor, path: Path, reference_tif: str):
95-
"""Save the image as a .tif file.
96-
97-
Args:
98-
image (torch.Tensor): The image to save
99-
path (Path): The path to save the image
100-
"""
48+
"""Saves the prediction tensor as a GeoTIFF."""
10149
image_np = image.squeeze().cpu().numpy().astype('float32')
10250
with rasterio.open(reference_tif) as ref:
10351
meta = ref.meta.copy()
@@ -111,60 +59,200 @@ def save_image(image: torch.Tensor, path: Path, reference_tif: str):
11159

11260
with rasterio.open(path, 'w', **meta) as dst:
11361
dst.write(image_np, 1)
62+
63+
# Apply water body mask
11464
mask_with_PWB(path, path)
11565

66+
# Binarize and compress
11667
with rasterio.open(path, 'r+') as dst:
11768
data = dst.read(1)
11869
binary_data = np.where(data > 0, 1, 0).astype(np.uint8)
11970
dst.write(binary_data, 1)
12071

12172
compress_tif_lzw(path)
122-
123-
124-
#ENHANCE THE LOW-FIDELITY FLOOD MAP
125-
def enhanceFIM(huc_id, patch_size=(256, 256)):
126-
73+
74+
# REDICTION
75+
def predict_optimized(dataset, model, shape: torch.Tensor, M: int = 256, N: int = 256,
76+
stride: int = 128, device=None, batch_size=32):
77+
"""
78+
Highly optimized prediction loop.
79+
- Uses VIEWs instead of COPIES for memory efficiency.
80+
- Performs on-the-fly padding.
81+
- Streams batches to GPU while keeping the main map on CPU.
82+
"""
83+
84+
# SETUP INPUTS
85+
if isinstance(dataset.x_feature_index, slice):
86+
X = shape[dataset.x_feature_index]
87+
elif sorted(dataset.x_feature_index) == list(range(shape.shape[0])):
88+
X = shape[:]
89+
else:
90+
print("Warning: Creating tensor copy for non-contiguous indices (High RAM usage)")
91+
X = shape[dataset.x_feature_index]
92+
93+
y = shape[dataset.y_feature_index]
94+
95+
img_channels, img_rows, img_cols = X.shape
96+
97+
# SETUP OUTPUT ACCUMULATORS (On CPU)
98+
weighted_prediction_sum = torch.zeros((1, img_rows, img_cols), dtype=torch.float32, device='cpu')
99+
weight_sum = torch.zeros((1, img_rows, img_cols), dtype=torch.float32, device='cpu')
100+
101+
# Weight map: Shape (1, M, N)
102+
weight_map_gpu = create_weight_map(M, N, device)
103+
weight_map_cpu = weight_map_gpu.cpu()
104+
105+
# BATCH PROCESSING LOOP
106+
batch_patches = []
107+
batch_coords = []
108+
109+
total_steps = ((img_rows - 1) // stride + 1) * ((img_cols - 1) // stride + 1)
110+
processed_steps = 0
111+
112+
print(f" Starting inference on {img_rows}x{img_cols} image...")
113+
114+
for r in range(0, img_rows, stride):
115+
for c in range(0, img_cols, stride):
116+
r_end = min(r + M, img_rows)
117+
c_end = min(c + N, img_cols)
118+
119+
h_valid = r_end - r
120+
w_valid = c_end - c
121+
122+
# Extract patch from CPU tensor (View)
123+
patch = X[:, r:r_end, c:c_end]
124+
125+
# Handle Boundary Padding (On-the-fly)
126+
if h_valid < M or w_valid < N:
127+
pad_h = M - h_valid
128+
pad_w = N - w_valid
129+
patch = F.pad(patch, (0, pad_w, 0, pad_h), mode='constant', value=0)
130+
131+
batch_patches.append(patch)
132+
batch_coords.append((r, r_end, c, c_end, h_valid, w_valid))
133+
134+
# INFERENCE STEP
135+
if len(batch_patches) >= batch_size:
136+
batch_tensor = torch.stack(batch_patches).to(device)
137+
138+
with torch.no_grad():
139+
preds = model(batch_tensor).cpu()
140+
141+
# Accumulate
142+
for k, (r0, r1, c0, c1, h_val, w_val) in enumerate(batch_coords):
143+
pred_valid = preds[k, :, :h_val, :w_val]
144+
weight_valid = weight_map_cpu[:, :h_val, :w_val]
145+
146+
weighted_prediction_sum[:, r0:r1, c0:c1] += pred_valid * weight_valid
147+
weight_sum[:, r0:r1, c0:c1] += weight_valid
148+
149+
processed_steps += len(batch_patches)
150+
print(f" Progress: {processed_steps}/{total_steps} ({100*processed_steps/total_steps:.1f}%)", end='\r')
151+
152+
batch_patches = []
153+
batch_coords = []
154+
del batch_tensor, preds
155+
156+
if processed_steps % (batch_size * 10) == 0:
157+
torch.cuda.empty_cache()
158+
159+
# Process remaining patches
160+
if batch_patches:
161+
batch_tensor = torch.stack(batch_patches).to(device)
162+
with torch.no_grad():
163+
preds = model(batch_tensor).cpu()
164+
165+
for k, (r0, r1, c0, c1, h_val, w_val) in enumerate(batch_coords):
166+
pred_valid = preds[k, :, :h_val, :w_val]
167+
weight_valid = weight_map_cpu[:, :h_val, :w_val]
168+
weighted_prediction_sum[:, r0:r1, c0:c1] += pred_valid * weight_valid
169+
weight_sum[:, r0:r1, c0:c1] += weight_valid
170+
171+
print(f" Progress: 100% - Inference Complete.")
172+
173+
# NORMALIZE AND FINALIZE
174+
epsilon = 1e-8
175+
final_prediction = weighted_prediction_sum / (weight_sum + epsilon)
176+
final_prediction = (final_prediction > 0.01).float()
177+
178+
lf_idx = dataset.lf_index
179+
if isinstance(shape, torch.Tensor):
180+
lf = shape[lf_idx].unsqueeze(0)
181+
else:
182+
lf = None
183+
184+
return final_prediction, y, lf
185+
186+
# MAIN FUNCTION
187+
def enhanceFIM(huc_id, patch_size=(256, 256), batch_size=32):
188+
189+
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
190+
print(f"\n{'='*60}\nSYSTEM: {device_type.upper()}\n{'='*60}")
191+
127192
data_dir = Path(f'./HUC{huc_id}_forcings/')
128-
model = AttentionUNet(channel=8)
129-
193+
model = AttentionUNet(channel=8)
130194
preprocessor = InferenceDataPreprocessor(data_dir=Path(data_dir), patch_size=patch_size, verbose=True)
131-
195+
132196
print("Loading model...")
133197
model, device = load_model(model)
134-
print("Model loaded.")
135-
136-
198+
137199
lf_files = preprocessor.get_all_lf_maps(huc_id)
138-
for lf_path in lf_files:
200+
201+
for idx, lf_path in enumerate(lf_files, 1):
139202
lf_filename = lf_path.name
140-
print(f"Predicting for: {lf_filename}\n")
203+
print(f"\nProcessing [{idx}/{len(lf_files)}]: {lf_filename}")
141204

142-
print(f"Loading static features for HUC {huc_id}...")
143205
static_stack = preprocessor.get_static_stack(huc_id)
144206
lf_tensor = preprocessor.tif_to_tensor(lf_path, feature_name='low_fidelity')
145207

146-
# Combine and validate
208+
print("Merging tensors...")
147209
area_tensor = torch.cat([static_stack, lf_tensor], dim=0)
148-
if area_tensor.shape[0] != 8:
149-
raise ValueError(f"Expected 8 channels, got {area_tensor.shape[0]} — check missing static feature for HUC {huc_id}.")
150210

151-
# Define dummy interface
211+
del static_stack
212+
del lf_tensor
213+
gc.collect()
214+
215+
print(f"Tensor Shape: {area_tensor.shape} | Memory: {area_tensor.element_size() * area_tensor.nelement() / 1e9:.2f} GB")
216+
152217
class Dummy:
153-
x_feature_index = list(range(area_tensor.shape[0]))
154-
y_feature_index = [area_tensor.shape[0] - 1]
218+
x_feature_index = slice(None)
219+
y_feature_index = [area_tensor.shape[0] - 1]
155220
lf_index = area_tensor.shape[0] - 1
156221

157-
print(f"Static features loaded for {huc_id}.\n")
158-
159-
# Predict
160-
print(f"Enhancing {lf_path}...")
161-
x, y, lf = predict_on_area(Dummy, model, area_tensor, M=patch_size[0], N=patch_size[1], stride=patch_size[0] // 2, device=device)
222+
try:
223+
x, y, lf = predict_optimized(
224+
Dummy,
225+
model,
226+
area_tensor,
227+
M=patch_size[0],
228+
N=patch_size[1],
229+
stride=patch_size[0] // 2,
230+
device=device,
231+
batch_size=batch_size
232+
)
233+
234+
except RuntimeError as e:
235+
if "out of memory" in str(e):
236+
print("OOM Error. Retrying with batch_size=4...")
237+
torch.cuda.empty_cache()
238+
gc.collect()
239+
x, y, lf = predict_optimized(
240+
Dummy, model, area_tensor, M=patch_size[0], N=patch_size[1],
241+
stride=patch_size[0] // 2, device=device, batch_size=4
242+
)
243+
else:
244+
raise e
245+
246+
del area_tensor
247+
gc.collect()
248+
if device.type == 'cuda':
249+
torch.cuda.empty_cache()
162250

163-
# Save result
164251
pred_dir = Path(f"./Results/HUC{huc_id}/")
165252
pred_dir.mkdir(parents=True, exist_ok=True)
166253
pred_path = pred_dir / f"SMprediction_{lf_filename}"
167-
save_image(x, pred_path, lf_path)
168-
print(f"Enhancement completed for {lf_filename}.\n")
169-
170254

255+
save_image(x, pred_path, str(lf_path))
256+
print(f"✓ Saved: {pred_path}")
257+
258+
print(f"\n{'='*60}\nCOMPLETED\n{'='*60}")

0 commit comments

Comments
 (0)