Skip to content

Got 3 channels instead of 1 channels while trainning for MNIST-Style #84

@Jia-Bin

Description

@Jia-Bin

Hi:
while running: python train_alae.py -c configs/mnist.yaml

I got the following:
RuntimeError: Given groups=1, weight of size [256, 1, 1, 1], expected input[32, 3, 4, 4] to have 1 channels, but got 3 channels instead
Running on only one GPU

Log is as the following:
Adjusting learning rate of group 0 to 1.5000e-03.
Adjusting learning rate of group 1 to 1.5000e-03.
Adjusting learning rate of group 0 to 1.5000e-03.
Adjusting learning rate of group 1 to 1.5000e-03.
2023-04-11 09:44:07,335 logger INFO: Saving checkpoint to mnist_results/model_tmp_lod0.pth
2023-04-11 09:44:07,398 logger INFO:
[1/60] - ptime: 34.82, loss_d: 2.0644474, loss_g: 0.8213913, lae: 0.6135449, blend: 1.000, lr: 0.001500000000, 0.001500000000, max mem: 152.636719",
Traceback (most recent call last):
File "train_alae.py", line 348, in
world_size=gpu_count)
File "/home/jbhuang/MyWork/Python/GAN/ALAE/launcher.py", line 131, in run
_run(0, world_size, fn, defaults, write_log, no_cuda, args)
File "/home/jbhuang/MyWork/Python/GAN/ALAE/launcher.py", line 96, in _run
fn(**matching_args)
File "train_alae.py", line 337, in train
model.module if hasattr(model, "module") else model, cfg, encoder_optimizer, decoder_optimizer)
File "train_alae.py", line 64, in save_sample
Z, _ = model.encode(sample_in, lod2batch.lod, blend_factor)
File "/home/jbhuang/MyWork/Python/GAN/ALAE/model.py", line 109, in encode
Z = self.encoder(x, lod, blend_factor)
File "/home/jbhuang/work_p37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jbhuang/MyWork/Python/GAN/ALAE/net.py", line 345, in forward
return self.encode(x, lod)
File "/home/jbhuang/MyWork/Python/GAN/ALAE/net.py", line 311, in encode
x = self.from_rgbself.layer_count - lod - 1
File "/home/jbhuang/work_p37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jbhuang/MyWork/Python/GAN/ALAE/net.py", line 257, in forward
x = self.from_rgb(x)
File "/home/jbhuang/work_p37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jbhuang/MyWork/Python/GAN/ALAE/lreq.py", line 169, in forward
dilation=self.dilation, groups=self.groups)

Your any help/suggestion will be much appreciated! Thank you.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions