diff --git a/examples/nvexec/maxwell/snr.cuh b/examples/nvexec/maxwell/snr.cuh index dc4e13808..0497d3e1e 100644 --- a/examples/nvexec/maxwell/snr.cuh +++ b/examples/nvexec/maxwell/snr.cuh @@ -21,6 +21,8 @@ #include "common.cuh" #include "stdexec/execution.hpp" // IWYU pragma: export +#include "exec/repeat_n.hpp" // IWYU pragma: export + namespace ex = stdexec; #if STDEXEC_CUDA_COMPILATION() @@ -69,20 +71,20 @@ namespace _repeat_n struct repeat_n_t { template - auto operator()(Sender __sndr, std::size_t n, Closure closure) const noexcept + auto operator()(Sender __sndr, std::size_t count, Closure closure) const noexcept -> _repeat_n::sender { return _repeat_n::sender{ {}, - {closure, n}, + {closure, count}, std::move(__sndr) }; } template - auto operator()(std::size_t n, Closure closure) const + auto operator()(std::size_t count, Closure closure) const { - return ex::__closure(*this, n, closure); + return ex::__closure(*this, count, closure); } }; @@ -90,140 +92,70 @@ inline constexpr repeat_n_t repeat_n{}; namespace _repeat_n { - template - class receiver_2_t + template + class receiver { - using Sender = OpT::child_t; - using Receiver = OpT::receiver_t; - - OpT& op_state_; - + using receiver_t = OpState::receiver_t; public: using receiver_concept = ex::receiver_t; - void set_value() noexcept - { - using inner_op_state_t = OpT::inner_op_state_t; - - op_state_.i_++; - - if (op_state_.i_ == op_state_.n_) - { - ex::set_value(std::move(op_state_.rcvr_)); - return; - } - - auto sch = ex::get_scheduler(ex::get_env(op_state_.rcvr_)); - inner_op_state_t& inner_op_state = op_state_.inner_op_state_.emplace(ex::__emplace_from{ - [&]() noexcept - { - return ex::connect(ex::schedule(sch) | op_state_.closure_, receiver_2_t{op_state_}); - }}); - - ex::start(inner_op_state); - } - - template - void set_error(Error&& err) noexcept - { - ex::set_error(std::move(op_state_.rcvr_), static_cast(err)); - } - - void set_stopped() noexcept - { - ex::set_stopped(std::move(op_state_.rcvr_)); - } - - [[nodiscard]] - auto get_env() const noexcept -> ex::env_of_t - { - return ex::get_env(op_state_.rcvr_); - } - - explicit receiver_2_t(OpT& op_state) - : op_state_(op_state) + explicit receiver(OpState& op_state) + : opstate_(op_state) {} - }; - - template - class receiver_1_t - { - using Receiver = OpT::receiver_t; - - OpT& op_state_; - - public: - using receiver_concept = ex::receiver_t; void set_value() noexcept { - using inner_op_state_t = OpT::inner_op_state_t; - - if (op_state_.n_) + if (opstate_.count_ == 0) { - inner_op_state_t& inner_op_state = op_state_.inner_op_state_.emplace( - ex::__emplace_from{[this]() noexcept - { - auto sch = ex::get_scheduler(ex::get_env(op_state_.rcvr_)); - return ex::connect(ex::schedule(sch) | op_state_.closure_, - receiver_2_t{op_state_}); - }}); - - ex::start(inner_op_state); + ex::set_value(std::move(opstate_.rcvr_)); } else { - ex::set_value(std::move(op_state_.rcvr_)); + --opstate_.count_; + ex::start(opstate_._connect()); } } template void set_error(Error&& err) noexcept { - ex::set_error(std::move(op_state_.rcvr_), static_cast(err)); + ex::set_error(std::move(opstate_.rcvr_), static_cast(err)); } void set_stopped() noexcept { - ex::set_stopped(std::move(op_state_.rcvr_)); + ex::set_stopped(std::move(opstate_.rcvr_)); } [[nodiscard]] - auto get_env() const noexcept -> ex::env_of_t + auto get_env() const noexcept -> ex::env_of_t { - return ex::get_env(op_state_.rcvr_); + return ex::get_env(opstate_.rcvr_); } - explicit receiver_1_t(OpT& op_state) - : op_state_(op_state) - {} + private: + OpState& opstate_; }; - template - struct operation_state_t + template + struct opstate { - using receiver_t = Receiver; - using child_t = PredSender; - using Scheduler = std::invoke_result_t>; - using InnerSender = std::invoke_result_t>; - - using predecessor_op_state_t = - ex::connect_result_t>; - using inner_op_state_t = ex::connect_result_t>; - - PredSender pred_sender_; - Closure closure_; - Receiver rcvr_; - std::optional pred_op_state_; - std::optional inner_op_state_; - std::size_t n_{}; - std::size_t i_{}; + opstate(CvSender&& sndr, Closure closure, Receiver&& rcvr, std::size_t count) + : rcvr_(static_cast(rcvr)) + , count_(count) + , closure_(std::move(closure)) + , sched_(_get_scheduler(sndr)) + { + pred_opstate_.emplace(ex::__emplace_from{ + [&]() noexcept + { return ex::connect(static_cast(sndr), receiver{*this}); }}); + } void start() & noexcept { - if (n_) + if (count_ > 0) { - ex::start(*pred_op_state_); + ex::start(*pred_opstate_); } else { @@ -231,16 +163,36 @@ namespace _repeat_n } } - operation_state_t(PredSender&& pred_sender, Closure closure, Receiver&& rcvr, std::size_t n) - : pred_sender_{static_cast(pred_sender)} - , closure_(closure) - , rcvr_(rcvr) - , n_(n) + private: + friend receiver; + + using receiver_t = Receiver; + using scheduler_t = std::invoke_result_t, + ex::env_of_t, + ex::env_of_t>; + using inner_sender_t = std::invoke_result_t>; + using pred_opstate_t = ex::connect_result_t>; + using inner_opstate_t = ex::connect_result_t>; + + auto& _connect() { - pred_op_state_.emplace(ex::__emplace_from{ + return inner_opstate_.emplace(ex::__emplace_from{ [&]() noexcept - { return ex::connect(static_cast(pred_sender_), receiver_1_t{*this}); }}); + { return ex::connect(closure_(ex::schedule(sched_)), receiver{*this}); }}); } + + scheduler_t _get_scheduler(CvSender const & sndr) noexcept + { + return ex::get_completion_scheduler(ex::get_env(sndr), + ex::get_env(this->rcvr_)); + } + + Receiver rcvr_; + std::size_t count_; + Closure closure_; + scheduler_t sched_; + std::optional pred_opstate_; + std::optional inner_opstate_; }; template @@ -261,26 +213,25 @@ namespace _repeat_n template Self, ex::receiver Receiver> STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver r) - -> _repeat_n::operation_state_t + -> _repeat_n::opstate { - return _repeat_n::operation_state_t( - static_cast(self).sender_, - static_cast(self).data_.first, - static_cast(r), - self.data_.second); + return _repeat_n::opstate(static_cast(self).sndr_, + static_cast(self).data_.first, + static_cast(r), + self.data_.second); } STDEXEC_EXPLICIT_THIS_END(connect) [[nodiscard]] auto get_env() const noexcept -> ex::env_of_t { - return ex::get_env(sender_); + return ex::get_env(sndr_); } STDEXEC_ATTRIBUTE(no_unique_address, maybe_unused) repeat_n_t tag_; std::pair data_; - Sender sender_; + Sender sndr_; }; } // namespace _repeat_n @@ -296,128 +247,59 @@ namespace nv::execution::_strm { namespace _repeat_n { - template - class receiver_2_t : public stream_receiver_base + template + class receiver : public stream_receiver_base { - using Sender = OpT::child_t; - using Receiver = OpT::receiver_t; - - OpT& op_state_; - - public: - void set_value() noexcept - { - using inner_op_state_t = OpT::inner_op_state_t; - - op_state_.i_++; - - if (op_state_.i_ == op_state_.n_) - { - op_state_.propagate_completion_signal(ex::set_value); - return; - } - - inner_op_state_t& inner_op_state = op_state_.inner_op_state_.emplace(ex::__emplace_from{ - [&]() noexcept - { - return ex::connect(op_state_.closure_(ex::schedule(op_state_.scheduler_)), - receiver_2_t{op_state_}); - }}); - - ex::start(inner_op_state); - } - - template - void set_error(Error&& err) noexcept - { - op_state_.propagate_completion_signal(set_error_t(), static_cast(err)); - } - - void set_stopped() noexcept - { - op_state_.propagate_completion_signal(set_stopped_t()); - } - - auto get_env() const noexcept -> OpT::env_t - { - return op_state_.make_env(); - } - - explicit receiver_2_t(OpT& op_state) - : op_state_(op_state) - {} - }; - - template - class receiver_1_t : public stream_receiver_base - { - using Receiver = OpT::receiver_t; - - OpT& op_state_; - public: - explicit receiver_1_t(OpT& op_state) - : op_state_(op_state) + explicit receiver(OpState& op_state) + : opstate_(op_state) {} void set_value() noexcept { - using inner_op_state_t = OpT::inner_op_state_t; - - if (op_state_.n_) + if (opstate_.count_ == 0) { - inner_op_state_t& inner_op_state = op_state_.inner_op_state_.emplace(ex::__emplace_from{ - [&]() noexcept - { - return ex::connect(op_state_.closure_(ex::schedule(op_state_.scheduler_)), - receiver_2_t{op_state_}); - }}); - - ex::start(inner_op_state); + opstate_.propagate_completion_signal(ex::set_value); } else { - op_state_.propagate_completion_signal(set_value_t()); + --opstate_.count_; + ex::start(opstate_._connect()); } } template void set_error(Error&& err) noexcept { - op_state_.propagate_completion_signal(set_error_t(), static_cast(err)); + opstate_.propagate_completion_signal(set_error_t(), static_cast(err)); } void set_stopped() noexcept { - op_state_.propagate_completion_signal(set_stopped_t()); + opstate_.propagate_completion_signal(set_stopped_t()); } - auto get_env() const noexcept -> OpT::env_t + auto get_env() const noexcept -> OpState::env_t { - return op_state_.make_env(); + return opstate_.make_env(); } + + private: + OpState& opstate_; }; - template - struct operation_state_t : _strm::opstate_base + template + struct opstate : _strm::opstate_base { - using receiver_t = Receiver; - using child_t = PredSender; - using Scheduler = std::invoke_result_t, - ex::env_of_t, - ex::env_of_t>; - using InnerSender = std::invoke_result_t>; - - using predecessor_op_state_t = - ex::connect_result_t>; - using inner_op_state_t = ex::connect_result_t>; - - Scheduler scheduler_; - Closure closure_; - std::optional pred_op_state_; - std::optional inner_op_state_; - std::size_t n_{}; - std::size_t i_{}; + explicit opstate(CvSender&& sndr, Closure closure, Receiver&& rcvr, std::size_t count) + : _strm::opstate_base(std::move(rcvr), _get_scheduler(sndr).ctx_) + , count_(count) + , closure_(std::move(closure)) + , sched_(_get_scheduler(sndr)) + { + pred_opstate_.emplace(ex::__emplace_from{ + [&]() noexcept { return ex::connect(static_cast(sndr), receiver{*this}); }}); + } void start() & noexcept { @@ -425,36 +307,46 @@ namespace nv::execution::_strm { // Couldn't allocate memory for operation state, complete with error this->propagate_completion_signal(ex::set_error, - std::move(this->stream_provider_.status_)); + cudaError_t(this->stream_provider_.status_)); + } + else if (count_ > 0) + { + ex::start(*pred_opstate_); } else { - if (n_) - { - ex::start(*pred_op_state_); - } - else - { - this->propagate_completion_signal(ex::set_value); - } + this->propagate_completion_signal(ex::set_value); } } - operation_state_t(PredSender&& pred_sender, Closure closure, Receiver&& rcvr, std::size_t n) - : _strm::opstate_base( - static_cast(rcvr), - ex::get_completion_scheduler(ex::get_env(pred_sender), - ex::get_env(rcvr)) - .ctx_) - , scheduler_(ex::get_completion_scheduler(ex::get_env(pred_sender), - ex::get_env(rcvr))) - , closure_(closure) - , n_(n) + private: + friend receiver; + + using scheduler_t = std::invoke_result_t, + ex::env_of_t, + ex::env_of_t>; + using inner_sender_t = std::invoke_result_t>; + using pred_opstate_t = ex::connect_result_t>; + using inner_opstate_t = ex::connect_result_t>; + + auto& _connect() { - pred_op_state_.emplace(ex::__emplace_from{ + return inner_opstate_.emplace(ex::__emplace_from{ [&]() noexcept - { return ex::connect(static_cast(pred_sender), receiver_1_t{*this}); }}); + { return ex::connect(closure_(ex::schedule(sched_)), receiver{*this}); }}); } + + scheduler_t _get_scheduler(CvSender const & sndr) noexcept + { + return ex::get_completion_scheduler(ex::get_env(sndr), + ex::get_env(this->rcvr_)); + } + + std::size_t count_; + Closure closure_; + scheduler_t sched_; + std::optional pred_opstate_; + std::optional inner_opstate_; }; template @@ -470,25 +362,25 @@ namespace nv::execution::_strm template Self, ex::receiver Receiver> requires(ex::sender_to) STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver r) - -> nvexec::_strm::_repeat_n::operation_state_t + -> nvexec::_strm::_repeat_n::opstate { - return nvexec::_strm::_repeat_n::operation_state_t( - static_cast(self).sender_, + return nvexec::_strm::_repeat_n::opstate( + static_cast(self).sndr_, static_cast(self).closure_, static_cast(r), - self.n_); + self.count_); } STDEXEC_EXPLICIT_THIS_END(connect) [[nodiscard]] auto get_env() const noexcept -> ex::env_of_t { - return ex::get_env(sender_); + return ex::get_env(sndr_); } - Sender sender_; + Sender sndr_; Closure closure_; - std::size_t n_{}; + std::size_t count_{}; }; } // namespace _repeat_n @@ -526,12 +418,21 @@ auto maxwell_eqs_snr(float dt, fields_accessor accessor, ex::scheduler auto&& computer) { +#if 0 + return ex::on(computer, + exec::repeat_n(ex::just() // + | ex::bulk(ex::par, accessor.cells, update_h(accessor)) + | ex::bulk(ex::par, accessor.cells, update_e(time, dt, accessor)), + n_iterations)) + | ex::then(dump_vtk(write_results, accessor)); +#else return ex::just() | ex::on(computer, repeat_n(n_iterations, ex::bulk(ex::par, accessor.cells, update_h(accessor)) | ex::bulk(ex::par, accessor.cells, update_e(time, dt, accessor)))) | ex::then(dump_vtk(write_results, accessor)); +#endif } void run_snr(float dt, diff --git a/include/nvexec/stream/bulk.cuh b/include/nvexec/stream/bulk.cuh index 5e2594577..d1b67c5d7 100644 --- a/include/nvexec/stream/bulk.cuh +++ b/include/nvexec/stream/bulk.cuh @@ -401,28 +401,35 @@ namespace nv::execution::_strm template <> struct transform_sender_for { - template Sender> + template auto operator()(Env const & env, __ignore, Data data, Sender&& sndr) const { - auto [policy, shape, fun] = static_cast(data); - using shape_t = decltype(shape); - using fun_t = decltype(fun); - auto sched = get_completion_scheduler(get_env(sndr), env); - if constexpr (__std::same_as) + if constexpr (stream_completing_sender) { - // Use the bulk sender for a single GPU - using _sender_t = bulk_sender<__decay_t, shape_t, fun_t>; - return _sender_t{{}, static_cast(sndr), shape, static_cast(fun)}; + auto [policy, shape, fun] = static_cast(data); + using shape_t = decltype(shape); + using fun_t = decltype(fun); + auto sched = get_completion_scheduler(get_env(sndr), env); + if constexpr (__std::same_as) + { + // Use the bulk sender for a single GPU + using _sender_t = bulk_sender<__decay_t, shape_t, fun_t>; + return _sender_t{{}, static_cast(sndr), shape, static_cast(fun)}; + } + else + { + // Use the bulk sender for a multiple GPUs + using _sender_t = multi_gpu_bulk_sender<__decay_t, shape_t, fun_t>; + return _sender_t{{}, + sched.num_devices_, + static_cast(sndr), + shape, + static_cast(fun)}; + } } else { - // Use the bulk sender for a multiple GPUs - using _sender_t = multi_gpu_bulk_sender<__decay_t, shape_t, fun_t>; - return _sender_t{{}, - sched.num_devices_, - static_cast(sndr), - shape, - static_cast(fun)}; + return _strm::_no_stream_scheduler_in_env(); } } }; diff --git a/include/nvexec/stream/common.cuh b/include/nvexec/stream/common.cuh index 0bb263cef..fa0c8b989 100644 --- a/include/nvexec/stream/common.cuh +++ b/include/nvexec/stream/common.cuh @@ -54,7 +54,7 @@ namespace nv::execution }; #if defined(__clang__) && defined(__CUDA__) && !defined(STDEXEC_CLANG_TIDY_INVOKED) - __host__ inline auto get_device_type() noexcept -> device_type + inline __host__ auto get_device_type() noexcept -> device_type { return device_type::host; } @@ -64,7 +64,7 @@ namespace nv::execution return device_type::device; } #else - __host__ __device__ inline auto get_device_type() noexcept -> device_type + inline __host__ __device__ auto get_device_type() noexcept -> device_type { NV_IF_TARGET(NV_IS_HOST, (return device_type::host;), (return device_type::device;)); } @@ -75,6 +75,12 @@ namespace nv::execution return get_device_type() == device_type::device; } + struct stream_context; + struct stream_domain; + + struct CANNOT_DISPATCH_THIS_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER; + struct BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT; + namespace _strm { // Used by stream_domain to late-customize senders for execution @@ -84,6 +90,49 @@ namespace nv::execution template struct apply_sender_for; + + struct context; + + template + concept gpu_stream_scheduler = + scheduler + && __std::derived_from<__result_of, Scheduler, Env>, + stream_domain> + && requires(Scheduler sched) { + { sched.ctx_ } -> __decays_to; + }; + + template + concept stream_completing_sender = + sender + && gpu_stream_scheduler< + __result_of, env_of_t, Env>, + Env>; + + template + concept has_stream_transform = + STDEXEC::__callable>, + Sender, + Env const &>; + + template + concept has_nothrow_stream_transform = + STDEXEC::__nothrow_callable>, + Sender, + Env const &>; + + template + auto _no_stream_scheduler_in_env() noexcept + { + using namespace STDEXEC; + return __not_a_sender<_WHAT_(CANNOT_DISPATCH_THIS_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER), + _WHY_(BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT), + _WHERE_(_IN_ALGORITHM_, Tag), + _WITH_PRETTY_SENDER_, + _WITH_ENVIRONMENT_(Env)>{}; + } } // namespace _strm } // namespace nv::execution @@ -91,23 +140,14 @@ namespace nvexec = nv::execution; namespace nv::execution { - struct stream_context; - // The stream_domain is how the stream scheduler customizes the sender algorithms. All of the // algorithms use the current scheduler's domain to transform senders before starting them. struct stream_domain : STDEXEC::default_domain { template <::exec::sender_for Sender, class Tag = STDEXEC::tag_of_t, class Env> - requires STDEXEC::__callable, - Sender, - Env const &> + requires _strm::has_stream_transform static auto transform_sender(STDEXEC::set_value_t, Sender&& sndr, Env const & env) - noexcept(STDEXEC::__nothrow_callable, - Sender, - Env const &>) - + noexcept(_strm::has_nothrow_stream_transform) { return STDEXEC::__structured_apply(_strm::transform_sender_for{}, static_cast(sndr), @@ -278,15 +318,6 @@ namespace nv::execution template struct multi_gpu_bulk_sender; - template - concept gpu_stream_scheduler = - scheduler - && __std::derived_from<__result_of, Scheduler, Env>, - stream_domain> - && requires(Scheduler sched) { - { sched.ctx_ } -> __decays_to; - }; - struct stream_sender_base { using sender_concept = STDEXEC::sender_t; @@ -907,13 +938,6 @@ namespace nv::execution ctx); } - template - concept stream_completing_sender = - sender - && gpu_stream_scheduler< - __result_of, env_of_t, Env>, - Env>; - template using inner_receiver_t = __call_result_t&>; @@ -957,8 +981,10 @@ namespace nv::execution inline constexpr _strm::get_stream_t get_stream{}; #if CUDART_VERSION >= 13'00'0 - __host__ inline cudaError_t - cudaMemPrefetchAsync(const void* dev_ptr, size_t count, int dst_device, cudaStream_t stream = 0) + inline __host__ cudaError_t cudaMemPrefetchAsync(void const * dev_ptr, + size_t count, + int dst_device, + cudaStream_t stream = 0) { return ::cudaMemPrefetchAsync(dev_ptr, count, diff --git a/include/nvexec/stream/ensure_started.cuh b/include/nvexec/stream/ensure_started.cuh index 78c610464..385816a86 100644 --- a/include/nvexec/stream/ensure_started.cuh +++ b/include/nvexec/stream/ensure_started.cuh @@ -409,11 +409,18 @@ namespace nv::execution::_strm template using _sender_t = ensure_started_sender<__decay_t>; - template Sender> - auto operator()(Env const & env, __ignore, __ignore, Sender&& sndr) const -> _sender_t + template + auto operator()(Env const & env, __ignore, __ignore, Sender&& sndr) const { - auto sched = get_completion_scheduler(get_env(sndr), env); - return _sender_t{sched.ctx_, static_cast(sndr)}; + if constexpr (stream_completing_sender) + { + auto sched = get_completion_scheduler(get_env(sndr), env); + return _sender_t{sched.ctx_, static_cast(sndr)}; + } + else + { + return _strm::_no_stream_scheduler_in_env(); + } } }; } // namespace nv::execution::_strm diff --git a/include/nvexec/stream/let_xxx.cuh b/include/nvexec/stream/let_xxx.cuh index 5f567ec5e..7bf837f8b 100644 --- a/include/nvexec/stream/let_xxx.cuh +++ b/include/nvexec/stream/let_xxx.cuh @@ -80,11 +80,9 @@ namespace nv::execution::_strm using _sch_env_t = __result_of<_mk_sch_env, CvSender, Receiver, SetTag>; inline constexpr auto _mk_env2 = - []([[maybe_unused]] - SchEnv const & sch_env, + [](SchEnv const & sch_env, _strm::opstate_base const & opstate) { - //return opstate.make_env(); return __env::__join(sch_env, opstate.make_env()); }; @@ -210,22 +208,25 @@ namespace nv::execution::_strm using _mk_opstate_variant_fn = __mtransform<__muncurry<_mk_opstate_fn_t>, __qq<__variant>>; using _opstate_variant_t = __mapply<_mk_opstate_variant_fn, _result_tuples_t>; using _propagate_receiver_t = _let::_propagate_receiver_t; + using _sch_t = + __result_of, env_of_t, env_of_t>; explicit _opstate(CvSender&& sndr, Receiver rcvr, Fun fun) : _opstate(static_cast(sndr), static_cast(rcvr), static_cast(fun), + get_completion_scheduler(get_env(sndr), get_env(rcvr)), _mk_sch_env(sndr, rcvr, SetTag{})) {} - explicit _opstate(CvSender&& sndr, Receiver&& rcvr, Fun fun, _env2_t env2) + explicit _opstate(CvSender&& sndr, Receiver&& rcvr, Fun fun, _sch_t sch, _env2_t env2) : _opstate_base_t( static_cast(sndr), static_cast(rcvr), [this](__ignore) noexcept { return _receiver_t{{}, this}; }, - get_completion_scheduler(get_env(sndr), get_env(rcvr)).ctx_) + sch.ctx_) , fun_(static_cast(fun)) - , env2_(env2) + , env2_(static_cast<_env2_t&&>(env2)) {} STDEXEC_IMMOVABLE(_opstate); @@ -308,10 +309,18 @@ namespace nv::execution::_strm template struct _transform_let_sender { - template Sender> + template auto operator()(Env const &, __ignore, Fun fn, Sender&& sndr) const { - return let_sender{static_cast(sndr), static_cast(fn), SetTag{}}; + if constexpr (stream_completing_sender) + { + return let_sender{static_cast(sndr), static_cast(fn), SetTag{}}; + } + else + { + using _let_t = decltype(STDEXEC::__let::__let_from_set); + return _strm::_no_stream_scheduler_in_env<_let_t, Sender, Env>(); + } } }; diff --git a/include/nvexec/stream/repeat_n.cuh b/include/nvexec/stream/repeat_n.cuh index 3ecbc01e3..7d36412df 100644 --- a/include/nvexec/stream/repeat_n.cuh +++ b/include/nvexec/stream/repeat_n.cuh @@ -83,14 +83,16 @@ namespace nv::execution::_strm , sched_(std::move(sched)) , count_(count) { - _connect(); + if (count_ != 0) + { + _connect(); + } } - void _connect() + auto& _connect() { inner_opstate_.__emplace_from(STDEXEC::connect, exec::sequence(STDEXEC::schedule(sched_), sndr_), - //STDEXEC::on(sched_, sndr_), receiver{*this}); } @@ -114,8 +116,7 @@ namespace nv::execution::_strm } else { - _connect(); - STDEXEC::start(*inner_opstate_); + STDEXEC::start(_connect()); } } else @@ -167,6 +168,11 @@ namespace nv::execution::_strm STDEXEC::set_error_t(cudaError_t)>(); } + explicit sender(CvSender&& sndr, std::size_t count) + : sndr_(static_cast(sndr)) + , count_(count) + {} + template auto connect(Receiver rcvr) && -> repeat_n::opstate { @@ -186,6 +192,7 @@ namespace nv::execution::_strm return STDEXEC::get_env(sndr_); } + private: CvSender sndr_; // could be a value or a reference std::size_t count_; }; diff --git a/include/nvexec/stream/schedule_from.cuh b/include/nvexec/stream/schedule_from.cuh index b4792247c..4cda0dccf 100644 --- a/include/nvexec/stream/schedule_from.cuh +++ b/include/nvexec/stream/schedule_from.cuh @@ -27,10 +27,6 @@ namespace nv::execution { - struct CANNOT_DISPATCH_THE_SCHEDULE_FROM_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER; - struct BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT; - struct ADD_A_CONTINUES_ON_TRANSITION_TO_THE_CUDA_STREAM_SCHEDULER_BEFORE_THE_SCHEDULE_FROM_ALGORITHM; - namespace _strm { namespace _schfr @@ -188,15 +184,7 @@ namespace nv::execution } else { - return STDEXEC::__not_a_sender< - STDEXEC::_WHAT_( - CANNOT_DISPATCH_THE_SCHEDULE_FROM_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER), - STDEXEC::_WHY_(BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT), - STDEXEC::_WHERE_(STDEXEC::_IN_ALGORITHM_, STDEXEC::schedule_from_t), - // STDEXEC::_TO_FIX_THIS_ERROR_( - // ADD_A_CONTINUES_ON_TRANSITION_TO_THE_CUDA_STREAM_SCHEDULER_BEFORE_THE_SCHEDULE_FROM_ALGORITHM), - STDEXEC::_WITH_PRETTY_SENDER_, - STDEXEC::_WITH_ENVIRONMENT_(Env)>{}; + return _strm::_no_stream_scheduler_in_env(); } } }; diff --git a/include/nvexec/stream/split.cuh b/include/nvexec/stream/split.cuh index 8b592e21d..87234c6f6 100644 --- a/include/nvexec/stream/split.cuh +++ b/include/nvexec/stream/split.cuh @@ -391,11 +391,18 @@ namespace nv::execution::_strm template using _sender_t = split_sender<__decay_t>; - template Sender> - auto operator()(Env const & env, __ignore, __ignore, Sender&& sndr) const -> _sender_t + template + auto operator()(Env const & env, __ignore, __ignore, Sender&& sndr) const { - auto sched = get_completion_scheduler(get_env(sndr), env); - return _sender_t{sched.ctx_, static_cast(sndr)}; + if constexpr (stream_completing_sender) + { + auto sched = get_completion_scheduler(get_env(sndr), env); + return _sender_t{sched.ctx_, static_cast(sndr)}; + } + else + { + return _strm::_no_stream_scheduler_in_env, Env>(); + } } }; } // namespace nv::execution::_strm diff --git a/include/nvexec/stream/then.cuh b/include/nvexec/stream/then.cuh index 80f442222..2de508ced 100644 --- a/include/nvexec/stream/then.cuh +++ b/include/nvexec/stream/then.cuh @@ -208,11 +208,18 @@ namespace nv::execution::_strm template <> struct transform_sender_for { - template CvSender> + template auto operator()(Env const &, __ignore, Fn fun, CvSender&& sndr) const { - using _sender_t = then_sender<__decay_t, Fn>; - return _sender_t{static_cast(sndr), static_cast(fun)}; + if constexpr (stream_completing_sender) + { + using _sender_t = then_sender<__decay_t, Fn>; + return _sender_t{static_cast(sndr), static_cast(fun)}; + } + else + { + return _strm::_no_stream_scheduler_in_env(); + } } }; } // namespace nv::execution::_strm diff --git a/include/nvexec/stream/upon_error.cuh b/include/nvexec/stream/upon_error.cuh index f3342a99e..0af1b4264 100644 --- a/include/nvexec/stream/upon_error.cuh +++ b/include/nvexec/stream/upon_error.cuh @@ -195,11 +195,18 @@ namespace nv::execution::_strm template <> struct transform_sender_for { - template Sender> + template auto operator()(Env const &, __ignore, Fun fun, Sender&& sndr) const { - using _sender_t = upon_error_sender<__decay_t, Fun>; - return _sender_t{static_cast(sndr), static_cast(fun)}; + if constexpr (stream_completing_sender) + { + using _sender_t = upon_error_sender<__decay_t, Fun>; + return _sender_t{static_cast(sndr), static_cast(fun)}; + } + else + { + return _strm::_no_stream_scheduler_in_env(); + } } }; } // namespace nv::execution::_strm diff --git a/include/nvexec/stream/upon_stopped.cuh b/include/nvexec/stream/upon_stopped.cuh index a10888a68..202e15df6 100644 --- a/include/nvexec/stream/upon_stopped.cuh +++ b/include/nvexec/stream/upon_stopped.cuh @@ -186,11 +186,18 @@ namespace nv::execution::_strm template <> struct transform_sender_for { - template CvSender> + template auto operator()(Env const &, __ignore, Fun fun, CvSender&& sndr) const { - using _sender_t = upon_stopped_sender<__decay_t, Fun>; - return _sender_t{static_cast(sndr), static_cast(fun)}; + if constexpr (stream_completing_sender) + { + using _sender_t = upon_stopped_sender<__decay_t, Fun>; + return _sender_t{static_cast(sndr), static_cast(fun)}; + } + else + { + return _strm::_no_stream_scheduler_in_env(); + } } }; } // namespace nv::execution::_strm diff --git a/include/nvexec/stream/when_all.cuh b/include/nvexec/stream/when_all.cuh index 6f6780dc5..68919b875 100644 --- a/include/nvexec/stream/when_all.cuh +++ b/include/nvexec/stream/when_all.cuh @@ -533,15 +533,26 @@ namespace nv::execution::_strm template <> struct transform_sender_for { - template ... CvSenders> + template constexpr auto operator()(Env const &, __ignore, __ignore, CvSenders&&... sndrs) const { - using sender_t = - when_all_sender...>; - return sender_t{ - context{nullptr, nullptr, nullptr, nullptr}, - static_cast(sndrs)... - }; + if constexpr ((stream_completing_sender && ...)) + { + using sender_t = + when_all_sender...>; + return sender_t{ + context{nullptr, nullptr, nullptr, nullptr}, + static_cast(sndrs)... + }; + } + else + { + // Find the first sender that does not have a stream completion scheduler: + STDEXEC_CONSTEXPR_LOCAL bool _map[] = {!stream_completing_sender...}; + STDEXEC_CONSTEXPR_LOCAL std::size_t index = __pos_of(_map, _map + sizeof...(CvSenders)); + using _invalid_sender_t = __m_at_c; + return _strm::_no_stream_scheduler_in_env(); + } } }; diff --git a/include/stdexec/__detail/__any.hpp b/include/stdexec/__detail/__any.hpp index 3598458ba..ad49b0425 100644 --- a/include/stdexec/__detail/__any.hpp +++ b/include/stdexec/__detail/__any.hpp @@ -34,6 +34,7 @@ STDEXEC_PRAGMA_PUSH() STDEXEC_PRAGMA_IGNORE_GNU("-Wredundant-consteval-if") +STDEXEC_PRAGMA_IGNORE_GNU("-Warray-bounds") // NOLINTBEGIN(moderize-use-override) diff --git a/include/stdexec/__detail/__any_allocator.hpp b/include/stdexec/__detail/__any_allocator.hpp index 0e7e15871..5c216b546 100644 --- a/include/stdexec/__detail/__any_allocator.hpp +++ b/include/stdexec/__detail/__any_allocator.hpp @@ -23,6 +23,9 @@ #include "__memory.hpp" #include "__typeinfo.hpp" +STDEXEC_PRAGMA_PUSH() +STDEXEC_PRAGMA_IGNORE_GNU("-Warray-bounds") + namespace STDEXEC { namespace __detail @@ -120,3 +123,5 @@ namespace STDEXEC STDEXEC_HOST_DEVICE_DEDUCTION_GUIDE __any_allocator(std::allocator) -> __any_allocator; } // namespace STDEXEC + +STDEXEC_PRAGMA_POP() diff --git a/include/stdexec/__detail/__variant.hpp b/include/stdexec/__detail/__variant.hpp index 81dc1ca83..2418c49f2 100644 --- a/include/stdexec/__detail/__variant.hpp +++ b/include/stdexec/__detail/__variant.hpp @@ -20,7 +20,6 @@ #include "__type_traits.hpp" #include "__utility.hpp" -#include #include #include #include @@ -346,9 +345,9 @@ namespace STDEXEC noexcept((__nothrow_callable<_Fn, _Us..., __copy_cvref_t<_Self, _Ts>> && ...)) -> __call_result_t<_Fn, _Us..., __copy_cvref_t<_Self, __at_t<0>>> { - using __result_t = __call_result_t<_Fn, _Us..., __copy_cvref_t<_Self, __at_t<0>>>; - - STDEXEC_CONSTEXPR_LOCAL auto __vtable = std::array{ + using __result_t = __call_result_t<_Fn, _Us..., __copy_cvref_t<_Self, __at_t<0>>>; + using __visit_fn_t = decltype(&__var::__visit_alt<0, __result_t, _Fn, _Self, _Us...>); + STDEXEC_CONSTEXPR_LOCAL __visit_fn_t __vtable[] = { &__var::__visit_alt<_Is, __result_t, _Fn, _Self, _Us...>...}; STDEXEC_ASSERT(__self.__index_ != __variant_npos); return (*__vtable[__self.__index_])(static_cast<_Fn &&>(__fn), diff --git a/include/stdexec/stop_token.hpp b/include/stdexec/stop_token.hpp index 23b92e993..6669729c2 100644 --- a/include/stdexec/stop_token.hpp +++ b/include/stdexec/stop_token.hpp @@ -29,6 +29,9 @@ # include // IWYU pragma: export #endif +STDEXEC_PRAGMA_PUSH() +STDEXEC_PRAGMA_IGNORE_GNU("-Warray-bounds") + #if defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_X64)) extern void _mm_pause(); #endif @@ -464,3 +467,5 @@ namespace STDEXEC inplace_stop_source& __stop_source_; }; } // namespace STDEXEC + +STDEXEC_PRAGMA_POP()