Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 107 additions & 93 deletions include/nvexec/stream/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -488,27 +488,31 @@ namespace nv::execution
template <class Env, class Variant>
struct stream_enqueue_receiver
{
Env* env_;
Variant* variant_;
queue::task_base* task_;
queue::producer producer_;

public:
using receiver_concept = STDEXEC::receiver_t;

explicit stream_enqueue_receiver(Env const * env,
Variant* variant,
queue::task_base* task,
queue::producer producer)
: env_(env)
, variant_(variant)
, task_(task)
, producer_(producer)
{}

template <class... Args>
STDEXEC_ATTRIBUTE(host, device)
void set_value(Args&&... args) noexcept
{
variant_->template emplace<decayed_tuple_t<set_value_t, Args...>>(set_value_t(),
static_cast<Args&&>(
args)...);
using tuple_t = decayed_tuple_t<set_value_t, Args...>;
variant_->template emplace<tuple_t>(set_value_t(), static_cast<Args&&>(args)...);
producer_(task_);
}

STDEXEC_ATTRIBUTE(host, device) void set_stopped() noexcept
{
variant_->template emplace<decayed_tuple_t<set_stopped_t>>(set_stopped_t());
using tuple_t = decayed_tuple_t<set_stopped_t>;
variant_->template emplace<tuple_t>(set_stopped_t());
producer_(task_);
}

Expand All @@ -519,14 +523,13 @@ namespace nv::execution
if constexpr (__decays_to<Error, std::exception_ptr>)
{
// What is `exception_ptr` but death pending
variant_->template emplace<decayed_tuple_t<set_error_t, cudaError_t>>(STDEXEC::set_error,
cudaErrorUnknown);
using tuple_t = decayed_tuple_t<set_error_t, cudaError_t>;
variant_->template emplace<tuple_t>(STDEXEC::set_error, cudaErrorUnknown);
}
else
{
variant_->template emplace<decayed_tuple_t<set_error_t, Error>>(set_error_t(),
static_cast<Error&&>(
err));
using tuple_t = decayed_tuple_t<set_error_t, Error>;
variant_->template emplace<tuple_t>(set_error_t(), static_cast<Error&&>(err));
}
producer_(task_);
}
Expand All @@ -536,15 +539,11 @@ namespace nv::execution
return *env_;
}

stream_enqueue_receiver(Env* env,
Variant* variant,
queue::task_base* task,
queue::producer producer)
: env_(env)
, variant_(variant)
, task_(task)
, producer_(producer)
{}
private:
Env const * env_;
Variant* variant_;
queue::task_base* task_;
queue::producer producer_;
};

template <class Receiver, class... Args, class Tag>
Expand All @@ -558,16 +557,10 @@ namespace nv::execution
template <class Receiver, class Variant>
struct continuation_task : queue::task_base
{
Receiver rcvr_;
Variant* variant_;
cudaStream_t stream_{};
std::pmr::memory_resource* pinned_resource_{};
cudaError_t status_{cudaSuccess};

continuation_task(Receiver rcvr,
Variant* variant,
cudaStream_t stream,
std::pmr::memory_resource* pinned_resource) noexcept
explicit continuation_task(Receiver rcvr,
Variant* variant,
cudaStream_t stream,
std::pmr::memory_resource* pinned_resource) noexcept
: rcvr_{rcvr}
, variant_{variant}
, stream_{stream}
Expand Down Expand Up @@ -606,6 +599,18 @@ namespace nv::execution
status_ = STDEXEC_LOG_CUDA_API(cudaMemsetAsync(this->atom_next_, 0, ptr_size, stream_));
}
}

cudaError_t status() const noexcept
{
return status_;
}

private:
Receiver rcvr_;
Variant* variant_;
cudaStream_t stream_{};
std::pmr::memory_resource* pinned_resource_{};
cudaError_t status_{cudaSuccess};
};

