|
91 | 91 | }, |
92 | 92 | { |
93 | 93 | "cell_type": "code", |
94 | | - "execution_count": null, |
| 94 | + "execution_count": 1, |
95 | 95 | "id": "a15172b8", |
96 | 96 | "metadata": {}, |
97 | | - "outputs": [], |
| 97 | + "outputs": [ |
| 98 | + { |
| 99 | + "name": "stdout", |
| 100 | + "output_type": "stream", |
| 101 | + "text": [ |
| 102 | + "PyTorch 2.3.0.post100 | device: cpu\n", |
| 103 | + "Train: 60000 | Test: 10000\n", |
| 104 | + "Image shape: torch.Size([1, 28, 28])\n", |
| 105 | + "Loaders ready (num_workers=0)\n" |
| 106 | + ] |
| 107 | + } |
| 108 | + ], |
98 | 109 | "source": [ |
99 | 110 | "# ── Standard imports ──────────────────────────────────────────────────────────\n", |
100 | 111 | "import math, time, os\n", |
|
166 | 177 | }, |
167 | 178 | { |
168 | 179 | "cell_type": "code", |
169 | | - "execution_count": null, |
| 180 | + "execution_count": 2, |
170 | 181 | "id": "f3df2b1e", |
171 | 182 | "metadata": {}, |
172 | | - "outputs": [], |
| 183 | + "outputs": [ |
| 184 | + { |
| 185 | + "name": "stdout", |
| 186 | + "output_type": "stream", |
| 187 | + "text": [ |
| 188 | + "VAE parameters: 258,025\n", |
| 189 | + "VAE(\n", |
| 190 | + " (encoder): Sequential(\n", |
| 191 | + " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", |
| 192 | + " (1): LeakyReLU(negative_slope=0.2)\n", |
| 193 | + " (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", |
| 194 | + " (3): LeakyReLU(negative_slope=0.2)\n", |
| 195 | + " (4): Flatten(start_dim=1, end_dim=-1)\n", |
| 196 | + " )\n", |
| 197 | + " (fc_mu): Linear(in_features=3136, out_features=20, bias=True)\n", |
| 198 | + " (fc_log_var): Linear(in_features=3136, out_features=20, bias=True)\n", |
| 199 | + " (fc_dec): Linear(in_features=20, out_features=3136, bias=True)\n", |
| 200 | + " (decoder): Sequential(\n", |
| 201 | + " (0): Unflatten(dim=1, unflattened_size=(64, 7, 7))\n", |
| 202 | + " (1): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", |
| 203 | + " (2): LeakyReLU(negative_slope=0.2)\n", |
| 204 | + " (3): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", |
| 205 | + " )\n", |
| 206 | + ")\n" |
| 207 | + ] |
| 208 | + } |
| 209 | + ], |
173 | 210 | "source": [ |
174 | 211 | "# ── VAE Model ─────────────────────────────────────────────────────────────\n", |
175 | 212 | "class VAE(nn.Module):\n", |
|
277 | 314 | "execution_count": null, |
278 | 315 | "id": "409ab2e8", |
279 | 316 | "metadata": {}, |
280 | | - "outputs": [], |
| 317 | + "outputs": [ |
| 318 | + { |
| 319 | + "name": "stdout", |
| 320 | + "output_type": "stream", |
| 321 | + "text": [ |
| 322 | + "Training VAE on MNIST (20 epochs)...\n", |
| 323 | + "(Uses [0,1]-normalised pixels and Bernoulli decoder)\n" |
| 324 | + ] |
| 325 | + } |
| 326 | + ], |
281 | 327 | "source": [ |
282 | 328 | "def train_vae(model, loader, epochs=20, lr=1e-3, beta=1.0):\n", |
283 | 329 | " \"\"\"\n", |
|
0 commit comments