-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathrandom_diffusion.py
More file actions
224 lines (172 loc) · 9.47 KB
/
random_diffusion.py
File metadata and controls
224 lines (172 loc) · 9.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import sys
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as T
import time
class RandomDiffusion:
def __init__(self, trainer, patch_size=512, sampling_steps=50, cond_scale=3.0, device='cuda:0'):
self.vae=trainer.vae.to(device)
self.ema_model=trainer.ema.ema_model.to(device)
self.sampling_steps=sampling_steps
self.cond_scale=cond_scale+1
self.patch_size=patch_size
self.device=device
times = torch.linspace(-1, 1000 - 1, steps=self.sampling_steps + 1).to(device)
times = list(reversed(times.int().tolist()))
self.time_pairs = list(zip(times[:-1], times[1:]))
@torch.no_grad()
def encode(self, images, labels):
return self.vae.encode(images).latent_dist.sample()/50
@torch.no_grad()
def decode(self, z):
return torch.clip(self.vae.decode(z).sample,0,1)
@torch.no_grad()
def sample_one(self, x_t, x_stack, masks, times):
img=x_t.clone()
uniques=torch.unique(times)
vmin=uniques[0]
if len(uniques)>1:
for unique in uniques[1:]:
to_change=torch.where(times==unique, 1,0)
x_t=x_stack[:,0]*to_change+x_t*(to_change==0)
# Downsample masks for latent space
masks=masks.float()
mask_min,mask_max=masks.min(),masks.max()
masks=T.Lambda(lambda x: torch.nn.functional.interpolate(x,size=x_t.shape[-1]))(masks).repeat(1,1,1,1)
masks=torch.clamp(torch.round(masks),mask_min,mask_max)
masks=torch.where(masks == -1., 0., masks)
t,t_next=self.time_pairs[vmin][0], self.time_pairs[vmin][1]
pred = self.ema_model.ddim_multimask(x_t, masks, t, t_next, cond_scale=self.cond_scale)
time_mask=torch.where(times==vmin, 1,0).to(self.device)
return pred*time_mask+img*(time_mask==0)
def get_value_coordinates(self,tensor):
value_indices = torch.nonzero(tensor == tensor.min(), as_tuple=False)
random_indices = value_indices[torch.randperm(value_indices.size(0))]
return random_indices
def random_crop(self, image, i, j, latent=True):
if latent:
p=self.patch_size // 16
return image[...,i-p:i+p, j-p:j+p]
else:
p=self.patch_size // 2
return image[...,i*8-p:i*8+p, j*8-p:j*8+p]
def decoding_tiled_image(self, z, size):
p8=self.patch_size//8
img = [self.decode(t[None].cuda()) for t in tile(z, p8)]
return untile(torch.cat(img, 0), size, self.patch_size)
def hann_tile_overlap(self, z):
assert z.shape[-1]%(self.patch_size//8)==0, 'For simplicity, the code is implemented to generate images that are multiples of the (patch size//8) in the latent. Crop z to a multiple of the (patch size//8)'
windows = hann_window(self.patch_size)
b, c, h, w = z.shape
z_scaled = z.clone() * 50
p = self.patch_size
p16 = self.patch_size // 16
# Full image decoding
img = self.decoding_tiled_image(z_scaled, (1, 3, h * 8, w * 8))
# Vertical slice decoding
img_v = self.decoding_tiled_image(z_scaled[..., p16:-p16], (1, 3, h * 8, w * 8 - p))
# Horizontal slice decoding
img_h = self.decoding_tiled_image(z_scaled[..., p16:-p16, :], (1, 3, h * 8 - p, w * 8))
# Cross slice decoding
img_cross = self.decoding_tiled_image(z_scaled[..., p16:-p16, p16:-p16],
(1, 3, h * 8 - p, w * 8 - p))
# Applying windows
b, c, h, w = img.shape
repeat_v = windows['vertical'].repeat(b, c, h // p, w // p - 1).to(self.device)
repeat_h = windows['horizontal'].repeat(b, c, h // p - 1, w // p).to(self.device)
repeat_c = windows['center'].repeat(b, c, h // p - 1, w // p - 1).to(self.device)
img[..., p//2:-p//2] = img[..., p//2:-p//2] * (1 - repeat_v) + img_v * repeat_v
img[..., p//2:-p//2, :] = img[..., p//2:-p//2, :] * (1 - repeat_h) + img_h * repeat_h
img[..., p//2:-p//2, p//2:-p//2] = img[..., p//2:-p//2, p//2:-p//2] * (1 - repeat_c) + img_cross * repeat_c
return img
@torch.no_grad()
def __call__(self, masks):
b,c,_,image_size=masks.shape
masks=masks.to(self.device)
img_stack=torch.randn(b, 2, 4, image_size//8, image_size//8).to(self.device)
times=torch.zeros((1,1,image_size//8,image_size//8)).int().to(self.device)
img0=torch.randn(b,4,image_size//8,image_size//8).to(self.device)
p=self.patch_size//16
s=image_size//8
while times.float().mean()!=(self.sampling_steps-1):
sys.stdout.flush()
random_indices=self.get_value_coordinates(times[0,0])[0]
i,j=torch.clamp(random_indices,p,s-p).tolist()
print(f"\r Generation {times.float().mean()*100/(self.sampling_steps-1):.2f}%, indices=[{(i-p)*8:04d}:{(i+p)*8:04d},{(j-p)*8:04d}:{(j+p)*8:04d}]",end="")
sub_img=self.random_crop(img0, i, j)
sub_img_stack=self.random_crop(img_stack, i, j)
sub_time=self.random_crop(times, i, j)
sub_mask=self.random_crop(masks, i, j, latent=False)
if sub_time.float().mean()!=(self.sampling_steps-1):
sub_img=self.sample_one(sub_img, sub_img_stack, sub_mask, sub_time)
mask_changed=torch.where(sub_time==sub_time.min(), 1 ,0).to(self.device)
img0[...,i-p:i+p,j-p:j+p]=sub_img
img_stack[:,1,:,i-p:i+p,j-p:j+p]=sub_img*mask_changed+sub_img_stack[:,1]*(mask_changed==0)
times[...,i-p:i+p,j-p:j+p]=torch.where(sub_time==sub_time.min(), sub_time+1, sub_time)
if torch.all(times == times.max()):
img_stack[:,0]=img_stack[:,1]
return img0
# Helpers
def tile(x, p):
B, C, H, W = x.shape
x_tiled = x.unfold(2, p, p).unfold(3, p, p)
x_tiled = x_tiled.reshape(B, C, -1, p, p)
x_tiled = x_tiled.permute(0, 2, 1, 3, 4).contiguous()
x_tiled = x_tiled.reshape(-1, C, p, p)
return x_tiled
def untile(x_tiled, original_shape, p):
B, C, H, W = original_shape
H_p = H // p
W_p = W // p
x_tiled=x_tiled.reshape(B, (H//p*W//p), C, p,p).permute(0,2,1,3,4).reshape(B,C,H//p,W//p,p,p)
x_untiled=x_tiled.permute(0,1,2,4,3,5).reshape(B,C,H,W)
return x_untiled
def save_tensor_as_png(tensor, filename):
# Convert the tensor to a NumPy array
tensor_np = tensor.permute(1,2,0).cpu().numpy()
# Convert the NumPy array to an 8-bit unsigned integer array
tensor_uint8 = (tensor_np * 255).astype(np.uint8)
# Create a PIL image from the 8-bit unsigned integer array
image = Image.fromarray(tensor_uint8)
# Save the image as a PNG file
image.save(filename, format='PNG', optimize=True, compress_level=9)
def corners(subwindows: list):
(w_upleft,w_upright,w_downright,w_downleft)=subwindows
window=torch.ones_like(w_upleft)
size=window.shape[0]
window[:size//2,:size//2]=w_upleft[:size//2,:size//2]
window[:size//2,size//2:]=w_upright[:size//2,size//2:]
window[size//2:,size//2:]=w_downright[size//2:,size//2:]
window[size//2:,:size//2]=w_downleft[size//2:,:size//2]
return window
def hann_window(size=512):
i = torch.arange(size, dtype=torch.float)
w = 0.5*(1 - torch.cos(2*torch.pi*i/(size-1)))
window_center=w[:,None]*w
window_up=torch.where(torch.arange(size)[:,None] < size//2, w, w[:,None]*w)
window_down=torch.where(torch.arange(size)[:,None] > size//2, w, w[:,None]*w)
window_right=torch.where(torch.arange(size) > size//2, w[:,None], w[:,None]*w)
window_left=torch.where(torch.arange(size) < size//2, w[:,None], w[:,None]*w)
window_upleft=corners([torch.ones((size,size)), window_up, window_center,window_left])
window_upright=corners([window_up, torch.ones((size,size)), window_right, window_center])
window_downright=corners([window_center, window_right,torch.ones((size,size)), window_down])
window_downleft=corners([window_left, window_center, window_down, torch.ones((size,size))])
window_rightright=corners([torch.ones((size,size)), window_up, window_down, torch.ones((size,size))])
window_leftleft=corners([window_up, torch.ones((size,size)), torch.ones((size,size)), window_down])
window_upup=corners([window_left, window_right, torch.ones((size,size)), torch.ones((size,size))])
window_downdown=corners([torch.ones((size,size)), torch.ones((size,size)), window_right, window_left])
window_vertical=corners([window_up, window_up, window_down, window_down])
window_horizontal=corners([window_left, window_right, window_right, window_left])
return {'up-left': window_upleft, 'up': window_up, 'up-right': window_upright,
'left': window_left, 'center': window_center, 'right': window_right,
'down-left': window_downleft, 'down': window_down, 'down-right': window_downright,
'up-up': window_upup, 'down-down': window_downdown,
'left-left': window_leftleft, 'right-right': window_rightright,
'vertical': window_vertical, 'horizontal': window_horizontal}
def hann_tensors(t1,t2, windows):
n_windows=len(windows.keys())-6
t_mix=torch.zeros_like(t1).to(t1.device)
for key in windows.keys()[:9]:
t_mix=t1+windows[key].repeat(1,3,1,1)*t2
return t_mix/n_windows