@@ -111,11 +111,31 @@ Status DeformConv<T>::ComputeInternal(OpKernelContext* context) const {
111111 // We use a safe max(1, ...) for bytes_per_image to avoid division by zero in edge cases
112112 const size_t bytes_per_image = SafeInt<size_t >(output_image_size) * (C * kernel_size + M / group) * sizeof (T);
113113
114- // Heuristic: limit temp memory to 256MB per chunk to balance parallelism and memory usage.
115- // For small images, this allows up to kMaxParallelImgs (32).
116- // For large images (4K/8K), this restricts parallelism to 1 to prevent OOM.
117- constexpr size_t kMaxTempMemSize = 256 * 1024 * 1024 ;
118- const int max_parallel_imgs_mem = std::max (1 , static_cast <int >(kMaxTempMemSize / std::max (size_t (1 ), bytes_per_image)));
114+ // Heuristic: limit temp memory per chunk to balance parallelism and memory usage.
115+ // Mirrors Conv's approach (conv_8.h): use 90% of free memory (10% fragmentation buffer).
116+ // Tiered cap based on free memory: larger GPUs get higher limits for better parallelism.
117+ size_t effective_max_temp = 256ULL * 1024 * 1024 ; // default fallback
118+ constexpr size_t kMinTempMemSize = 32ULL * 1024 * 1024 ; // 32MB floor
119+ {
120+ size_t free_mem = 0 , total_mem = 0 ;
121+ if (cudaMemGetInfo (&free_mem, &total_mem) == cudaSuccess && free_mem > 0 ) {
122+ free_mem = static_cast <size_t >(static_cast <double >(free_mem) * 0.9 ); // 10% fragmentation buffer
123+ size_t kMaxTempMemSize ;
124+ if (free_mem > 16ULL * 1024 * 1024 * 1024 ) {
125+ kMaxTempMemSize = 2ULL * 1024 * 1024 * 1024 ; // 16GB+ free → 2GB
126+ } else if (free_mem > 8ULL * 1024 * 1024 * 1024 ) {
127+ kMaxTempMemSize = 1ULL * 1024 * 1024 * 1024 ; // 8-16GB → 1GB
128+ } else if (free_mem > 4ULL * 1024 * 1024 * 1024 ) {
129+ kMaxTempMemSize = 512ULL * 1024 * 1024 ; // 4-8GB → 512MB
130+ } else if (free_mem > 2ULL * 1024 * 1024 * 1024 ) {
131+ kMaxTempMemSize = 256ULL * 1024 * 1024 ; // 2-4GB → 256MB
132+ } else {
133+ kMaxTempMemSize = 128ULL * 1024 * 1024 ; // <2GB → 128MB
134+ }
135+ effective_max_temp = std::max (kMinTempMemSize , std::min (kMaxTempMemSize , free_mem));
136+ }
137+ }
138+ const int max_parallel_imgs_mem = std::max (1 , static_cast <int >(effective_max_temp / std::max (size_t (1 ), bytes_per_image)));
119139 const int target_parallel_imgs = std::min (kMaxParallelImgs , max_parallel_imgs_mem);
120140
121141 const int n_parallel_imgs = GetGreatestDivisorBelowBound (static_cast <int >(N), target_parallel_imgs);
0 commit comments