diff --git a/csrc/gemm_4bit.cu b/csrc/gemm_4bit.cu index 557acabb4..144f0e8d7 100644 --- a/csrc/gemm_4bit.cu +++ b/csrc/gemm_4bit.cu @@ -15,15 +15,23 @@ // 16-entry cache indexed by device ID. num_sms==0 means not yet populated. // Static storage is zero-initialized, so all entries start unpopulated (num_sms==0). GpuProps get_gpu_props() { - static GpuProps cache[16]; + static GpuProps cache[16] = {}; int dev = 0; cudaGetDevice(&dev); - if (dev < 16 && cache[dev].num_sms == 0) { - cudaDeviceGetAttribute(&cache[dev].num_sms, cudaDevAttrMultiProcessorCount, dev); - cudaDeviceGetAttribute(&cache[dev].cc_major, cudaDevAttrComputeCapabilityMajor, dev); - cudaDeviceGetAttribute(&cache[dev].cc_minor, cudaDevAttrComputeCapabilityMinor, dev); - } - return cache[dev]; + + if (dev < 16 && cache[dev].num_sms != 0) + return cache[dev]; + + GpuProps props = {}; + props.device_index = dev; + cudaDeviceGetAttribute(&props.num_sms, cudaDevAttrMultiProcessorCount, dev); + cudaDeviceGetAttribute(&props.cc_major, cudaDevAttrComputeCapabilityMajor, dev); + cudaDeviceGetAttribute(&props.cc_minor, cudaDevAttrComputeCapabilityMinor, dev); + + if (dev < 16) + cache[dev] = props; + + return props; } /// @brief Fused 4-bit dequantize + GEMM. Computes out[M,N] = A[M,K] @ B[N,K]^T + bias. diff --git a/csrc/gemm_4bit_common.cuh b/csrc/gemm_4bit_common.cuh index 3febc8741..1d640887f 100644 --- a/csrc/gemm_4bit_common.cuh +++ b/csrc/gemm_4bit_common.cuh @@ -5,7 +5,7 @@ // GPU properties queried once per device and cached in gemm_4bit.cu. // Passed through dispatch into MMA launchers to avoid repeated cudaGetDevice calls. struct GpuProps { - int num_sms, cc_major, cc_minor; + int device_index, num_sms, cc_major, cc_minor; }; #include diff --git a/csrc/gemm_4bit_sm75.cu b/csrc/gemm_4bit_sm75.cu index be5ec99a3..4f9bcaa9a 100644 --- a/csrc/gemm_4bit_sm75.cu +++ b/csrc/gemm_4bit_sm75.cu @@ -317,14 +317,16 @@ static void launch_tile( int M, int N, int K, int blocksize, int quant_type, + GpuProps gpu, cudaStream_t stream // clang-format on ) { constexpr int smem = smem_bytes_for(); - static bool cfg = false; - if (!cfg) { + static bool cfg[16] = {}; + if (gpu.device_index >= 16 || !cfg[gpu.device_index]) { cudaFuncSetAttribute(gemm_4bit_sm75_m16n8k8, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - cfg = true; + if (gpu.device_index < 16) + cfg[gpu.device_index] = true; } dim3 grid((M + MT - 1) / MT, (N + NT - 1) / NT); gemm_4bit_sm75_m16n8k8<<>>( @@ -396,7 +398,7 @@ void launch_gemm_4bit_sm75_m16n8k8( // clang-format off #define LAUNCH_SM75(MT, NT) \ - launch_tile(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, stream) + launch_tile(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, gpu, stream) if (mt == 32 && nt == 64) LAUNCH_SM75(32, 64); else if (mt == 32 && nt == 128) LAUNCH_SM75(32, 128); diff --git a/csrc/gemm_4bit_sm80.cu b/csrc/gemm_4bit_sm80.cu index e4c74460b..3339a93da 100644 --- a/csrc/gemm_4bit_sm80.cu +++ b/csrc/gemm_4bit_sm80.cu @@ -469,14 +469,16 @@ static void launch_tile( int M, int N, int K, int blocksize, int quant_type, + GpuProps gpu, cudaStream_t stream // clang-format on ) { constexpr int smem = smem_bytes_for(); - static bool cfg = false; - if (!cfg) { + static bool cfg[16] = {}; + if (gpu.device_index >= 16 || !cfg[gpu.device_index]) { cudaFuncSetAttribute(gemm_4bit_sm80_m16n8k16, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - cfg = true; + if (gpu.device_index < 16) + cfg[gpu.device_index] = true; } dim3 grid((M + MT - 1) / MT, (N + NT - 1) / NT); gemm_4bit_sm80_m16n8k16<<>>( @@ -662,7 +664,7 @@ void launch_gemm_4bit_sm80_m16n8k16( // clang-format off #define LAUNCH_SM80(MT, NT, KC) \ - launch_tile(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, stream) + launch_tile(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, gpu, stream) if (kc == 64) { if (mt == 32 && nt == 64) LAUNCH_SM80( 32, 64, 64);