Skip to content

Commit b1f2ad3

Browse files
authored
Merge pull request #24 from SamChou05/vae
Adding docstrings for pt_models
2 parents 730b16e + 25af569 commit b1f2ad3

1 file changed

Lines changed: 128 additions & 11 deletions

File tree

afqinsight/nn/pt_models.py

Lines changed: 128 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,19 @@ def forward(self, x):
704704

705705

706706
class VariationalAutoencoder(nn.Module):
707+
"""
708+
Variational Autoencoder (VAE) model.
709+
710+
Parameters
711+
----------
712+
input_shape : int
713+
The number of features in the input data.
714+
latent_dims : int
715+
The number of dimensions in the latent space.
716+
dropout : float
717+
The dropout rate.
718+
"""
719+
707720
def __init__(self, input_shape=100, latent_dims=20, dropout=0.2):
708721
super().__init__()
709722
self.encoder = VariationalEncoder(input_shape, latent_dims, dropout=dropout)
@@ -716,7 +729,46 @@ def __init__(self, input_shape=100, latent_dims=20, dropout=0.2):
716729
else "cpu"
717730
)
718731

732+
def reparameterize(self, mean, logvar):
733+
"""
734+
Reparameterization trick to separate random
735+
and deterministic parts of the latent space.
736+
737+
Parameters
738+
----------
739+
mean : torch.Tensor
740+
The mean of the latent space.
741+
logvar : torch.Tensor
742+
The log variance of the latent space.
743+
744+
Returns
745+
-------
746+
z : torch.Tensor
747+
The reparameterized latent space.
748+
"""
749+
std = torch.exp(0.5 * logvar)
750+
eps = torch.randn_like(std)
751+
z = mean + eps * std
752+
return z
753+
719754
def forward(self, x):
755+
"""
756+
Forward pass of the VAE model.
757+
758+
Parameters
759+
----------
760+
x : torch.Tensor
761+
The input data.
762+
763+
Returns
764+
-------
765+
x_hat: torch.Tensor
766+
The reconstructed data.
767+
mean: torch.Tensor
768+
The mean of the latent space.
769+
logvar: torch.Tensor
770+
The log variance of the latent space.
771+
"""
720772
mean, logvar = self.encoder(x)
721773

722774
z = self.reparameterize(mean, logvar)
@@ -725,11 +777,6 @@ def forward(self, x):
725777

726778
return x_hat, mean, logvar
727779

728-
def reparameterize(self, mean, logvar):
729-
std = torch.exp(0.5 * logvar)
730-
eps = torch.randn_like(std)
731-
return mean + eps * std
732-
733780
def fit(self, train_data, epochs=500, lr=0.001, kl_weight=0.001):
734781
self.train()
735782
opt = torch.optim.Adam(self.parameters(), lr=lr)
@@ -829,6 +876,19 @@ def fit_transform(self, data, epochs=20, kl_weight=0.001):
829876

830877

831878
class Autoencoder(nn.Module):
879+
"""
880+
Autoencoder model.
881+
882+
Parameters
883+
----------
884+
input_shape : int
885+
The number of features in the input data.
886+
latent_dims : int
887+
The number of dimensions in the latent space.
888+
dropout : float
889+
The dropout rate.
890+
"""
891+
832892
def __init__(self, input_shape=100, latent_dims=20, dropout=0.2):
833893
super().__init__()
834894
self.encoder = Encoder(input_shape, latent_dims, dropout=dropout)
@@ -930,6 +990,19 @@ def fit_transform(self, data, epochs=20):
930990

931991

932992
class Conv1DVariationalAutoencoder(nn.Module):
993+
"""
994+
Convolutional Variational Autoencoder (VAE) model.
995+
996+
Parameters
997+
----------
998+
num_tracts : int
999+
The number of tracts in the input data.
1000+
latent_dims : int
1001+
The number of dimensions in the latent space.
1002+
dropout : float
1003+
The dropout rate.
1004+
"""
1005+
9331006
def __init__(self, num_tracts=48, latent_dims=20, dropout=0.2):
9341007
super().__init__()
9351008
self.encoder = Conv1DVariationalEncoder(num_tracts, latent_dims, dropout)
@@ -943,21 +1016,52 @@ def __init__(self, num_tracts=48, latent_dims=20, dropout=0.2):
9431016
)
9441017

9451018
def reparameterize(self, mean, logvar):
1019+
"""
1020+
Reparameterization trick to separate random and
1021+
deterministic parts of the latent space.
1022+
1023+
Parameters
1024+
----------
1025+
mean : torch.Tensor
1026+
The mean of the latent space.
1027+
logvar : torch.Tensor
1028+
The log variance of the latent space.
1029+
1030+
Returns
1031+
-------
1032+
z : torch.Tensor
1033+
The reparameterized latent space.
1034+
"""
9461035
std = torch.exp(0.5 * logvar)
9471036
eps = torch.randn_like(std)
9481037
z = mean + eps * std
9491038
return z
9501039

9511040
def forward(self, x):
952-
(
953-
mean,
954-
logvar,
955-
) = self.encoder(x)
1041+
"""
1042+
Forward pass of the Convolutional VAE model.
1043+
1044+
Parameters
1045+
----------
1046+
x : torch.Tensor
1047+
The input data.
1048+
1049+
Returns
1050+
-------
1051+
x_hat: torch.Tensor
1052+
The reconstructed data.
1053+
mean: torch.Tensor
1054+
The mean of the latent space.
1055+
logvar: torch.Tensor
1056+
The log variance of the latent space.
1057+
"""
1058+
1059+
mean, logvar = self.encoder(x)
9561060

9571061
z = self.reparameterize(mean, logvar)
9581062

959-
x_prime = self.decoder(z)
960-
return x_prime, mean, logvar
1063+
x_hat = self.decoder(z)
1064+
return x_hat, mean, logvar
9611065

9621066
def fit(self, train_data, epochs=500, lr=0.001, kl_weight=0.001):
9631067
self.train()
@@ -1056,6 +1160,19 @@ def fit_transform(self, data, epochs=20, kl_weight=0.001):
10561160

10571161

10581162
class Conv1DAutoencoder(nn.Module):
1163+
"""
1164+
Convolutional Autoencoder model.
1165+
1166+
Parameters
1167+
----------
1168+
num_tracts : int
1169+
The number of tracts in the input data.
1170+
latent_dims : int
1171+
The number of dimensions in the latent space.
1172+
dropout : float
1173+
The dropout rate.
1174+
"""
1175+
10591176
def __init__(self, num_tracts=48, latent_dims=20, dropout=0.2):
10601177
super().__init__()
10611178
self.encoder = Conv1DEncoder(num_tracts, latent_dims, dropout)

0 commit comments

Comments
 (0)