@@ -704,6 +704,19 @@ def forward(self, x):
704704
705705
706706class VariationalAutoencoder (nn .Module ):
707+ """
708+ Variational Autoencoder (VAE) model.
709+
710+ Parameters
711+ ----------
712+ input_shape : int
713+ The number of features in the input data.
714+ latent_dims : int
715+ The number of dimensions in the latent space.
716+ dropout : float
717+ The dropout rate.
718+ """
719+
707720 def __init__ (self , input_shape = 100 , latent_dims = 20 , dropout = 0.2 ):
708721 super ().__init__ ()
709722 self .encoder = VariationalEncoder (input_shape , latent_dims , dropout = dropout )
@@ -716,7 +729,46 @@ def __init__(self, input_shape=100, latent_dims=20, dropout=0.2):
716729 else "cpu"
717730 )
718731
732+ def reparameterize (self , mean , logvar ):
733+ """
734+ Reparameterization trick to separate random
735+ and deterministic parts of the latent space.
736+
737+ Parameters
738+ ----------
739+ mean : torch.Tensor
740+ The mean of the latent space.
741+ logvar : torch.Tensor
742+ The log variance of the latent space.
743+
744+ Returns
745+ -------
746+ z : torch.Tensor
747+ The reparameterized latent space.
748+ """
749+ std = torch .exp (0.5 * logvar )
750+ eps = torch .randn_like (std )
751+ z = mean + eps * std
752+ return z
753+
719754 def forward (self , x ):
755+ """
756+ Forward pass of the VAE model.
757+
758+ Parameters
759+ ----------
760+ x : torch.Tensor
761+ The input data.
762+
763+ Returns
764+ -------
765+ x_hat: torch.Tensor
766+ The reconstructed data.
767+ mean: torch.Tensor
768+ The mean of the latent space.
769+ logvar: torch.Tensor
770+ The log variance of the latent space.
771+ """
720772 mean , logvar = self .encoder (x )
721773
722774 z = self .reparameterize (mean , logvar )
@@ -725,11 +777,6 @@ def forward(self, x):
725777
726778 return x_hat , mean , logvar
727779
728- def reparameterize (self , mean , logvar ):
729- std = torch .exp (0.5 * logvar )
730- eps = torch .randn_like (std )
731- return mean + eps * std
732-
733780 def fit (self , train_data , epochs = 500 , lr = 0.001 , kl_weight = 0.001 ):
734781 self .train ()
735782 opt = torch .optim .Adam (self .parameters (), lr = lr )
@@ -829,6 +876,19 @@ def fit_transform(self, data, epochs=20, kl_weight=0.001):
829876
830877
831878class Autoencoder (nn .Module ):
879+ """
880+ Autoencoder model.
881+
882+ Parameters
883+ ----------
884+ input_shape : int
885+ The number of features in the input data.
886+ latent_dims : int
887+ The number of dimensions in the latent space.
888+ dropout : float
889+ The dropout rate.
890+ """
891+
832892 def __init__ (self , input_shape = 100 , latent_dims = 20 , dropout = 0.2 ):
833893 super ().__init__ ()
834894 self .encoder = Encoder (input_shape , latent_dims , dropout = dropout )
@@ -930,6 +990,19 @@ def fit_transform(self, data, epochs=20):
930990
931991
932992class Conv1DVariationalAutoencoder (nn .Module ):
993+ """
994+ Convolutional Variational Autoencoder (VAE) model.
995+
996+ Parameters
997+ ----------
998+ num_tracts : int
999+ The number of tracts in the input data.
1000+ latent_dims : int
1001+ The number of dimensions in the latent space.
1002+ dropout : float
1003+ The dropout rate.
1004+ """
1005+
9331006 def __init__ (self , num_tracts = 48 , latent_dims = 20 , dropout = 0.2 ):
9341007 super ().__init__ ()
9351008 self .encoder = Conv1DVariationalEncoder (num_tracts , latent_dims , dropout )
@@ -943,21 +1016,52 @@ def __init__(self, num_tracts=48, latent_dims=20, dropout=0.2):
9431016 )
9441017
9451018 def reparameterize (self , mean , logvar ):
1019+ """
1020+ Reparameterization trick to separate random and
1021+ deterministic parts of the latent space.
1022+
1023+ Parameters
1024+ ----------
1025+ mean : torch.Tensor
1026+ The mean of the latent space.
1027+ logvar : torch.Tensor
1028+ The log variance of the latent space.
1029+
1030+ Returns
1031+ -------
1032+ z : torch.Tensor
1033+ The reparameterized latent space.
1034+ """
9461035 std = torch .exp (0.5 * logvar )
9471036 eps = torch .randn_like (std )
9481037 z = mean + eps * std
9491038 return z
9501039
9511040 def forward (self , x ):
952- (
953- mean ,
954- logvar ,
955- ) = self .encoder (x )
1041+ """
1042+ Forward pass of the Convolutional VAE model.
1043+
1044+ Parameters
1045+ ----------
1046+ x : torch.Tensor
1047+ The input data.
1048+
1049+ Returns
1050+ -------
1051+ x_hat: torch.Tensor
1052+ The reconstructed data.
1053+ mean: torch.Tensor
1054+ The mean of the latent space.
1055+ logvar: torch.Tensor
1056+ The log variance of the latent space.
1057+ """
1058+
1059+ mean , logvar = self .encoder (x )
9561060
9571061 z = self .reparameterize (mean , logvar )
9581062
959- x_prime = self .decoder (z )
960- return x_prime , mean , logvar
1063+ x_hat = self .decoder (z )
1064+ return x_hat , mean , logvar
9611065
9621066 def fit (self , train_data , epochs = 500 , lr = 0.001 , kl_weight = 0.001 ):
9631067 self .train ()
@@ -1056,6 +1160,19 @@ def fit_transform(self, data, epochs=20, kl_weight=0.001):
10561160
10571161
10581162class Conv1DAutoencoder (nn .Module ):
1163+ """
1164+ Convolutional Autoencoder model.
1165+
1166+ Parameters
1167+ ----------
1168+ num_tracts : int
1169+ The number of tracts in the input data.
1170+ latent_dims : int
1171+ The number of dimensions in the latent space.
1172+ dropout : float
1173+ The dropout rate.
1174+ """
1175+
10591176 def __init__ (self , num_tracts = 48 , latent_dims = 20 , dropout = 0.2 ):
10601177 super ().__init__ ()
10611178 self .encoder = Conv1DEncoder (num_tracts , latent_dims , dropout )
0 commit comments