Geotransolver 2d 3d#1676
Conversation
Refactor flare to reduce from three duplicate implementations to one.
…odesl, optionally
|
As part of this PR, we will need to do model checkpoint surgery on our existing chekpoints. Tagging @ktangsali so we can be sure to validate checkpoints against this. |
|
/ok to test 931d139 |
Greptile SummaryThis PR extends GeoTransolver, FLARE, and related components to support structured 2D and 3D grids (in addition to the existing unstructured mesh path), and wires the new variants into the Darcy example. The refactoring also extracts shared helpers (
Important Files Changed
|
| train_path: //lustre/fsw/portfolios/coreai/users/coreya/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth1.npz | ||
| test_path: //lustre/fsw/portfolios/coreai/users/coreya/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth2.npz |
There was a problem hiding this comment.
Hardcoded internal cluster paths
train_path and test_path are set to paths on NVIDIA's internal Lustre filesystem (//lustre/fsw/portfolios/coreai/users/coreya/...). Anyone cloning the repo and running this example will immediately hit a file-not-found error. These should be replaced with placeholder paths such as /path/to/piececonst_r421_N1024_smooth1.npz — matching the documented README — so that users know they need to point these at their own downloaded copies of the dataset.
| dir: . | ||
|
|
||
| output_dir: ./output/ | ||
| run_id: ${hydra:runtime.choices.model}-muon_${precision}_r${resolution}_b${data.batch_size}_s${model.slice_num} |
There was a problem hiding this comment.
run_id has the string "muon" hard-coded, so runs using optimizer.type: adamw will still produce a run_id that says muon. Consider interpolating from the config so the recorded name always matches the actual optimizer in use.
| run_id: ${hydra:runtime.choices.model}-muon_${precision}_r${resolution}_b${data.batch_size}_s${model.slice_num} | |
| run_id: ${hydra:runtime.choices.model}-${optimizer.type}_${precision}_r${resolution}_b${data.batch_size}_s${model.slice_num} |
| else: | ||
| self.ln_1 = nn.LayerNorm(hidden_dim) | ||
|
|
||
| # Attention layer | ||
| dim_head = hidden_dim // num_heads | ||
| # First match on attention backend, then on spatial shape | ||
| match attention_type: | ||
| case 'GALE': | ||
| self.Attn = GALE( | ||
| hidden_dim, | ||
| heads=num_heads, | ||
| dim_head=hidden_dim // num_heads, | ||
| dropout=dropout, | ||
| slice_num=slice_num, | ||
| use_te=use_te, | ||
| plus=plus, | ||
| context_dim=context_dim, | ||
| concrete_dropout=concrete_dropout, | ||
| state_mixing_mode=state_mixing_mode, | ||
| ) | ||
| if spatial_shape is None: | ||
| self.Attn = GALE( | ||
| hidden_dim, | ||
| heads=num_heads, | ||
| dim_head=dim_head, | ||
| dropout=dropout, | ||
| slice_num=slice_num, | ||
| use_te=use_te, | ||
| plus=plus, | ||
| context_dim=context_dim, | ||
| concrete_dropout=concrete_dropout, | ||
| state_mixing_mode=state_mixing_mode, | ||
| ) | ||
| elif len(spatial_shape) == 2: | ||
| self.Attn = GALEStructuredMesh2D( | ||
| hidden_dim, | ||
| spatial_shape=(int(spatial_shape[0]), int(spatial_shape[1])), | ||
| heads=num_heads, | ||
| dim_head=dim_head, | ||
| dropout=dropout, | ||
| slice_num=slice_num, | ||
| use_te=use_te, | ||
| plus=plus, | ||
| context_dim=context_dim, | ||
| state_mixing_mode=state_mixing_mode, | ||
| ) | ||
| elif len(spatial_shape) == 3: | ||
| self.Attn = GALEStructuredMesh3D( | ||
| hidden_dim, | ||
| spatial_shape=( | ||
| int(spatial_shape[0]), | ||
| int(spatial_shape[1]), | ||
| int(spatial_shape[2]), | ||
| ), | ||
| heads=num_heads, | ||
| dim_head=dim_head, | ||
| dropout=dropout, | ||
| slice_num=slice_num, | ||
| use_te=use_te, | ||
| plus=plus, | ||
| context_dim=context_dim, | ||
| state_mixing_mode=state_mixing_mode, | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| f"spatial_shape must be None, length-2, or length-3; got {spatial_shape!r}" | ||
| ) | ||
| case 'GALE_FA': | ||
| self.Attn = GALE_FA( | ||
| hidden_dim, | ||
| heads=num_heads, | ||
| dim_head=hidden_dim // num_heads, | ||
| dim_head=dim_head, | ||
| dropout=dropout, | ||
| n_global_queries=slice_num, | ||
| use_te=use_te, |
There was a problem hiding this comment.
concrete_dropout silently ignored for structured GALE variants
When spatial_shape is not None, GALE_block selects GALEStructuredMesh2D or GALEStructuredMesh3D, neither of which accepts a concrete_dropout argument. If a caller sets concrete_dropout=True on a structured GALE_block, the option is silently dropped — standard nn.Dropout is used instead and no warning is emitted. At minimum, a warnings.warn when concrete_dropout=True and spatial_shape is not None would make this limitation discoverable.
|
/blossom-ci |
PhysicsNeMo Pull Request
Reopening this Pull Request.
This refactors GeoTransolver, Flare, and some components of transolver to be enabled fro 2D and 3D cases. The goal is to make these more suitable for structured datasets, and enable domain parallelism in these cases.
I also enabled them in the Darcy transolver model, just so we have an example for users to test these.
In order to enable ShardTensor for GeoTransolver and Flare, and move them out of experimental, we should get this or a similar refactor in.
Description
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.