Skip to content

Add AeroJEPA model + SuperWing tutorial recipe (experimental)#1690

Open
fgiral000 wants to merge 51 commits into
NVIDIA:mainfrom
fgiral000:aerojepa-integration
Open

Add AeroJEPA model + SuperWing tutorial recipe (experimental)#1690
fgiral000 wants to merge 51 commits into
NVIDIA:mainfrom
fgiral000:aerojepa-integration

Conversation

@fgiral000
Copy link
Copy Markdown

PhysicsNeMo Pull Request

Description

Adds the AeroJEPA model and a SuperWing tutorial recipe under
physicsnemo.experimental and examples/cfd/external_aerodynamics/.
AeroJEPA is a Joint-Embedding Predictive Architecture for 3D
aerodynamic surrogate modeling: instead of mapping geometry directly to
a flow field, it predicts a latent representation of the flow from a
latent representation of the geometry and operating conditions, and
reconstructs the field through a continuous implicit decoder when
needed (Giral et al., arXiv:2605.05586).

What this PR delivers:

  • Model at physicsnemo.experimental.models.aerojepa.
    AeroJEPA composes a context encoder, a target encoder, a query-token
    field decoder (collectively AeroJEPATrunk), and a JEPA predictor
    head (PrototypeTokenJEPAHead) into a single
    physicsnemo.core.module.Module. The training path takes context
    positions/features, independent target encoder surface/volume inputs,
    and operating conditions; the predictor predicts target tokens, and
    the decoder evaluates the field at user-supplied query points.
    predict is a no-grad inference wrapper; decode_field_chunked
    supports memory-bounded evaluation over very large query sets.
    Concrete encoders (ContextTransformer, TargetTransformer,
    PointTransformer), the QueryTokenDecoder, and the encoder ABCs
    are all exposed as composable components.
  • Building blocks at
    physicsnemo.experimental.models.aerojepa.layers. TokenSet and
    EncoderOutput token dataclasses, a deterministic
    FourierPositionalEncoding, ResidualMLP, the
    LocalPointTransformerBlock / LocalTokenCrossAttentionBlock
    attention blocks (with optional AdaLN / AdaLN-Zero conditioning), the
    PointCloudTokenizer (seven center-selection strategies with k-NN
    cluster pooling), token batching / mask / k-NN helpers, and prototype
    anchor build / load utilities. TokenSet and EncoderOutput are
    re-exported from the model package for convenience.
  • Losses at physicsnemo.experimental.models.aerojepa.losses.
    SIGReg and TokenLatentSIGReg (a sketch isotropic-Gaussian
    regularizer for latent-token distributions, with a padding-aware
    wrapper), the flatten_valid_token_features /
    reshape_token_features_for_sigreg masking helpers, and the
    reconstruction loss family (MSELoss / RelativeL2Loss /
    RelativeMSELoss / RelativeL2MSELoss, each with functional and
    nn.Module forms, optional per-channel weights stored as a
    persistent buffer, optional per-point weights, and an optional
    validity mask).
  • Tutorial recipe at
    examples/cfd/external_aerodynamics/aerojepa. End-to-end Hydra-driven
    workflow on the public SuperWing dataset (Yang et al.,
    arXiv:2512.14397): dataset download via the Hugging Face Hub
    (yunplus/SuperWing), automatic split-by-geometry manifest and
    per-channel normalization stats, JEPA training (reconstruction +
    latent + SIGReg with linear warmups; AdamW +
    warmup-cosine; optional EMA), checkpointed inference with chunked
    decoding, three-panel GT | Pred | |Error| field plots for the three
    surface channels (Cp, Cf_tau, Cf_z), per-channel relative-L2 /
    RMSE / MAE metrics on the test split, and a pressure-only CL/CD
    post-processor that integrates the surface field and emits a per-case
    CSV plus a parity scatter.

Checklist

