-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
137 lines (117 loc) · 5.68 KB
/
model.py
File metadata and controls
137 lines (117 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# flake8: noqa: E203
from typing import Any, Dict, Optional, Tuple
import numpy as np
import lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock3d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
super(ConvBlock3d, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
self.bn = nn.BatchNorm3d(out_channels)
self.relu = nn.LeakyReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
result: torch.Tensor = self.relu(self.bn(self.conv(x)))
return result
class UNet3d(pl.LightningModule):
def __init__(
self,
train_len_dataloader: int,
unique_blocks_dict: Dict[str, int],
latent_dim: int = 64,
unique_counts_coefficients: Optional[np.ndarray] = None,
):
super(UNet3d, self).__init__()
self.train_len_dataloader = train_len_dataloader
self.unique_blocks_dict = unique_blocks_dict
self.reverse_unique_blocks_dict = {v: k for k, v in unique_blocks_dict.items()}
self.latent_dim = latent_dim
if unique_counts_coefficients is None:
unique_counts_coefficients = np.ones(len(unique_blocks_dict))
self.unique_counts_coefficients = (
torch.from_numpy(unique_counts_coefficients).float().to("cuda" if torch.cuda.is_available() else "cpu")
)
self.conv_input = ConvBlock3d(1, 32)
self.conv1 = ConvBlock3d(32, 64)
self.conv2 = ConvBlock3d(64, 128)
self.conv3 = ConvBlock3d(128, 256)
self.conv4 = ConvBlock3d(256, 512)
self.conv6 = ConvBlock3d(512, 256)
self.conv7 = ConvBlock3d(256, 128)
self.conv8 = ConvBlock3d(128, 64)
self.conv9 = ConvBlock3d(64, 32)
self.conv_output = nn.Conv3d(32, len(unique_blocks_dict), kernel_size=3, padding=1)
def ml_core(self, x: torch.Tensor) -> torch.Tensor:
# Encode input
out_conv_input = self.conv_input(x)
out_conv_1 = self.conv1(out_conv_input)
out_conv_2 = self.conv2(out_conv_1)
out_conv_3 = self.conv3(out_conv_2)
out_conv_4 = self.conv4(out_conv_3)
# Decode input
out_conv_6 = self.conv6(out_conv_4) + out_conv_3
out_conv_7 = self.conv7(out_conv_6) + out_conv_2
out_conv_8 = self.conv8(out_conv_7) + out_conv_1
out_conv_9 = self.conv9(out_conv_8) + out_conv_input
out_conv_output: torch.Tensor = self.conv_output(out_conv_9)
return out_conv_output
def forward(self, x: torch.Tensor) -> torch.Tensor:
reconstruction = self.ml_core(x)
reconstruction = F.softmax(reconstruction, dim=1)
return reconstruction
def step(self, batch: Tuple[np.ndarray, np.ndarray, np.ndarray], batch_idx: int, mode: str) -> torch.Tensor:
block_maps, noisy_block_maps, masks = batch
pre_processed_block_maps = self.pre_process(block_maps)
pre_processed_noisy_block_maps = self.pre_process(noisy_block_maps).float().unsqueeze(1)
tensor_masks = torch.from_numpy(masks).float().to("cuda" if torch.cuda.is_available() else "cpu").long()
reconstruction = self.ml_core(pre_processed_noisy_block_maps)
# Compute accuracy
accuracy = (reconstruction.argmax(dim=1) == pre_processed_block_maps).float()
accuracy = accuracy * tensor_masks
accuracy = accuracy.mean()
# Compute reconstruction loss using categorical cross-entropy
reconstruction_loss = F.cross_entropy(reconstruction, pre_processed_block_maps, reduction="none")
reconstruction_loss = reconstruction_loss * tensor_masks
reconstruction_loss = reconstruction_loss * self.unique_counts_coefficients[pre_processed_block_maps]
reconstruction_loss = reconstruction_loss.mean()
# Total loss
loss = reconstruction_loss
loss_dict = {
"reconstruction_loss": reconstruction_loss,
"loss": loss,
"accuracy": accuracy,
"learning_rate": self.trainer.optimizers[0].param_groups[0]["lr"],
}
for name, value in loss_dict.items():
self.log(
f"{mode}_{name}",
value,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
batch_size=block_maps.shape[0],
)
return loss
def pre_process(self, x: np.ndarray) -> torch.Tensor:
vectorized_x = np.vectorize(lambda x: self.unique_blocks_dict.get(x, self.unique_blocks_dict["minecraft:air"]))(
x
)
vectorized_x = vectorized_x.astype(np.int64)
x_tensor = torch.from_numpy(vectorized_x)
x_tensor = x_tensor.to("cuda" if torch.cuda.is_available() else "cpu")
return x_tensor
def post_process(self, x: torch.Tensor) -> np.ndarray:
predicted_block_maps: np.ndarray = np.vectorize(self.reverse_unique_blocks_dict.get)(x.argmax(dim=1).numpy())
return predicted_block_maps
def training_step(self, batch: Tuple[np.ndarray, np.ndarray, np.ndarray], batch_idx: int) -> torch.Tensor:
return self.step(batch, batch_idx, "train")
def validation_step(self, batch: Tuple[np.ndarray, np.ndarray, np.ndarray], batch_idx: int) -> torch.Tensor:
return self.step(batch, batch_idx, "val")
def configure_optimizers(self) -> Any:
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=self.train_len_dataloader)
return [optimizer], [scheduler]
def on_train_start(self) -> None:
print(self)