-
Notifications
You must be signed in to change notification settings - Fork 757
[torch.compile] Bunch of small changes needed for enabling torch.compile #3130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bfce3a7
0fd25df
afe364b
ee43f56
22f80e4
6910743
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -554,6 +554,12 @@ def get_ub(name: str, use_fp8: bool): | |
| return _ub_communicators[key] | ||
|
|
||
|
|
||
| @torch.compiler.assume_constant_result | ||
| def get_ub_is_fp8(name: str, use_fp8: bool) -> bool: | ||
| """Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant.""" | ||
| return get_ub(name, use_fp8).is_fp8_ubuf() | ||
|
pggPL marked this conversation as resolved.
|
||
|
|
||
|
|
||
| def destroy_ub(): | ||
| """Destroy all allocated userbuffer communicators.""" | ||
| global _ub_communicators, _ub_with_cublasmp, _ub_initialized | ||
|
|
@@ -562,6 +568,9 @@ def destroy_ub(): | |
| _ub_initialized = False | ||
| global layers_atomic_ring_exchange | ||
| layers_atomic_ring_exchange = [] | ||
| # Compiled graphs may have baked is_fp8_ubuf() via assume_constant_result; | ||
| # reset so re-init with different settings doesn't read stale constants. | ||
| torch.compiler.reset() | ||
|
Comment on lines
+571
to
+573
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current helper call sites are all inside Is it possible avoid a process-wide compiler reset on UB teardown, or add a targeted compiled UB test that proves the stale-constant case and justifies this global invalidation? |
||
|
|
||
|
|
||
| def fill_userbuffers_buffer_for_all_gather( | ||
|
|
@@ -1049,7 +1058,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: | |
| if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): | ||
| return | ||
| if recipe.custom() and isinstance(recipe_state, CustomRecipeState): | ||
| return | ||
| if recipe_state.recipe is recipe: | ||
| return | ||
|
|
||
| # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and | ||
| # 2 (grad_output and grad_input) for bwd | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That file only compiles a local
ToyLinearhelper andtorch.nn.Linearunderte.autocast. It does not instantiate changed in this PRte.Linear,LayerNormLinear, orLayerNormMLP, and it has no UB,sequence_parallel/parallel_mode.What tests would fail without changes to layernorm_linear, layernorm_mlp files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fix the issue that the test was not connected to the CI.
Currently it tests only if te.autocast() can be traced inside torch.compile.
This is first of series of PRs and I change here only small things to make next PRs cleaner.