-
Notifications
You must be signed in to change notification settings - Fork 631
[pyTorch] Fix the compilation warnings #2663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Greptile OverviewGreptile SummaryThis PR targets compilation warnings in the PyTorch extension and cuDNN frontend fused-attention code paths. Changes include:
Main remaining concern is the FP8 fused-attention path, which now hardcodes Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Py as PyTorch Python
participant PB as pybind11 module
participant CO as CommOverlap/CommOverlapP2P (C++)
participant Base as CommOverlap*Base/Core (C++)
participant FA as FusedAttn FP16/FP8 (CUDA)
participant FE as cuDNN Frontend Graph
Py->>PB: instantiate CommOverlap*(...)
PB->>CO: constructor
Py->>PB: copy_into_buffer(tensor, local_chunk)
PB->>CO: copy_into_buffer(const at::Tensor&, bool)
Note over PB,CO: Binding uses static_cast to pick (Tensor,bool)
CO->>Base: (optional) call into base/core overlap implementation
Py->>PB: run fused attention op
PB->>FA: fused_attn_*_fwd_impl(...)
FA->>FE: build SDPA attributes
Note over FE: FP16: set_generate_stats(generate_stats)
Note over FE: FP8: set_generate_stats(true) (current PR)
FE-->>FA: compile/execute graph
FA-->>PB: return outputs (incl. Stats when generated)
PB-->>Py: return tensors
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Przemyslaw Tredak <ptrendx@gmail.com>
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
Additional Comments (1)
|
Description
This PR fixes most compilation warnings for the pyTorch extension. The one still left is the performance advisory from ptxas on using the multicast.cluster modifier when compiling for architecture 120 - fixing this one would require a special kernel for SM120 most probably.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: