diff --git a/models/swin_transformer.py b/models/swin_transformer.py index dde06bc5..bb36cc4d 100644 --- a/models/swin_transformer.py +++ b/models/swin_transformer.py @@ -68,9 +68,9 @@ def window_reverse(windows, window_size, H, W): Returns: x: (B, H, W, C) """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + C = int(windows.shape[-1]) + x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x