-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpeekVAE.m
More file actions
100 lines (69 loc) · 2.44 KB
/
peekVAE.m
File metadata and controls
100 lines (69 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
function peekVAE(encoderNet,decoderNet,XTest,nRecons)
%UNTITLED Summary of this function goes here
visualizeReconstruction(XTest,nRecons,encoderNet, decoderNet);
visualizeLatentSpace(XTest, encoderNet);
end
function [zSampled, zMean, zLogvar] = sampling(encoderNet, x)
compressed = forward(encoderNet, x);
d = size(compressed,1)/2;
zMean = compressed(1:d,:);
zLogvar = compressed(1+d:end,:);
sz = size(zMean);
epsilon = randn(sz); % get rand normally distributed (will be standard z)
sigma = exp(.5 * zLogvar); % get your variance
z = epsilon .* sigma + zMean; % basically reshape your z distrib
z = reshape(z, [1,1,sz]); % reshape these variables by dimension
% this basically will allow you to generate a sample of data points, so
% that you can get a conditional distribution of real values, given a
% distribution of answers, this is the bayes part
zSampled = dlarray(z, 'SSCB'); % send into your dlarray
end
function visualizeReconstruction(XTest,nRecons, encoderNet, decoderNet)
for c=1:nRecons
idx = randi(size(XTest,4),1); % pull random
X = XTest(:,:,:,idx);
[z, ~, ~] = sampling(encoderNet, X);
XPred = sigmoid(forward(decoderNet, z));
X = gather(extractdata(X));
XPred = gather(extractdata(XPred));
comparison = [X, ones(size(X,1),1), XPred];
figure; imshow(comparison,[]), title("Example ground truth image vs. reconstructed image")
end
end
function visualizeLatentSpace(XTest, encoderNet)
[~, zMean, zLogvar] = sampling(encoderNet, XTest);
zMean = stripdims(zMean)';
zMean = gather(extractdata(zMean));
zLogvar = stripdims(zLogvar)';
zLogvar = gather(extractdata(zLogvar));
[~,scoreMean] = pca(zMean);
[~,scoreLogvar] = pca(zLogvar);
c = parula(10);
f1 = figure;
figure(f1)
title("Latent space")
ah = subplot(1,2,1);
scatter(scoreMean(:,2),scoreMean(:,1),[]);
ah.YDir = 'reverse';
axis equal
xlabel("Z_m_u(2)")
ylabel("Z_m_u(1)")
ah = subplot(1,2,2);
scatter(scoreLogvar(:,2),scoreLogvar(:,1),[]);
ah.YDir = 'reverse';
xlabel("Z_v_a_r(2)")
ylabel("Z_v_a_r(1)")
axis equal
end
% inputs- decoderNet, and latentDim is number of latent dimensions, and
% this produces 25 images
function generate(decoderNet, latentDim)
randomNoise = dlarray(randn(1,1,latentDim,1),'SSCB');
generatedImage = sigmoid(predict(decoderNet, randomNoise));
generatedImage = extractdata(generatedImage);
f3 = figure;
figure(f3)
imshow(imtile(generatedImage, "ThumbnailSize", [100,100]))
title("Generated random samples")
drawnow
end