Tests

  • 193 unit tests under test/experimental/models/aerojepa/
    (constructor + attribute checks, non-regression shape checks on the
    encoders, decoder, predictor, trunk, top-level model, layers, and
    losses). pytest test/experimental/models/aerojepa/ -q passes
    locally on CPU (~20 s).
  • Full SuperWing end-to-end smoke-tested on a single GPU:
    train.py -> inference.py -> superwing_metrics -> superwing_forces.
    Training losses decrease monotonically; inference produces field
    plots, per-case field-error metrics, and a force-coefficient parity
    scatter.

Dependencies

No new core dependencies. The example recipe adds optional
example-side dependencies in
examples/cfd/external_aerodynamics/aerojepa/requirements.txt
(Hugging Face Hub for the dataset download, plotting and
post-processing utilities). Pre-commit hooks, ruff, interrogate,
markdownlint, and the SPDX license check pass on every file in the
PR.

fgiral000 added 30 commits June 1, 2026 13:15
Create an empty subpackage for the AeroJEPA reusable building blocks
(attention blocks, geometry tokenizer, context/target encoders, decoder,
predictor) that land in subsequent commits. Establishes the SPDX
license header, module docstring, and ``__all__`` placeholder so that
follow-up commits only need to register new public symbols.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
TokenSet bundles token features with their geometric coordinates and
optional mask, global token, and auxiliary side data; EncoderOutput is
a thin wrapper used by context and target encoders to surface a global
summary alongside the per-token output. Includes raw-string docstrings
with Parameters/Examples sections (three executable doctests), modern
union syntax, and the SPDX header.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A deterministic log-frequency sinusoidal positional encoding used to
lift continuous query coordinates into a high-dimensional feature
space before the implicit decoder consumes them. Distinct from
physicsnemo.nn.FourierEmbedding (random Gaussian frequencies on
scalar timesteps); this variant uses fixed log-powers of pi on
multi-dim coordinates with the standard sin/cos band layout.
Includes an out_dim property, jaxtyping on forward, and an
executable doctest.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A private _gpu_knn module bundling chunked torch.cdist plus topk for
building homogeneous (gpu_knn_self) and bipartite (gpu_knn_bipartite)
k-NN graphs and inverse-distance interpolation (gpu_knn_interpolate).
Pure PyTorch, no warp or custom CUDA — works on CPU too, just slower.
The leading underscore on the filename makes the module package-
private; callers live inside the aerojepa subpackage only.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A token_utils module with the helpers used by the AeroJEPA tokenizer,
encoders and attention blocks: gather_rows, counts_to_mask,
flatten_padded_batch / unflatten_to_padded, compute_batch_offset_step,
flatten_batched_coords, chunked_knn_indices (CPU/GPU dispatcher with
the AE_KNN_BACKEND env override), masked_mean, trim_batched_tokens,
and pad_token_sets. Behavior preserved; types modernized and the
TokenSet import is package-relative.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A reusable trio of attention building blocks: ResidualMLP (pre-norm
residual MLP with optional AdaLN / AdaLN-Zero conditioning),
LocalPointTransformerBlock (local self-attention over a per-point
k-NN graph with learned relative-position bias), and
LocalTokenCrossAttentionBlock (cross-attention from queries to a
per-query k-NN of context tokens, with a 5-way conditioning MLP that
modulates query and key/value sides independently). Behavior
preserved: zero-init conditioning MLPs give an identity transform at
construction time, and the N<=1 / empty-input fallbacks short-circuit
the same way they do upstream.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
A tokenizer module that reduces a raw point set to a bounded token
budget before attention. Seven strategies: identity, random, FPS,
random/FPS/voxel-FPS cluster pooling, and prototype-anchored
clustering. The cluster strategies return the kNN indices that link
each token center back to the source points, allowing a downstream
encoder to replace the default feature mean with a learned pooling
(e.g. the message-passing PointClusterGraphPool that lands with the
encoders in PR NVIDIA#3). Behavior preserved including the non-persistent
prototype_coords buffer and the per-sample loop used by the
prototype strategy.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Build the fixed k-means anchor set used by the
data_prototype_cluster tokenizer strategy. The build pass walks a
training dataset, tokenizes each sample to obtain candidate token
coordinates, optionally subsamples, runs chunked Lloyd-iteration
k-means with empty-cluster FPS refill, sorts the centers
lexicographically, and serializes them with a JSON metadata blob.
Two load functions (target / context - identical file layout) and
two ensure_* helpers (load-if-exists else build) round out the
public surface. Behavior preserved; the seed argument governs
k-means initialization and candidate subsampling but not the
tokenizer pass, which intentionally uses random sampling.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Export the 23 public symbols from the six source modules at the
package level: TokenSet/EncoderOutput dataclasses,
FourierPositionalEncoding, ResidualMLP and the two local attention
blocks, PointCloudTokenizer, the ten batching/mask/kNN helpers,
and the six prototype anchor build/load functions. Module
docstring tightened to reflect the actual contents (encoders /
decoder / predictor land in physicsnemo.experimental.models.aerojepa
in a later PR). The package-private _gpu_knn helpers remain
accessible via their submodule path but are intentionally not
re-exported.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Re-export the five AeroJEPA nn.Module layer classes
(FourierPositionalEncoding, ResidualMLP, LocalPointTransformerBlock,
LocalTokenCrossAttentionBlock, PointCloudTokenizer) at the
experimental.nn parent namespace, alongside the existing FLARE
and DiffusionUNet3D family. Data types (TokenSet, EncoderOutput),
batching/mask helpers, and prototype-anchor builders stay scoped
to the aerojepa subpackage to keep the parent namespace focused
on actual layers.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Tests for TokenSet and EncoderOutput covering construction (both
batched and unbatched), the is_batched / token_dim properties, the
with_updates immutability + selective-replacement contract, and the
independence of the default aux dict across instances. Uses the
shared device fixture so the CUDA path runs when available.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…locks

Six new test files covering positional_encoding, attention_blocks,
point_tokenizer, token_utils, _gpu_knn, and prototype_anchors. 85 new
tests covering: constructor validation paths, forward output shapes,
edge cases (N<=1 LPT fallback, empty cross-attention, empty/single-
point kNN, missing voxel_size, non-persistent prototype_coords
buffer), identity-at-init of AdaLN-Zero conditioning MLPs, the
AE_KNN_BACKEND env override, and build/load round-trips with a tiny
fake dataset. All tests use the shared device fixture so CUDA runs
when available; CPU run is 18 s wall.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Documents the new physicsnemo.experimental.nn.aerojepa subpackage
contributed across the preceding 12 commits on this branch: token
dataclasses, Fourier positional encoding, ResidualMLP, the two local
attention blocks, PointCloudTokenizer, token batching/mask/kNN
helpers, and prototype anchor utilities, plus the parent-namespace
re-export of the five layer classes.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Create an empty subpackage for JEPA-style losses and regularizers
(SIGReg, TokenLatentSIGReg, the padding-aware masking helpers, and
the reconstruction loss family) that land in subsequent commits.
Establishes the SPDX license header, module docstring, and
``__all__`` placeholder so that follow-up commits only need to
register new public symbols.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Two utilities used by SIGReg / TokenLatentSIGReg to flatten padded
batched token features and reshape them into the (T, B, D) layout
SIGReg expects. flatten_valid_token_features is a passthrough on
rank-2 inputs and uses boolean masking on rank-3 inputs;
reshape_token_features_for_sigreg adds the leading T=1 axis and
emits a zero-element (1, 0, D) placeholder when the mask removes
every row.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
SIGReg pushes a learned latent toward N(0, I) by comparing the
empirical Fourier characteristic function of random projections
against the reference Gaussian one on a uniform knot grid (the
LeWorldModel construction). Three non-learnable buffers cache the
knot positions, the reference window, and the trapezoidal +
window-weighted integration weights. TokenLatentSIGReg is a thin
wrapper that accepts (B, N, D) or (N, D) features plus an optional
mask, drops padded rows via the masking helpers, and short-circuits
to a zero scalar when the mask removes every row.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Four loss families exposed as functional and nn.Module variants:
mse_loss / MSELoss (channel-weighted MSE with mask + point weights),
relative_l2_loss / RelativeL2Loss (per-channel relative L2 averaged
over channels), relative_mse_loss / RelativeMSELoss (relative MSE
with selectable pointwise vs channel_max normalization), and the
relative_l2_mse_loss / RelativeL2MSELoss hybrid that linearly
combines the L2 and MSE terms. Channel weights are stored as a
persistent float32 buffer on the Module variants when supplied, and
as a non-persistent None buffer otherwise.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Move the JEPA losses subpackage from
physicsnemo.experimental.metrics.jepa to .metrics.aerojepa so it
mirrors the nn.aerojepa naming. Populate the package __init__ with
the 12 public re-exports from masking, sigreg, and reconstruction
(flatten/reshape token helpers, SIGReg/TokenLatentSIGReg, and the
four reconstruction loss families both functional and as nn.Module).

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Three test files mirroring the source modules: test_masking,
test_sigreg, test_reconstruction. 37 tests covering constructor
validation, forward shape, edge cases (rank-1, empty batch, all-False
mask), the SIGReg buffer layout, state_dict persistence of
channel_weights on the reconstruction Module variants, both modes of
relative_mse_loss, and the hybrid degenerating to either of its two
sub-losses when the corresponding weight is zero. CPU run is 4 s wall;
device fixture picks up CUDA automatically.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Documents the new physicsnemo.experimental.metrics.aerojepa
subpackage contributed across the preceding 6 commits on this
branch: SIGReg / TokenLatentSIGReg regularizers, masking helpers,
and the four reconstruction loss families.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Create an empty subpackage for the top-level AeroJEPA model and its
model-specific subcomponents (context/target/point encoders, decoder,
predictor, trunk) that land in subsequent commits. Module docstring
points readers to experimental.nn.aerojepa for the reusable building
blocks the model is composed from.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Abstract base classes BaseContextEncoder and BaseTargetEncoder
(plus the encoders subpackage init) define the contract concrete
encoders must satisfy: a required forward returning an EncoderOutput
and an optional forward_batched gated by a supports_batched_forward
class flag. The context encoder's forward args are named context_pos
/ context_feat (these bundle the boundary and any volumetric samples
in whole-domain models; the SDF channel in context_feat distinguishes
the two halves at inference). The target encoder keeps the surface
/ volume split because training-time subsamplings for the two are
intentionally decoupled.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Context tokens are produced from geometry alone - operating
conditions enter the model downstream at the predictor head, not
at the context branch. Remove gen_params from BaseContextEncoder
forward and forward_batched signatures. Class docstring spells
out the intent. BaseTargetEncoder is untouched.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
PointTransformer (point.py) is a point-cloud encoder building block:
tokenizes the input via PointCloudTokenizer, embeds tokens with a
Fourier positional encoding plus per-feature linear projection,
optionally adds a conditioning vector, runs a stack of
LocalPointTransformerBlock layers with configurable dilation, and
emits an EncoderOutput. Two entry points - encode_single for
unbatched inputs and forward_batched for padded batches with
per-batch coordinate offsetting so the inner k-NN does not mix
tokens across batch items. The same file carries the
build_geometry_features helper (assembles per-point features from
positions and optional SDF / normals / n-dot channels) and the
message-passing PointClusterGraphPool used when
tokenizer_cluster_pooling='graph'.

ContextTransformer (context.py) is the concrete BaseContextEncoder.
Takes context_pos and context_feat - no gen_params, since operating
conditions enter the model downstream at the predictor head.
Internally wraps PointTransformer with conditioning disabled.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Mirror the context-side change: target encoders take their inputs
straight, with no gen_params threaded through. Operating conditions
enter the model only at the predictor head.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
JEPA target encoders are self-attention only. Remove context_tokens
from forward and forward_batched.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Concrete BaseTargetEncoder that wraps an inner PointTransformer.
Forward concatenates surface and volume into one bundled point set;
forward_batched weaves variable-length surface and volume halves per
batch via counts_to_mask. Self-attention only.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Implicit field decoder driven by cross-attention to target tokens.
Per-query embedding is a Fourier positional encoding plus optional
SDF channel and optional cond vector; cross-attention to the target
token set refines it, a trunk MLP and head produce the output.
Several optional behaviors wire in: wall-velocity gate, pressure
split head (MLP or SIREN), final SIREN refinement, extra SDF
features. Both forward (single) and forward_batched (padded)
process queries in chunks of query_chunk_size and return
(pred, query_embeddings). SineLayer and SirenHead are the small
SIREN building blocks used by the optional heads.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
The JEPA predictor head. Maps a target-token coordinate set to
predicted target-token features, given context tokens and a
conditioning vector. Operating conditions enter the model here
(via the cond argument), projected once and threaded into every
self- and cross-attention block. Accepts both unbatched (rank-2
context features) and padded batched (rank-3) inputs;
target_positions and cond are broadcast across the batch when
their leading dim is 1. The forward signature uses target_positions
as the parameter name (not target_coords) for consistency with
the rest of the model API.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Owns context encoder, target encoder, and decoder, and wires them
together. encode_context runs both encoders and emits a dict with
context tokens, target tokens, and the decoder-side cond_global.
decode_queries decodes a target token set at supplied query
positions, optionally producing a per-query mask logit when the
mask head is enabled. forward_single and forward_batch are
convenience wrappers chaining the two phases for unbatched and
padded batched inputs respectively. Public args use context_pos /
context_feat naming; gen_params is used to build cond_global but
is not threaded into the encoders.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
fgiral000 added 20 commits June 1, 2026 13:18
Composes AeroJEPATrunk and PrototypeTokenJEPAHead into a single
physicsnemo.Module with full MOD-001 / MOD-006 / MOD-010 compliance:
@DataClass AeroJEPAMetaData inheriting ModelMetaData, jaxtyping on
all public methods, validation guarded by torch.compiler.is_compiling,
constructor taking typed components (no cfg dict, no kwargs).

The forward entry takes context_pos / context_feat / gen_params /
query_pos / query_sdf and derives target-token positions internally
via build_target_token_coords (the target encoder's tokenizer with
a placeholder feature tensor); callers no longer supply target_coords.
predict is a no-grad wrapper around forward. encode_geometry,
encode_geometry_and_flow, predict_field_tokens, decode_field, and
build_target_token_coords are exposed for training-loop callers and
for latent-optimization workflows that want to cache target
coordinates across many predictor evaluations.
decode_field_chunked wraps the decoder in chunked + autocast +
CPU-offload for memory-bounded inference on very large query sets.
The class docstring ships an executable doctest that wires the whole
chain and asserts a forward-pass shape.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Re-export AeroJEPA + AeroJEPAMetaData (top-level model), AeroJEPATrunk
and PrototypeTokenJEPAHead (composable components), QueryTokenDecoder,
BaseContextEncoder / BaseTargetEncoder (ABCs for custom encoders),
and the three concrete encoders ContextTransformer / PointTransformer
/ TargetTransformer.

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Ten new test files mirroring the source modules:
encoders/test_base, encoders/test_context, encoders/test_point,
encoders/test_target, test_decoder, test_predictor, test_trunk, and
test_aerojepa. 63 tests covering constructor validation, signature
checks (drops of target_coords / gen_params / context_tokens per
the API redesign), forward / forward_batched shapes, every optional
decoder feature (pressure split head, SIREN refinement, wall gate,
extra SDF features), predictor broadcasting paths, trunk wiring
(mask head on/off), and the top-level model contract (physicsnemo
Module subclass, plain-tensor forward, no-grad predict,
single-arg build_target_token_coords, chunked CPU-offload decode).

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Documents the new physicsnemo.experimental.models.aerojepa
subpackage: the AeroJEPA top-level model and its composable
subcomponents (context/target encoders, decoder, predictor,
trunk).

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…ame surface_main_feat

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…mup builders)

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…recipe

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…s in README

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…recipe to CHANGELOG

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…est split

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
…sion bump

