Skip to content

DDP and batch loss weighting are likely biased by local valid-residue counts #79

@jeffreyHoelzel

Description

@jeffreyHoelzel

Each rank computes a local masked mean BCE loss, then DDP averages gradients across ranks. This is not equivalent to a global valid-residue-weighted loss when ranks or batches have different valid residue counts.

Evidence:

  • src/pepseqpred/core/train/trainer.py
    • Loss is (loss_raw * mask).sum() / mask.sum() per local batch.
    • DDP then averages gradients across ranks equally.
  • src/pepseqpred/apps/train_ffnn_cli.py
    • IDs are partitioned across ranks by estimated embedding file size, not by valid positive/negative residue count.

Why this can hurt:

  • A rank with few valid residues can contribute the same gradient weight as a rank with many valid residues.
  • Long proteins, label sparsity, and source/pathogen-specific label density can make this much worse.
  • Multi-pathogen data is especially vulnerable if some groups have sparse labels or many uncertain residues.

Planning direction:

  • Log valid residue count per batch and rank.
  • Consider globally normalized loss using summed numerator and denominator across ranks.
  • Alternatively ensure per-rank partitioning balances valid residues and positive residues, not just embedding file size.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions