-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloss.py
More file actions
77 lines (54 loc) · 2.72 KB
/
loss.py
File metadata and controls
77 lines (54 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
'''
Copyright (c) 2025 Bashayer Abdallah
Licensed under CC BY-NC 4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
Commercial use is prohibited.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
class siLogLoss(nn.Module):
def __init__(self):
super(SILogLoss, self).__init__()
self.name = 'SILog'
def forward(self, input, target, interpolate=True):
# Add epsilon to avoid log(0) or division by zero
eps = 1e-6
if interpolate:
input = nn.functional.interpolate(input, target.shape[-2:], mode='bilinear', align_corners=True)
# Use log(input + eps) to prevent log(0) from occurring
g = torch.log(input + eps) - torch.log(target + eps)
Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
# Also add an epsilon here in sqrt() to prevent sqrt(0) from becoming unstable
return 10 * torch.sqrt(Dg + eps)
class berHuLoss(nn.Module):
def __init__(self, threshold_ratio=0.2): #The threshold ratio determines the point at which the loss switches from L1 to L2.
super(BerHuLoss, self).__init__()
self.threshold_ratio = threshold_ratio
def forward(self, input, target):
# Calculate the absolute difference
eps = 1e-6 # Small value to avoid numerical issues
diff = torch.abs(input - target)
# Calculate the threshold value 'c' based on the ratio and the max error
max_diff = torch.max(diff).item()
c = max(self.threshold_ratio * max_diff, eps) # Add epsilon to avoid c being zero
# Apply the BerHu loss calculation
mask = diff <= c
l1_part = diff[mask] # L1 for smaller errors
l2_part = (diff[~mask] ** 2 + c ** 2) / (
2 * c + eps) # Add epsilon in denominator to avoid division by zero
# Combine L1 and L2 losses
loss = torch.cat([l1_part, l2_part]).mean()
return loss
class edgeLoss(nn.Module):
def forward(self, edge_map_pred, edge_map_gt):
eps = 1e-6
# Add epsilon to both prediction and ground truth to ensure no zero-values lead to NaNs
return F.mse_loss(edge_map_pred + eps, edge_map_gt + eps)
# From MDE-ED (https://ieeexplore.ieee.org/document/11084697)
def edge_loss(pred_depth, gt_depth):
grad_pred = torch.abs(pred_depth[:, :, 1:, :] - pred_depth[:, :, :-1, :]) + \
torch.abs(pred_depth[:, :, :, 1:] - pred_depth[:, :, :, :-1])
grad_gt = torch.abs(gt_depth[:, :, 1:, :] - gt_depth[:, :, :-1, :]) + \
torch.abs(gt_depth[:, :, :, 1:] - gt_depth[:, :, :, :-1])
edge_loss = F.l1_loss(grad_pred, grad_gt)
return edge_loss