Signed-off-by: fgiral000 <fa.giral@alumnos.upm.es>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Jun 1, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 1, 2026

Greptile Summary

This PR introduces the AeroJEPA model — a Joint-Embedding Predictive Architecture for 3D aerodynamic surrogate modeling — along with all its building blocks (encoders, decoder, predictor, layers, losses) under physicsnemo.experimental, plus a full Hydra-driven SuperWing tutorial recipe and 193 unit tests.

  • Model core (aerojepa.py, trunk.py, predictor.py, decoder.py): The context/target encoder–predictor–decoder pipeline is well-structured; batched and single-sample forward paths handle edge cases correctly.
  • Training recipe (train.py, runtime.py): The validation path in _run_epoch builds full autograd graphs without a torch.no_grad() guard, and get_autocast_context exposes fp16 without a paired GradScaler.
  • masked_mean: Returns (B, F) for rank-3 no-mask input but documents (B, 1, F).

Important Files Changed

Filename Overview
physicsnemo/experimental/models/aerojepa/aerojepa.py Top-level AeroJEPA Module composing trunk + predictor; forward/predict/decode_field_chunked paths look correct; build_target_token_coords uses a private _tokenize_single method (noqa-suppressed).
physicsnemo/experimental/models/aerojepa/trunk.py AeroJEPATrunk wiring encoder/decoder; encode_context, decode_queries, forward_single/forward_batch all look correct.
physicsnemo/experimental/models/aerojepa/decoder.py QueryTokenDecoder with chunked cross-attention, SIREN options, wall-velocity gate, and batched forward; logic appears sound.
physicsnemo/experimental/models/aerojepa/predictor.py PrototypeTokenJEPAHead with interleaved self/cross attention; batch handling and conditioning logic look correct.
physicsnemo/experimental/models/aerojepa/layers/token_utils.py Batch flattening, k-NN, and TokenSet utilities; masked_mean has a docstring/implementation shape inconsistency for rank-3 no-mask input (returns (B,F) not (B,1,F) as documented).
examples/cfd/external_aerodynamics/aerojepa/train.py Hydra training entry point; validation forward pass in _run_epoch builds unnecessary autograd graphs because there is no torch.no_grad() guard when is_train=False, wasting GPU memory.
examples/cfd/external_aerodynamics/aerojepa/src/training/runtime.py get_autocast_context enables fp16 autocast without a paired GradScaler; safe with the default bf16 config but could silently corrupt training if users set precision: fp16.

