Bug Report for https://neetcode.io/problems/weight-initialization
Expected output based on when randn is being called.
The structure required was not clearly defined in the question and prone to human error.
Causes a lot of headache, frustration and wild goose chase.
using your provided solution produces the following (expected)
[0.72,0.82,0.7,0.68,0.5]
def check_activations(
self, num_layers: int, input_dim: int, hidden_dim: int, init_type: str
) -> List[float]:
torch.manual_seed(0)
dims = [input_dim] + [hidden_dim] * num_layers
weights = []
for i in range(num_layers):
if init_type == "kaiming":
std = math.sqrt(2.0 / dims[i])
elif init_type == "xavier":
std = math.sqrt(2.0 / (dims[i] + dims[i + 1]))
elif init_type == "random":
std = 1.0 # plain N(0,1), no scaling
w = torch.randn(dims[i+1], dims[i]) * std
weights.append(w)
x = torch.randn(1, input_dim)
stds = []
for w in weights:
x = x @ w.T
x = torch.relu(x)
stds.append(round(x.std().item(), 2))
return stds
refactoring the same code to as such produces a completely different result
[0.88,1,1.02,0.77,0.72]
def check_activations(
self, num_layers: int, input_dim: int, hidden_dim: int, init_type: str
) -> List[float]:
torch.manual_seed(0)
dims = [input_dim] + [hidden_dim] * num_layers
weights = []
x = torch.randn(1, input_dim)
stds = []
for i in range(num_layers):
if init_type == "kaiming":
std = math.sqrt(2.0 / dims[i])
elif init_type == "xavier":
std = math.sqrt(2.0 / (dims[i] + dims[i + 1]))
elif init_type == "random":
std = 1.0 # plain N(0,1), no scaling
w = torch.randn(dims[i+1], dims[i]) * std
x = torch.relu(x @ w.T)
stds.append(round(x.std().item(), 2))
# weights.append(w)
#x = torch.randn(1, input_dim)
#stds = []
#for w in weights:
# x = x @ w.T
# x = torch.relu(x)
# stds.append(round(x.std().item(), 2))
return stds
Bug Report for https://neetcode.io/problems/weight-initialization
Expected output based on when randn is being called.
The structure required was not clearly defined in the question and prone to human error.
Causes a lot of headache, frustration and wild goose chase.
using your provided solution produces the following (expected)
[0.72,0.82,0.7,0.68,0.5]
refactoring the same code to as such produces a completely different result
[0.88,1,1.02,0.77,0.72]