-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSquash_s.py
More file actions
24 lines (22 loc) · 1.04 KB
/
Squash_s.py
File metadata and controls
24 lines (22 loc) · 1.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import tensorflow as tf
import numpy as np
# **********************************************
def squash(s, axis=-1, epsilon=1e-7, name=None):
with tf.name_scope(name, default_name="squash"):
squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
keep_dims=True)
safe_norm = tf.sqrt(squared_norm + epsilon)
squash_factor = squared_norm / (1. + squared_norm)
unit_vector = s / safe_norm
return squash_factor * unit_vector
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False, name=None):
with tf.name_scope(name, default_name="safe_norm"):
squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
keep_dims=keep_dims)
return tf.sqrt(squared_norm + epsilon)
def squash_arr(s, axis=-1, epsilon=1e-7):
squared_norm = np.sum(np.square(s), axis=axis)
safe_norm = np.sqrt(squared_norm + epsilon)
squash_factor = squared_norm / (1. + squared_norm)
unit_vector = s / safe_norm
return squash_factor * unit_vector