Comments Outside Diff (1)

  1. physicsnemo/experimental/models/aerojepa/layers/token_utils.py, line 1350-1356 (link)

    P2 masked_mean return shape mismatch between mask=None and mask≠None paths for rank-3 input

    The docstring states the function returns (B, 1, F) for rank-3 input, but the mask is None branch uses features.mean(dim=1) (no keepdim) and actually returns (B, F). The masked branch correctly uses keepdim=True and returns (B, 1, F). This inconsistency could cause silent shape mismatches if a caller passes rank-3 features without a mask and expects the documented (B, 1, F) layout.

Reviews (1): Last reviewed commit: "changelog: move SuperWing recipe bullet ..." | Re-trigger Greptile

Comment on lines +95 to +107
if precision_l == "fp16":
return torch.autocast(device_type="cuda", dtype=torch.float16)
return contextlib.nullcontext()


def build_lr_scheduler(
optimizer: torch.optim.Optimizer,
*,
name: str,
epochs: int,
steps_per_epoch: int,
warmup_epochs: float = 5.0,
warmup_steps: int | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 fp16 + autocast without GradScaler

get_autocast_context enables torch.autocast with torch.float16 when precision="fp16", but the training loop in train.py contains no torch.cuda.amp.GradScaler. Without the scaler, fp16 gradients can overflow to inf/nan and the optimizer step silently corrupts parameters. The default config uses bf16 (which shares fp32's dynamic range and doesn't need a scaler), so normal runs are unaffected, but any user who changes precision: fp16 in the training YAML will experience silent training failures.

Comment on lines +312 to +345
loss_cfg: DictConfig,
epoch: int,
max_batches: int | None,
) -> dict[str, float]:
is_train = optimizer is not None
model.train(is_train)

