Skip to content

sptrain test#6

Open
mahf708 wants to merge 4 commits intomainfrom
sp-train
Open

sptrain test#6
mahf708 wants to merge 4 commits intomainfrom
sp-train

Conversation

@mahf708
Copy link
Copy Markdown
Collaborator

@mahf708 mahf708 commented Feb 24, 2026

Short description of why the PR is needed and how it satisfies those requirements, in sentence form.

Changes:

  • symbol (e.g. fme.core.my_function) or script and concise description of changes or added feature

  • Can group multiple related symbols on a single bullet

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Resolves # (delete if none)

@peterdschwartz
Copy link
Copy Markdown

Working on adding tests for equality between the distributed and non-distributed ops. I notice that the pre-existing test for ConditionalLayerNorm fails when using the distributed global layer norm. But i think it may be due to the test itself being bad?

@pytest.mark.parametrize("global_layer_norm", [True, False])
@pytest.mark.parametrize("n_channels", [32])
@pytest.mark.parametrize("embed_dim_scalar", [9, 0])
@pytest.mark.parametrize("embed_dim_noise", [10, 0])
@pytest.mark.parametrize("embed_dim_labels", [11, 0])
@pytest.mark.parametrize("embed_dim_pos", [18, 0])
@pytest.mark.parametrize("img_shape", [(8, 16)])
def test_conditional_layer_norm(
    n_channels: int,
    img_shape: tuple[int, int],
    global_layer_norm: bool,
    embed_dim_scalar: int,
    embed_dim_labels: int,
    embed_dim_noise: int,
    embed_dim_pos: int,
):
    epsilon = 1e-6
    device = get_device()
    conditional_layer_norm = ConditionalLayerNorm(
        n_channels,
        img_shape,
        context_config=ContextConfig(
            embed_dim_scalar=embed_dim_scalar,
            embed_dim_labels=embed_dim_labels,
            embed_dim_noise=embed_dim_noise,
            embed_dim_pos=embed_dim_pos,
        ),
        global_layer_norm=global_layer_norm,
        epsilon=epsilon,
    ).to(device)
    x = torch.randn(1, n_channels, *img_shape, device=device) * 5 + 2
    context_embedding_scalar = torch.randn(1, embed_dim_scalar, device=device)
    context_embedding_labels = torch.randn(1, embed_dim_labels, device=device)
    context_embedding_noise = torch.randn(1, embed_dim_noise, *img_shape, device=device)
    context_embedding_pos = torch.randn(1, embed_dim_pos, *img_shape, device=device)
    context = Context(
        embedding_scalar=context_embedding_scalar,
        noise=context_embedding_noise,
        labels=context_embedding_labels,
        embedding_pos=context_embedding_pos,
    )
    output = conditional_layer_norm(x, context)
    assert output.shape == x.shape
    torch.testing.assert_close(
        output.mean(), torch.tensor(0.0, device=device), atol=1e-3, rtol=0
    )
    torch.testing.assert_close(
        output.std(), torch.tensor(1.0, device=device), atol=1e-3, rtol=0
    )
    if not global_layer_norm:
        zero = torch.zeros(1, *img_shape, device=device)
        torch.testing.assert_close(output.mean(dim=1), zero, atol=1e-3, rtol=0)
        torch.testing.assert_close(
            (((n_channels - 1) / n_channels) ** 0.5 * output.std(dim=1) - 1),
            zero,
            atol=1e-3,
            rtol=0,
        )

Basically it asserts that the mean and var after applying a random context weight/bias should still be close to zero, but mathematically that doesn't make sense. So i'm thinking it fails with the DistributedGlobalLayerNorm because it's numerically slightly different and pushing it just outside of tolerance.

The test i added for DGLN does show that it is equivalent to nn.LayerNorm:

@pytest.mark.parallel
def test_distributed_global_layer_norm():
    """DistributedGlobalLayerNorm on sharded input == LayerNorm on full tensor.
    Note: Since no training is active, the weights and biases for nn.LayerNorm 
    and DGLN are ones and zeros, so the norms should also be idempotent
    """
    device = get_device()
    dist = Distributed.get_instance()

    B, C, H, W = 2, 4, 8, 12
    # Make sure the spatial mesh divides H and W; otherwise this test is invalid.
    assert H % max(dist.h_size, 1) == 0
    assert W % max(dist.w_size, 1) == 0

    torch.manual_seed(1234)
    x_full = torch.randn(B, C, H, W, device=device)

    # Reference LayerNorm over (C, H, W) on the full tensor.
    norm_shape = (C, H, W)
    ref_ln = torch.nn.LayerNorm(
        norm_shape, eps=1e-5, elementwise_affine=True
    ).to(device)
    y_ref = ref_ln(x_full)

    # DistributedGlobalLayerNorm with matching affine params.
    H_DIM, W_DIM = -2, -1
    x_local = dist.scatter_spatial(x_full,h_dim=H_DIM,w_dim=W_DIM)
    norm_shape_local = (C, x_local.shape[H_DIM], x_local.shape[W_DIM])

    dln = DistributedGlobalLayerNorm(
        norm_shape_local, eps=1e-5, elementwise_affine=True
    ).to(device)

    y_local = dln(x_local )
    y = dist.gather_spatial(y_local,H_DIM, W_DIM)

    torch.testing.assert_close(y, y_ref, atol=1e-5, rtol=1e-5)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants