-
Notifications
You must be signed in to change notification settings - Fork 159
Expand file tree
/
Copy pathcomplexFunctions.py
More file actions
executable file
·220 lines (182 loc) · 5.86 KB
/
complexFunctions.py
File metadata and controls
executable file
·220 lines (182 loc) · 5.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: spopoff
"""
import torch
from torch.nn.functional import (
avg_pool2d,
dropout,
dropout2d,
interpolate,
max_pool2d,
relu,
sigmoid,
tanh,
softmax,
)
from torch.nn.functional import max_pool2d, avg_pool2d, dropout, dropout2d, interpolate
from torch import tanh, relu, sigmoid
def complex_matmul(A, B):
"""
Performs the matrix product between two complex matrices
"""
outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag)
outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real)
return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64)
def complex_avg_pool2d(inp, *args, **kwargs):
"""
Perform complex average pooling.
"""
absolute_value_real = avg_pool2d(inp.real, *args, **kwargs)
absolute_value_imag = avg_pool2d(inp.imag, *args, **kwargs)
return absolute_value_real.type(torch.complex64) + 1j * absolute_value_imag.type(
torch.complex64
)
def complex_normalize(inp):
"""
Perform complex normalization
"""
real_value, imag_value = inp.real, inp.imag
real_norm = (real_value - real_value.mean()) / real_value.std()
imag_norm = (imag_value - imag_value.mean()) / imag_value.std()
return real_norm.type(torch.complex64) + 1j * imag_norm.type(torch.complex64)
def complex_relu(inp):
return relu(inp.real).type(torch.complex64) + 1j * relu(inp.imag).type(
torch.complex64
)
def complex_sigmoid(inp):
return sigmoid(inp.real).type(torch.complex64) + 1j * sigmoid(inp.imag).type(
torch.complex64
)
def complex_tanh(inp):
return tanh(inp.real).type(torch.complex64) + 1j * tanh(inp.imag).type(
torch.complex64
)
def complex_opposite(inp):
return -inp.real.type(torch.complex64) + 1j * (-inp.imag.type(torch.complex64))
def complex_stack(inp, dim):
inp_real = [x.real for x in inp]
inp_imag = [x.imag for x in inp]
return torch.stack(inp_real, dim).type(torch.complex64) + 1j * torch.stack(
inp_imag, dim
).type(torch.complex64)
def _retrieve_elements_from_indices(tensor, indices):
flattened_tensor = tensor.flatten(start_dim=-2)
output = flattened_tensor.gather(
dim=-1, index=indices.flatten(start_dim=-2)
).view_as(indices)
return output
def complex_upsample(
inp,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
):
"""
Performs upsampling by separately interpolating the real and imaginary part and recombining
"""
outp_real = interpolate(
inp.real,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
)
outp_imag = interpolate(
inp.imag,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
)
return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64)
def complex_upsample2(
inp,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
):
"""
Performs upsampling by separately interpolating the amplitude and phase part and recombining
"""
outp_abs = interpolate(
inp.abs(),
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
)
angle = torch.atan2(inp.imag, inp.real)
outp_angle = interpolate(
angle,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
)
return outp_abs * (
torch.cos(outp_angle).type(torch.complex64)
+ 1j * torch.sin(outp_angle).type(torch.complex64)
)
def complex_max_pool2d(
inp,
kernel_size,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False,
):
"""
Perform complex max pooling by selecting on the absolute value on the complex values.
"""
absolute_value, indices = max_pool2d(
inp.abs(),
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode,
return_indices=True,
)
# performs the selection on the absolute values
absolute_value = absolute_value.type(torch.complex64)
# retrieve the corresponding phase value using the indices
# unfortunately, the derivative for 'angle' is not implemented
angle = torch.atan2(inp.imag, inp.real)
# get only the phase values selected by max pool
angle = _retrieve_elements_from_indices(angle, indices)
return absolute_value * (
torch.cos(angle).type(torch.complex64)
+ 1j * torch.sin(angle).type(torch.complex64)
)
def complex_dropout(inp, p=0.5, training=True):
# need to have the same dropout mask for real and imaginary part,
# this not a clean solution!
mask = torch.ones(*inp.shape, dtype=torch.float32, device=inp.device)
mask = dropout(mask, p, training) * 1 / (1 - p)
mask.type(inp.dtype)
return mask * inp
def complex_dropout2d(inp, p=0.5, training=True):
# need to have the same dropout mask for real and imaginary part,
# this not a clean solution!
mask = torch.ones(*inp.shape, dtype=torch.float32, device=inp.device)
mask = dropout2d(mask, p, training) * 1 / (1 - p)
mask.type(inp.dtype)
return mask * inp
def complex_softmax(inp, dim):
"""
Perform complex softmax.
"""
real_softmax = softmax(inp.real, dim=dim)
imag_softmax = softmax(inp.imag, dim=dim)
return real_softmax.type(torch.complex64) + 1j * imag_softmax.type(torch.complex64)