totals: dict[str, float] = {
"loss": 0.0,
"recon": 0.0,
"latent": 0.0,
"sigreg": 0.0,
}
n_samples = 0

for batch_idx, batch in enumerate(loader):
if max_batches is not None and batch_idx >= int(max_batches):
break
batch = move_batch_to_device(batch, device)
if is_train:
optimizer.zero_grad(set_to_none=True)

sample_losses: list[torch.Tensor] = []
for sample_idx in range(int(batch["context_pos"].shape[0])):
sample = _slice_batch_sample(batch, sample_idx)
with get_autocast_context(device, precision):
pred_field, pred_features, target_tokens, _, _ = _forward_sample(
model, sample
)
loss, parts = _compute_total_loss(
pred_field=pred_field,
query_target=sample["query_target"],
pred_features=pred_features,
target_tokens=target_tokens,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Validation forward pass builds unnecessary computation graphs

_run_epoch runs _forward_sample and _compute_total_loss without a torch.no_grad() guard when is_train=False. PyTorch therefore builds and retains the full autograd graph for every validation sample, but backward() is never called. The graph is held until sample_losses is reset at the next batch boundary, so peak extra memory is one batch's computation graph. With eval_batch_size=1 this is modest, but larger evaluation batch sizes or models could trigger OOM. Wrapping the inner loop body (or at minimum the _forward_sample call) with torch.no_grad() or torch.inference_mode() when not is_train would eliminate this overhead.

@peterdsharpe
Copy link
Copy Markdown
Collaborator

Hi @fgiral000, thanks for the PR! To keep PR size reviewable, would it be possible to:

a) split this PR up into two separate PRs, one of which adds the model ("PR 1"), and a later follow-on that adds the example ("PR 2").