template <class Env>
Expand Down Expand Up @@ -695,6 +700,7 @@ namespace nv::execution
}
}

[[nodiscard]]
auto make_env() const noexcept -> env_t
{
return make_stream_env(get_env(rcvr_), get_stream_provider());
Expand Down Expand Up @@ -738,10 +744,12 @@ namespace nv::execution
stream_provider stream_provider_;
};

template <class OuterReceiver>
template <class OpState, class Env = decltype(__declval<OpState&>().make_env())>
struct propagate_receiver : stream_receiver_base
{
opstate_base<OuterReceiver>& opstate_;
explicit propagate_receiver(OpState& opstate) noexcept
: opstate_(opstate)
{}

template <class... Args>
void set_value(Args&&... args) noexcept
Expand All @@ -760,10 +768,14 @@ namespace nv::execution
opstate_.propagate_completion_signal(set_stopped_t());
}

auto get_env() const noexcept -> decltype(auto)
[[nodiscard]]
auto get_env() const noexcept -> Env
{
return opstate_.make_env();
}

private:
OpState& opstate_;
};

template <class CvSender, class InnerReceiver, class OuterReceiver>
Expand All @@ -780,38 +792,6 @@ namespace nv::execution
__if_c<stream_sender<CvSender, env_t>, InnerReceiver, stream_enqueue_receiver_t>;
using inner_opstate_t = connect_result_t<CvSender, intermediate_receiver_t>;

void start() & noexcept
{
started_.test_and_set(::cuda::std::memory_order::relaxed);

if (this->stream_provider_.status_ != cudaSuccess)
{
// Couldn't allocate memory for opstate state, complete with error
this->propagate_completion_signal(STDEXEC::set_error,
std::move(this->stream_provider_.status_));
return;
}

if constexpr (stream_receiver<InnerReceiver>)
{
if (InnerReceiver::memory_allocation_size())
{
STDEXEC_TRY
{
this->temp_storage_ = this->ctx_.managed_resource_->allocate(
InnerReceiver::memory_allocation_size());
}
STDEXEC_CATCH_ALL
{
this->propagate_completion_signal(STDEXEC::set_error, cudaErrorMemoryAllocation);
return;
}
}
}

STDEXEC::start(inner_op_);
}

