-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathloss.py
More file actions
30 lines (22 loc) · 676 Bytes
/
loss.py
File metadata and controls
30 lines (22 loc) · 676 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Author: Vlad Niculae <vlad@vene.ro>
# License: BSD 3 clause
import numpy as np
def squared_loss(y_true, y_pred, return_derivative=False):
diff = y_pred - y_true
obj = 0.5 * np.dot(diff, diff)
if return_derivative:
return obj, diff
else:
return obj
def squared_hinge_loss(y_true, y_scores, return_derivative=False):
# labels in (-1, 1)
z = np.maximum(0, 1 - y_true * y_scores)
obj = np.sum(z ** 2)
if return_derivative:
return obj, -2 * y_true * z
else:
return obj
def get_loss(name):
losses = {'squared': squared_loss,
'squared-hinge': squared_hinge_loss}
return losses[name]