b) In PR 1, please re-use shared PhysicsNeMo tooling where possible. (E.g., _gpu_knn.py should re-use existing KNN implementations in physicsnemo.nn.functional; conditioning MLPs should use FullyConnected, many losses duplicate existing code)

c) All functions should use jaxtyping annotations for tensor shapes. Please use Literal types for enumerations rather than str, etc.

d) In PR 2, if possible, please add AeroJEPA as an example within ./examples/external_aerodynamics/unified_external_aero_recipe/, rather than as a standalone aerojepa folder.

@peterdsharpe
Copy link
Copy Markdown
Collaborator

Actually, it might be worth splitting out a third PR as well for addition of the SuperWing dataset utils.

@mnabian mnabian self-requested a review June 1, 2026 18:40
Updated README.md for AeroJEPA tutorial for title case. And a few edits for clarity
Comment thread CHANGELOG.md
field decoder (collectively the :class:`AeroJEPATrunk`), and a
JEPA predictor head (:class:`PrototypeTokenJEPAHead`) into a single
`physicsnemo.core.module.Module`. The forward entry takes
context positions / features, operating conditions, and query
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
context positions / features, operating conditions, and query
context positions and features, operating conditions, and query

Comment thread CHANGELOG.md
JEPA predictor head (:class:`PrototypeTokenJEPAHead`) into a single
`physicsnemo.core.module.Module`. The forward entry takes
context positions / features, operating conditions, and query
positions, derives target-token coordinates internally via
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
positions, derives target-token coordinates internally via
positions, derives target-token coordinates internally using

Comment thread CHANGELOG.md
positions, derives target-token coordinates internally via
``build_target_token_coords``, and returns the decoded field at
the queries. ``predict`` is a no-grad convenience wrapper;
``encode_geometry`` / ``encode_geometry_and_flow`` /
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``encode_geometry`` / ``encode_geometry_and_flow`` /
``encode_geometry`` , ``encode_geometry_and_flow`` ,

Comment thread CHANGELOG.md
``build_target_token_coords``, and returns the decoded field at
the queries. ``predict`` is a no-grad convenience wrapper;
``encode_geometry`` / ``encode_geometry_and_flow`` /
``predict_field_tokens`` / ``decode_field`` /
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``predict_field_tokens`` / ``decode_field`` /
``predict_field_tokens`` , ``decode_field``, and

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants