Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ struct GroupedTensor {
SimpleTensor columnwise_amax;
SimpleTensor scale; // for FP8-DS only

NVTEScalingMode scaling_mode;
size_t num_tensors;

// Shape information (OPTIONAL - empty if dimension is uniform across all tensors)
// first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim)
// last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim)
Expand All @@ -330,8 +333,6 @@ struct GroupedTensor {
// Always 2D with positive dimensions
NVTEShape logical_shape;

NVTEScalingMode scaling_mode;
size_t num_tensors;
NVTEGroupedTensor nvte_tensor;

GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
Expand All @@ -342,12 +343,12 @@ struct GroupedTensor {
amax(),
columnwise_amax(),
scale(),
scaling_mode(scaling_mode),
num_tensors(num_tensors),
first_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 1)),
scaling_mode(scaling_mode),
nvte_tensor(0) {}

explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
fe::graph::SDPA_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention")
.set_is_inference(false)
.set_generate_stats(generate_stats)
.set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1810,7 +1810,7 @@ void fused_attn_fp8_fwd_impl_v1(
fe::graph::SDPA_fp8_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_fp8_attributes()
.set_name("sdpa_fp8")
.set_is_inference(false)
.set_generate_stats(true)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);

Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve

~CommOverlap() {}

using transformer_engine::CommOverlapCore::copy_into_buffer;
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);

at::Tensor get_buffer(bool local_chunk = false,
Expand All @@ -569,6 +570,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm

~CommOverlapP2P() {}

using transformer_engine::CommOverlapP2PBase::copy_into_buffer;
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);

at::Tensor get_buffer(bool local_chunk = false,
Expand Down
12 changes: 8 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0,
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false)
.def("copy_into_buffer",
static_cast<void (CommOverlap::*)(const at::Tensor &, bool)>(
&CommOverlap::copy_into_buffer),
py::arg("input"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlap::get_communication_stream);
Expand All @@ -510,8 +512,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1,
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false)
.def("copy_into_buffer",
static_cast<void (CommOverlapP2P::*)(const at::Tensor &, bool)>(
&CommOverlapP2P::copy_into_buffer),
py::arg("input"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlapP2P::get_communication_stream);
Expand Down
Loading