diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a83cbe3e30..3c106b982e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -43,8 +43,6 @@ set(NVTE_SPECIFIC_ARCHS) # Check for architecture 100 list(FIND CMAKE_CUDA_ARCHITECTURES "100" arch_100_index) if(NOT arch_100_index EQUAL -1) - list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "100") - list(APPEND NVTE_GENERIC_ARCHS "100") list(APPEND NVTE_SPECIFIC_ARCHS "100a") if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) list(APPEND NVTE_SPECIFIC_ARCHS "103a") @@ -54,31 +52,57 @@ endif() # Check for architecture 101 (if we see this we are in toolkit <= 12.9) list(FIND CMAKE_CUDA_ARCHITECTURES "101" arch_101_index) if(NOT arch_101_index EQUAL -1) - list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "101") - list(APPEND NVTE_GENERIC_ARCHS "101") list(APPEND NVTE_SPECIFIC_ARCHS "101a") endif() # Check for architecture 110 (if we see this we are in toolkit >= 13.0) list(FIND CMAKE_CUDA_ARCHITECTURES "110" arch_110_index) if(NOT arch_110_index EQUAL -1) - list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110") - list(APPEND NVTE_GENERIC_ARCHS "110") - list(APPEND NVTE_SPECIFIC_ARCHS "110f") + if(CMAKE_VERSION VERSION_GREATER_EQUAL 4.0.2) + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110") + list(APPEND CMAKE_CUDA_ARCHITECTURES "110f") + else() + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110") + list(APPEND NVTE_GENERIC_ARCHS "110") + list(APPEND NVTE_SPECIFIC_ARCHS "110f") + endif() endif() # Check for architecture 120 list(FIND CMAKE_CUDA_ARCHITECTURES "120" arch_120_index) if(NOT arch_120_index EQUAL -1) list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "120") - list(APPEND NVTE_GENERIC_ARCHS "120") - if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) - list(APPEND NVTE_SPECIFIC_ARCHS "120f") + if(CMAKE_VERSION VERSION_GREATER_EQUAL 4.0.2) + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) + list(APPEND CMAKE_CUDA_ARCHITECTURES "120f") + else() + list(APPEND NVTE_GENERIC_ARCHS "120") + list(APPEND NVTE_SPECIFIC_ARCHS "120a") + endif() else() - list(APPEND NVTE_SPECIFIC_ARCHS "120a") + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) + list(APPEND NVTE_GENERIC_ARCHS "120") + list(APPEND NVTE_SPECIFIC_ARCHS "120f") + else() + list(APPEND NVTE_GENERIC_ARCHS "120") + list(APPEND NVTE_SPECIFIC_ARCHS "120a") + endif() + endif() +endif() + +if(CMAKE_VERSION VERSION_LESS 4.0.2) + if(NOT CMAKE_CUDA_ARCHITECTURES) + message(WARNING + "CMAKE_CUDA_ARCHITECTURES is empty after replacing arch-specific targets. " + "Please upgrade to CMake 4.0.2+ for native 'f' architecture support. " + "Adding sm_75 target in addition to the specified target to avoid configuration " + "errors - this will result in longer build time, but does not affect correctness.") + set(CMAKE_CUDA_ARCHITECTURES 75) endif() endif() +set(NVTE_ARCH_SPECIFIC_TARGETS TRUE) + # cuDNN frontend API set(CUDNN_FRONTEND_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") @@ -193,7 +217,6 @@ foreach(cuda_source IN LISTS transformer_engine_cuda_sources) foreach(arch IN LISTS NVTE_GENERIC_ARCHS) list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}") endforeach() - if(arch_compile_options) set_property( SOURCE ${cuda_source} @@ -204,7 +227,6 @@ foreach(cuda_source IN LISTS transformer_engine_cuda_sources) endif() endforeach() -# Set compile options for CUDA sources with specific architectures foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources) set(arch_compile_options) foreach(arch IN LISTS NVTE_SPECIFIC_ARCHS) @@ -221,6 +243,17 @@ foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources) endif() endforeach() +if(NVTE_ARCH_SPECIFIC_TARGETS) + foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources) + set_property( + SOURCE ${cuda_source} + APPEND + PROPERTY + COMPILE_DEFINITIONS NVTE_HAS_ARCH_SPECIFIC_TARGETS=1 + ) + endforeach() +endif() + if (NVTE_WITH_CUBLASMP) list(APPEND transformer_engine_SOURCES comm_gemm/comm_gemm.cpp) diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 9bcf6e2289..840e0c6a0e 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -31,11 +31,15 @@ struct ArchSpecific { template constexpr static bool compatible() { if constexpr (CurrentArch == id) { +#if !defined(NVTE_HAS_ARCH_SPECIFIC_TARGETS) static_assert(ArchSpecific == CurrentArch, "Compiled for the generic architecture, while utilizing arch-specific " "features. Please compile for smXXXa architecture instead of smXXX " "architecture."); return true; +#else + return ArchSpecific == CurrentArch; +#endif } else { return false; } @@ -49,11 +53,15 @@ struct FamilySpecific { template constexpr static bool compatible() { if constexpr ((CurrentArch / 100) == (id / 100)) { +#if !defined(NVTE_HAS_ARCH_SPECIFIC_TARGETS) static_assert(FamilySpecific == CurrentArch, "Compiled for the generic architecture, while utilizing family-specific " "features. Please compile for smXXXf architecture instead of smXXX " "architecture."); return true; +#else + return FamilySpecific == CurrentArch; +#endif } else { return false; }