Skip to content

Commit 1222ad4

Browse files
committed
Use GPU free memory in DeformConv temp memory heuristic
1 parent e4d3c51 commit 1222ad4

1 file changed

Lines changed: 25 additions & 5 deletions

File tree

onnxruntime/core/providers/cuda/nn/deform_conv.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,31 @@ Status DeformConv<T>::ComputeInternal(OpKernelContext* context) const {
111111
// We use a safe max(1, ...) for bytes_per_image to avoid division by zero in edge cases
112112
const size_t bytes_per_image = SafeInt<size_t>(output_image_size) * (C * kernel_size + M / group) * sizeof(T);
113113

114-
// Heuristic: limit temp memory to 256MB per chunk to balance parallelism and memory usage.
115-
// For small images, this allows up to kMaxParallelImgs (32).
116-
// For large images (4K/8K), this restricts parallelism to 1 to prevent OOM.
117-
constexpr size_t kMaxTempMemSize = 256 * 1024 * 1024;
118-
const int max_parallel_imgs_mem = std::max(1, static_cast<int>(kMaxTempMemSize / std::max(size_t(1), bytes_per_image)));
114+
// Heuristic: limit temp memory per chunk to balance parallelism and memory usage.
115+
// Mirrors Conv's approach (conv_8.h): use 90% of free memory (10% fragmentation buffer).
116+
// Tiered cap based on free memory: larger GPUs get higher limits for better parallelism.
117+
size_t effective_max_temp = 256ULL * 1024 * 1024; // default fallback
118+
constexpr size_t kMinTempMemSize = 32ULL * 1024 * 1024; // 32MB floor
119+
{
120+
size_t free_mem = 0, total_mem = 0;
121+
if (cudaMemGetInfo(&free_mem, &total_mem) == cudaSuccess && free_mem > 0) {
122+
free_mem = static_cast<size_t>(static_cast<double>(free_mem) * 0.9); // 10% fragmentation buffer
123+
size_t kMaxTempMemSize;
124+
if (free_mem > 16ULL * 1024 * 1024 * 1024) {
125+
kMaxTempMemSize = 2ULL * 1024 * 1024 * 1024; // 16GB+ free → 2GB
126+
} else if (free_mem > 8ULL * 1024 * 1024 * 1024) {
127+
kMaxTempMemSize = 1ULL * 1024 * 1024 * 1024; // 8-16GB → 1GB
128+
} else if (free_mem > 4ULL * 1024 * 1024 * 1024) {
129+
kMaxTempMemSize = 512ULL * 1024 * 1024; // 4-8GB → 512MB
130+
} else if (free_mem > 2ULL * 1024 * 1024 * 1024) {
131+
kMaxTempMemSize = 256ULL * 1024 * 1024; // 2-4GB → 256MB
132+
} else {
133+
kMaxTempMemSize = 128ULL * 1024 * 1024; // <2GB → 128MB
134+
}
135+
effective_max_temp = std::max(kMinTempMemSize, std::min(kMaxTempMemSize, free_mem));
136+
}
137+
}
138+
const int max_parallel_imgs_mem = std::max(1, static_cast<int>(effective_max_temp / std::max(size_t(1), bytes_per_image)));
119139
const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem);
120140

121141
const int n_parallel_imgs = GetGreatestDivisorBelowBound(static_cast<int>(N), target_parallel_imgs);

0 commit comments

Comments
 (0)