template <class ReceiverProvider>
requires stream_sender<CvSender, env_t>
opstate(CvSender&& sender,
Expand Down Expand Up @@ -846,7 +826,7 @@ namespace nv::execution
{
if (this->stream_provider_.status_ == cudaSuccess)
{
this->stream_provider_.status_ = task_->status_;
this->stream_provider_.status_ = task_->status();
}
}

Expand All @@ -870,6 +850,39 @@ namespace nv::execution

STDEXEC_IMMOVABLE(opstate);

void start() & noexcept
{
started_.test_and_set(::cuda::std::memory_order::relaxed);

if (this->stream_provider_.status_ != cudaSuccess)
{
// Couldn't allocate memory for opstate state, complete with error
this->propagate_completion_signal(STDEXEC::set_error,
std::move(this->stream_provider_.status_));
return;
}

if constexpr (stream_receiver<InnerReceiver>)
{
if (InnerReceiver::memory_allocation_size())
{
STDEXEC_TRY
{
this->temp_storage_ = this->ctx_.managed_resource_->allocate(
InnerReceiver::memory_allocation_size());
}
STDEXEC_CATCH_ALL
{
this->propagate_completion_signal(STDEXEC::set_error, cudaErrorMemoryAllocation);
return;
}
}
}

STDEXEC::start(inner_op_);
}

private:
host_ptr_t<variant_t> storage_;
task_t* task_{};
::cuda::std::atomic_flag started_{};
Expand All @@ -880,60 +893,61 @@ namespace nv::execution
template <class CvSender, class OuterReceiver>
requires stream_receiver<OuterReceiver>
using exit_opstate_t =
_strm::opstate<CvSender, propagate_receiver<OuterReceiver>, OuterReceiver>;
_strm::opstate<CvSender, propagate_receiver<opstate_base<OuterReceiver>>, OuterReceiver>;

template <class Sender, class OuterReceiver>
auto exit_opstate(Sender&& sndr, OuterReceiver rcvr, context ctx) noexcept
-> exit_opstate_t<Sender, OuterReceiver>
template <class CvSender, class OuterReceiver>
auto exit_opstate(CvSender&& sndr, OuterReceiver rcvr, context ctx) noexcept
-> exit_opstate_t<CvSender, OuterReceiver>
{
return exit_opstate_t<Sender, OuterReceiver>(
static_cast<Sender&&>(sndr),
return exit_opstate_t<CvSender, OuterReceiver>(
static_cast<CvSender&&>(sndr),
static_cast<OuterReceiver&&>(rcvr),
[](opstate_base<OuterReceiver>& op) -> propagate_receiver<OuterReceiver>
{ return propagate_receiver<OuterReceiver>{{}, op}; },
[](opstate_base<OuterReceiver>& op) noexcept
{ return propagate_receiver<opstate_base<OuterReceiver>>(op); },
ctx);
}

template <class Sender, class E>
template <class Sender, class Env>
concept stream_completing_sender =
sender<Sender>
&& gpu_stream_scheduler<
__result_of<get_completion_scheduler<set_value_t>, env_of_t<Sender>, E>,
E>;
__result_of<get_completion_scheduler<set_value_t>, env_of_t<Sender>, Env>,
Env>;

template <class InnerReceiverProvider, class OuterReceiver>
using inner_receiver_t = __call_result_t<InnerReceiverProvider, opstate_base<OuterReceiver>&>;

template <class CvSender, class InnerReceiver, class OuterReceiver>
using stream_opstate_t = _strm::opstate<CvSender, InnerReceiver, OuterReceiver>;

template <class Sender, class OuterReceiver, class ReceiverProvider>
requires stream_completing_sender<Sender, env_of_t<OuterReceiver>>
auto
stream_opstate(Sender&& sndr, OuterReceiver&& out_receiver, ReceiverProvider receiver_provider)
-> stream_opstate_t<Sender, inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
template <class CvSender, class OuterReceiver, class ReceiverProvider>
requires stream_completing_sender<CvSender, env_of_t<OuterReceiver>>
auto stream_opstate(CvSender&& sndr,
OuterReceiver&& out_receiver,
ReceiverProvider receiver_provider)
-> stream_opstate_t<CvSender, inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
{
auto sch = get_completion_scheduler<set_value_t>(get_env(sndr), get_env(out_receiver));
context ctx = sch.ctx_;

return stream_opstate_t<Sender,
return stream_opstate_t<CvSender,
inner_receiver_t<ReceiverProvider, OuterReceiver>,
OuterReceiver>(static_cast<Sender&&>(sndr),
OuterReceiver>(static_cast<CvSender&&>(sndr),
static_cast<OuterReceiver&&>(out_receiver),
receiver_provider,
ctx);
}

template <class Sender, class OuterReceiver, class ReceiverProvider>
auto stream_opstate(Sender&& sndr,
template <class CvSender, class OuterReceiver, class ReceiverProvider>
auto stream_opstate(CvSender&& sndr,
OuterReceiver&& out_receiver,
ReceiverProvider receiver_provider,
context ctx)
-> stream_opstate_t<Sender, inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
-> stream_opstate_t<CvSender, inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
{
return stream_opstate_t<Sender,
return stream_opstate_t<CvSender,
inner_receiver_t<ReceiverProvider, OuterReceiver>,
OuterReceiver>(static_cast<Sender&&>(sndr),
OuterReceiver>(static_cast<CvSender&&>(sndr),
static_cast<OuterReceiver&&>(out_receiver),
receiver_provider,
ctx);
Expand Down
Loading
Loading