-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvis_noise.py
More file actions
28 lines (24 loc) · 1009 Bytes
/
vis_noise.py
File metadata and controls
28 lines (24 loc) · 1009 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""
script for visualizing mnist & cifar with noise added
"""
from main import get_split
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import numpy as np
import torch
import os
noise_lvls = [0, 0.1, 0.2, 0.3, 0.4, 0.5]
# noise type is randomly set pixels to 2 * st. dev. px value wrt whole image
# i.e. a value corresponding to near white / full intensity
two_sd = 0
for i, noise_lvl in enumerate(noise_lvls):
# _, test_data = get_split("mnist", noise_type=two_sd, noise_lvl=noise_lvl)
# less noise_lvl needed for cifar10 to be a similar classification difficulty for human a>
_, test_data = get_split("cifar10", noise_type=two_sd, noise_lvl=noise_lvl/3)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
for _, (img, target) in enumerate(test_loader):
if _ != 10 and _ != 29:
continue
if _ > 29:
break
save_image(img, f"cifar10_{i}_{_}.png")