Skip to content

Commit 962d395

Browse files
marandjeassistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#5014 (commit 21e3a39)
Fix use-after-free for non-trivially-destructible functors in hipstdpar (#5014) ## Motivation Functors must remain observable until the kernel finishes it's execution. ## Technical Details Non-trivially-destructible functors are allocated in HIP managed memory and passed to the kernel by pointer instead of by value ## Test Plan / ## Test Result / ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent aaa49d9 commit 962d395

1 file changed

Lines changed: 59 additions & 1 deletion

File tree

thrust/system/hip/detail/parallel_for.h

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343

4444
# include <thrust/system/hip/detail/util.h>
4545

46+
# include <new>
47+
# include <type_traits>
48+
# include <utility>
49+
4650
THRUST_NAMESPACE_BEGIN
4751

4852
namespace hip_rocprim
@@ -131,6 +135,50 @@ hipError_t THRUST_HIP_RUNTIME_FUNCTION parallel_for(Size num_items, F f, hipStre
131135
}
132136
return hipSuccess;
133137
}
138+
139+
template <class F>
140+
class managed_callable_guard
141+
{
142+
public:
143+
explicit managed_callable_guard(F&& f)
144+
{
145+
hipError_t status = ::hipMallocManaged(reinterpret_cast<void**>(&f_ptr_), sizeof(F));
146+
hip_rocprim::throw_on_error(status, "parallel_for: failed to allocate managed callable");
147+
::new (static_cast<void*>(f_ptr_)) F(::std::move(f));
148+
}
149+
150+
managed_callable_guard(const managed_callable_guard&) = delete;
151+
managed_callable_guard& operator=(const managed_callable_guard&) = delete;
152+
153+
~managed_callable_guard()
154+
{
155+
if (f_ptr_ != nullptr)
156+
{
157+
f_ptr_->~F();
158+
(void) ::hipFree(f_ptr_);
159+
}
160+
}
161+
162+
F* get() const noexcept
163+
{
164+
return f_ptr_;
165+
}
166+
167+
private:
168+
F* f_ptr_ = nullptr;
169+
};
170+
171+
template <class F>
172+
struct callable_proxy
173+
{
174+
F* f_ptr;
175+
176+
template <class... Args>
177+
THRUST_HIP_FUNCTION auto operator()(Args&&... args) const -> decltype((*f_ptr)(::std::forward<Args>(args)...))
178+
{
179+
return (*f_ptr)(::std::forward<Args>(args)...);
180+
}
181+
};
134182
} // namespace __parallel_for
135183

136184
THRUST_EXEC_CHECK_DISABLE
@@ -149,7 +197,17 @@ void THRUST_HOST_DEVICE parallel_for(execution_policy<Derived>& policy, F f, Siz
149197
THRUST_HOST static void par(execution_policy<Derived>& policy, F f, Size count)
150198
{
151199
hipStream_t stream = hip_rocprim::stream(policy);
152-
hipError_t status = __parallel_for::parallel_for(count, f, stream);
200+
if constexpr (!::std::is_trivially_destructible_v<F>)
201+
{
202+
__parallel_for::managed_callable_guard<F> guard(::std::move(f));
203+
hipError_t status = __parallel_for::parallel_for(count, __parallel_for::callable_proxy<F>{guard.get()}, stream);
204+
hip_rocprim::throw_on_error(status, "parallel_for failed");
205+
status = hip_rocprim::synchronize_optional(policy);
206+
hip_rocprim::throw_on_error(status, "parallel_for: failed to synchronize");
207+
return;
208+
}
209+
210+
hipError_t status = __parallel_for::parallel_for(count, f, stream);
153211
hip_rocprim::throw_on_error(status, "parallel_for failed");
154212
status = hip_rocprim::synchronize_optional(policy);
155213
hip_rocprim::throw_on_error(status, "parallel_for: failed to synchronize");

0 commit comments

Comments
 (0)