Skip to content

Bug Report for weight-initialization #5810

@spironan

Description

@spironan

Bug Report for https://neetcode.io/problems/weight-initialization

Expected output based on when randn is being called.
The structure required was not clearly defined in the question and prone to human error.
Causes a lot of headache, frustration and wild goose chase.

using your provided solution produces the following (expected)
[0.72,0.82,0.7,0.68,0.5]

    def check_activations(
        self, num_layers: int, input_dim: int, hidden_dim: int, init_type: str
    ) -> List[float]:
        
        torch.manual_seed(0)
        dims = [input_dim] + [hidden_dim] * num_layers
        weights = []
        
        for i in range(num_layers):
            if init_type == "kaiming":
                std = math.sqrt(2.0 / dims[i])
            elif init_type == "xavier":
                std = math.sqrt(2.0 / (dims[i] + dims[i + 1]))
            elif init_type == "random":
                std = 1.0  # plain N(0,1), no scaling

            w = torch.randn(dims[i+1], dims[i]) * std
            weights.append(w)

        x = torch.randn(1, input_dim)
        stds = []
        for w in weights:
            x = x @ w.T
            x = torch.relu(x)
            stds.append(round(x.std().item(), 2))
            
        return stds

refactoring the same code to as such produces a completely different result
[0.88,1,1.02,0.77,0.72]

def check_activations(
        self, num_layers: int, input_dim: int, hidden_dim: int, init_type: str
    ) -> List[float]:
        
        torch.manual_seed(0)
        dims = [input_dim] + [hidden_dim] * num_layers
        weights = []
        
        x = torch.randn(1, input_dim)
        stds = []

        for i in range(num_layers):
            if init_type == "kaiming":
                std = math.sqrt(2.0 / dims[i])
            elif init_type == "xavier":
                std = math.sqrt(2.0 / (dims[i] + dims[i + 1]))
            elif init_type == "random":
                std = 1.0  # plain N(0,1), no scaling

            w = torch.randn(dims[i+1], dims[i]) * std
            x = torch.relu(x @ w.T)
            stds.append(round(x.std().item(), 2))
            # weights.append(w)

        #x = torch.randn(1, input_dim)
        #stds = []
        #for w in weights:
        #    x = x @ w.T
        #    x = torch.relu(x)
        #    stds.append(round(x.std().item(), 2))
            
        return stds

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions