diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 970b7aef6c..99a2985d5e 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -313,6 +313,9 @@ struct GroupedTensor { SimpleTensor columnwise_amax; SimpleTensor scale; // for FP8-DS only + NVTEScalingMode scaling_mode; + size_t num_tensors; + // Shape information (OPTIONAL - empty if dimension is uniform across all tensors) // first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim) // last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim) @@ -330,8 +333,6 @@ struct GroupedTensor { // Always 2D with positive dimensions NVTEShape logical_shape; - NVTEScalingMode scaling_mode; - size_t num_tensors; NVTEGroupedTensor nvte_tensor; GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors) @@ -342,12 +343,12 @@ struct GroupedTensor { amax(), columnwise_amax(), scale(), + scaling_mode(scaling_mode), num_tensors(num_tensors), first_dims(nullptr, std::vector{0}, DType::kInt64), last_dims(nullptr, std::vector{0}, DType::kInt64), tensor_offsets(nullptr, std::vector{0}, DType::kInt64), logical_shape(nvte_make_shape(nullptr, 1)), - scaling_mode(scaling_mode), nvte_tensor(0) {} explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 53023361e4..d13ed97de1 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -250,7 +250,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( fe::graph::SDPA_attributes sdpa_options; sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") - .set_is_inference(false) .set_generate_stats(generate_stats) .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f886ec77f4..fe859b0b22 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1810,7 +1810,7 @@ void fused_attn_fp8_fwd_impl_v1( fe::graph::SDPA_fp8_attributes sdpa_options; sdpa_options = fe::graph::SDPA_fp8_attributes() .set_name("sdpa_fp8") - .set_is_inference(false) + .set_generate_stats(true) .set_causal_mask(is_causal) .set_attn_scale(attn_scale); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f7cf32eaf6..e0ea3d6b78 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -548,6 +548,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve ~CommOverlap() {} + using transformer_engine::CommOverlapCore::copy_into_buffer; void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); at::Tensor get_buffer(bool local_chunk = false, @@ -569,6 +570,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ~CommOverlapP2P() {} + using transformer_engine::CommOverlapP2PBase::copy_into_buffer; void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); at::Tensor get_buffer(bool local_chunk = false, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 79dd9ea5ce..1e907d9bc0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -492,8 +492,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) - .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), - py::arg("local_chunk") = false) + .def("copy_into_buffer", + static_cast( + &CommOverlap::copy_into_buffer), + py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlap::get_communication_stream); @@ -510,8 +512,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) - .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), - py::arg("local_chunk") = false) + .def("copy_into_buffer", + static_cast( + &CommOverlapP2P::copy_into_buffer), + py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlapP2P::get_communication_stream);