Skip to content

Commit f5d7bcc

Browse files
committed
really fix the perf regression in the maxwell_gpu_s example
1 parent 170226a commit f5d7bcc

8 files changed

Lines changed: 228 additions & 110 deletions

File tree

examples/nvexec/maxwell/snr.cuh

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,19 @@ STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces")
6060

6161
namespace ex = stdexec;
6262

63-
namespace repeat_n_detail
63+
namespace _repeat_n
6464
{
6565
template <class Sender, class Closure>
66-
struct repeat_n_sender_t;
67-
} // namespace repeat_n_detail
66+
struct sender;
67+
} // namespace _repeat_n
6868

6969
struct repeat_n_t
7070
{
7171
template <ex::sender Sender, ex::__sender_adaptor_closure Closure>
7272
auto operator()(Sender __sndr, std::size_t n, Closure closure) const noexcept
73-
-> repeat_n_detail::repeat_n_sender_t<Sender, Closure>
73+
-> _repeat_n::sender<Sender, Closure>
7474
{
75-
return repeat_n_detail::repeat_n_sender_t<Sender, Closure>{
75+
return _repeat_n::sender<Sender, Closure>{
7676
{},
7777
{closure, n},
7878
std::move(__sndr)
@@ -88,9 +88,8 @@ struct repeat_n_t
8888

8989
inline constexpr repeat_n_t repeat_n{};
9090

91-
namespace repeat_n_detail
91+
namespace _repeat_n
9292
{
93-
9493
template <class OpT>
9594
class receiver_2_t
9695
{
@@ -245,11 +244,11 @@ namespace repeat_n_detail
245244
};
246245

247246
template <class Sender, class Closure>
248-
struct repeat_n_sender_t
247+
struct sender
249248
{
250249
using sender_concept = ex::sender_t;
251250

252-
template <class>
251+
template <class, class Env>
253252
static consteval auto get_completion_signatures() noexcept
254253
{
255254
return ex::completion_signatures<
@@ -260,11 +259,11 @@ namespace repeat_n_detail
260259
>();
261260
}
262261

263-
template <ex::__decays_to<repeat_n_sender_t> Self, ex::receiver Receiver>
262+
template <ex::__decays_to<sender> Self, ex::receiver Receiver>
264263
STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver r)
265-
-> repeat_n_detail::operation_state_t<Sender, Closure, Receiver>
264+
-> _repeat_n::operation_state_t<Sender, Closure, Receiver>
266265
{
267-
return repeat_n_detail::operation_state_t<Sender, Closure, Receiver>(
266+
return _repeat_n::operation_state_t<Sender, Closure, Receiver>(
268267
static_cast<Self&&>(self).sender_,
269268
static_cast<Self&&>(self).data_.first,
270269
static_cast<Receiver&&>(r),
@@ -283,26 +282,25 @@ namespace repeat_n_detail
283282
std::pair<Closure, std::size_t> data_;
284283
Sender sender_;
285284
};
286-
} // namespace repeat_n_detail
285+
} // namespace _repeat_n
287286

288287
namespace STDEXEC
289288
{
290289
template <class Sender, class Closure>
291-
inline constexpr std::size_t
292-
__structured_binding_size_v<repeat_n_detail::repeat_n_sender_t<Sender, Closure>> = 3;
290+
inline constexpr std::size_t __structured_binding_size_v<_repeat_n::sender<Sender, Closure>> = 3;
293291
} // namespace STDEXEC
294292

295293
#if STDEXEC_CUDA_COMPILATION()
296294
// A CUDA stream implementation of repeat_n
297295
namespace nv::execution::_strm
298296
{
299-
namespace repeat_n
297+
namespace _repeat_n
300298
{
301299
template <class OpT>
302300
class receiver_2_t : public stream_receiver_base
303301
{
304-
using Sender = OpT::PredSender;
305-
using Receiver = OpT::Receiver;
302+
using Sender = OpT::child_t;
303+
using Receiver = OpT::receiver_t;
306304

307305
OpT& op_state_;
308306

@@ -353,11 +351,15 @@ namespace nv::execution::_strm
353351
template <class OpT>
354352
class receiver_1_t : public stream_receiver_base
355353
{
356-
using Receiver = OpT::Receiver;
354+
using Receiver = OpT::receiver_t;
357355

358356
OpT& op_state_;
359357

360358
public:
359+
explicit receiver_1_t(OpT& op_state)
360+
: op_state_(op_state)
361+
{}
362+
361363
void set_value() noexcept
362364
{
363365
using inner_op_state_t = OpT::inner_op_state_t;
@@ -394,15 +396,13 @@ namespace nv::execution::_strm
394396
{
395397
return op_state_.make_env();
396398
}
397-
398-
explicit receiver_1_t(OpT& op_state)
399-
: op_state_(op_state)
400-
{}
401399
};
402400

403401
template <class PredSender, class Closure, class Receiver>
404402
struct operation_state_t : _strm::opstate_base<Receiver>
405403
{
404+
using receiver_t = Receiver;
405+
using child_t = PredSender;
406406
using Scheduler = std::invoke_result_t<ex::get_completion_scheduler_t<ex::set_value_t>,
407407
ex::env_of_t<PredSender>,
408408
ex::env_of_t<Receiver>>;
@@ -445,7 +445,7 @@ namespace nv::execution::_strm
445445
static_cast<Receiver&&>(rcvr),
446446
ex::get_completion_scheduler<ex::set_value_t>(ex::get_env(pred_sender),
447447
ex::get_env(rcvr))
448-
.context_state_)
448+
.ctx_)
449449
, scheduler_(ex::get_completion_scheduler<ex::set_value_t>(ex::get_env(pred_sender),
450450
ex::get_env(rcvr)))
451451
, closure_(closure)
@@ -458,7 +458,7 @@ namespace nv::execution::_strm
458458
};
459459

460460
template <class Sender, class Closure>
461-
struct sender_t
461+
struct sender
462462
{
463463
using sender_concept = ex::sender_t;
464464

@@ -467,12 +467,12 @@ namespace nv::execution::_strm
467467
ex::set_error_t(std::exception_ptr),
468468
ex::set_error_t(cudaError_t)>;
469469

470-
template <ex::__decays_to<sender_t> Self, ex::receiver Receiver>
470+
template <ex::__decays_to<sender> Self, ex::receiver Receiver>
471471
requires(ex::sender_to<Sender, Receiver>)
472472
STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver r)
473-
-> nvexec::_strm::repeat_n::operation_state_t<Sender, Closure, Receiver>
473+
-> nvexec::_strm::_repeat_n::operation_state_t<Sender, Closure, Receiver>
474474
{
475-
return nvexec::_strm::repeat_n::operation_state_t<Sender, Closure, Receiver>(
475+
return nvexec::_strm::_repeat_n::operation_state_t<Sender, Closure, Receiver>(
476476
static_cast<Self&&>(self).sender_,
477477
static_cast<Self&&>(self).closure_,
478478
static_cast<Receiver&&>(r),
@@ -490,21 +490,20 @@ namespace nv::execution::_strm
490490
Closure closure_;
491491
std::size_t n_{};
492492
};
493-
} // namespace repeat_n
493+
} // namespace _repeat_n
494494

495495
template <>
496-
struct transform_sender_for<repeat_n_t>
496+
struct transform_sender_for<::repeat_n_t>
497497
{
498498
template <class Env, class Data, class Sender>
499-
auto operator()(Env const &, ex::__ignore, Data&& data, Sender sndr) const
499+
auto operator()(Env const &, ::repeat_n_t, Data&& data, Sender sndr) const
500500
{
501-
static_assert(sizeof(Env) == 0);
502-
auto& [closure, n] = data;
503-
using closure_t = decltype(closure);
501+
auto& [closure, count] = data;
502+
using closure_t = decltype(closure);
504503

505-
return repeat_n::sender_t<Sender, closure_t>(static_cast<Sender&&>(sndr),
506-
ex::__forward_like<Data>(closure),
507-
n);
504+
return _strm::_repeat_n::sender<Sender, closure_t>(static_cast<Sender&&>(sndr),
505+
ex::__forward_like<Data>(closure),
506+
count);
508507
}
509508
};
510509
} // namespace nv::execution::_strm
@@ -528,8 +527,8 @@ auto maxwell_eqs_snr(float dt,
528527
ex::scheduler auto&& computer)
529528
{
530529
return ex::just()
531-
| repeat_n(n_iterations,
532-
ex::on(computer,
530+
| ex::on(computer,
531+
repeat_n(n_iterations,
533532
ex::bulk(ex::par, accessor.cells, update_h(accessor))
534533
| ex::bulk(ex::par, accessor.cells, update_e(time, dt, accessor))))
535534
| ex::then(dump_vtk(write_results, accessor));

include/exec/detail/basic_sequence.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ namespace exec = experimental::execution;
113113

114114
namespace STDEXEC::__detail
115115
{
116+
template <auto _DescriptorFn>
117+
extern __result_of<_DescriptorFn> __desc_of_v<exec::__seqexpr<_DescriptorFn>>;
118+
116119
template <auto _DescriptorFn>
117120
extern __declfn_t<__minvoke<__result_of<_DescriptorFn>, __q<exec::__basic_sequence_sender_t>>>
118121
__demangle_v<exec::__seqexpr<_DescriptorFn>>;

include/exec/static_thread_pool.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,14 @@ namespace experimental::execution
12811281
using _bulk_opstate_t =
12821282
_static_thread_pool::_bulk_opstate<Parallelize, Shape, Fun, Sender, Receiver>;
12831283

1284+
explicit _bulk_sender(_static_thread_pool& pool, Sender sndr, Shape shape, Fun fun)
1285+
noexcept(__nothrow_move_constructible<Sender, Fun>)
1286+
: pool_(pool)
1287+
, sndr_(static_cast<Sender&&>(sndr))
1288+
, shape_(shape)
1289+
, fun_(static_cast<Fun&&>(fun))
1290+
{}
1291+
12841292
template <__decays_to<_bulk_sender> Self, receiver Receiver>
12851293
requires receiver_of<Receiver, _completions_t<Self, env_of_t<Receiver>>>
12861294
STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver rcvr)
@@ -1310,6 +1318,7 @@ namespace experimental::execution
13101318
return STDEXEC::get_env(sndr_);
13111319
}
13121320

1321+
private:
13131322
_static_thread_pool& pool_;
13141323
Sender sndr_;
13151324
Shape shape_;

include/nvexec/stream/common.cuh

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,20 @@ namespace nv::execution
9898
struct stream_domain : STDEXEC::default_domain
9999
{
100100
template <::exec::sender_for Sender, class Tag = STDEXEC::tag_of_t<Sender>, class Env>
101-
requires STDEXEC::__applicable<_strm::transform_sender_for<Tag>, Sender, Env const &>
101+
requires STDEXEC::__callable<STDEXEC::__structured_apply_t,
102+
_strm::transform_sender_for<Tag>,
103+
Sender,
104+
Env const &>
102105
static auto transform_sender(STDEXEC::set_value_t, Sender&& sndr, Env const & env)
106+
noexcept(STDEXEC::__nothrow_callable<STDEXEC::__structured_apply_t,
107+
_strm::transform_sender_for<Tag>,
108+
Sender,
109+
Env const &>)
110+
103111
{
104-
return STDEXEC::__apply(_strm::transform_sender_for<Tag>{}, static_cast<Sender&&>(sndr), env);
112+
return STDEXEC::__structured_apply(_strm::transform_sender_for<Tag>{},
113+
static_cast<Sender&&>(sndr),
114+
env);
105115
}
106116

107117
template <class Tag, STDEXEC::sender Sender, class... Args>
@@ -115,7 +125,6 @@ namespace nv::execution
115125

116126
namespace _strm
117127
{
118-
119128
#if STDEXEC_HAS_BUILTIN(__is_reference)
120129
template <class... Ts>
121130
concept trivially_copyable = ((STDEXEC_IS_TRIVIALLY_COPYABLE(Ts) || __is_reference(Ts)) && ...);
@@ -701,8 +710,9 @@ namespace nv::execution
701710
else
702711
{
703712
// pass a cudaError_t by value:
704-
continuation_kernel<OuterReceiver, Error>
705-
<<<1, 1, 0, get_stream()>>>(static_cast<OuterReceiver&&>(rcvr_), set_error_t(), status);
713+
continuation_kernel<<<1, 1, 0, get_stream()>>>(static_cast<OuterReceiver&&>(rcvr_),
714+
set_error_t(),
715+
cudaError_t(status));
706716
}
707717
}
708718

include/nvexec/stream/schedule_from.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ namespace nv::execution
3838
template <class Sender, class Receiver>
3939
struct opstate : _strm::opstate_base<Receiver>
4040
{
41-
using env_t = _strm::opstate_base<Receiver>::env_t;
4241
struct receiver;
42+
using env_t = _strm::opstate_base<Receiver>::env_t;
4343
using variant_t = variant_storage_t<Sender, env_t>;
4444
using task_t = continuation_task<receiver, variant_t>;
4545
using enqueue_receiver_t = stream_enqueue_receiver<env_t, variant_t>;
@@ -75,7 +75,7 @@ namespace nv::execution
7575
opstate& opstate_;
7676
};
7777

78-
opstate(Sender&& sender, Receiver&& rcvr, context ctx)
78+
opstate(Sender&& sndr, Receiver&& rcvr, context ctx)
7979
: _strm::opstate_base<Receiver>(static_cast<Receiver&&>(rcvr), ctx)
8080
, ctx_(ctx)
8181
, storage_(host_allocate<variant_t>(this->status_, ctx.pinned_resource_))
@@ -88,7 +88,7 @@ namespace nv::execution
8888
.release())
8989
, env_(host_allocate(this->status_, ctx_.pinned_resource_, this->make_env()))
9090
, inner_op_{
91-
connect(static_cast<Sender&&>(sender),
91+
connect(static_cast<Sender&&>(sndr),
9292
enqueue_receiver_t{env_.get(), storage_.get(), task_, ctx_.hub_->producer()})}
9393
{
9494
if (this->status_ == cudaSuccess)

include/stdexec/__detail/__config.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -801,16 +801,16 @@ namespace STDEXEC
801801
_NAME))
802802

803803
# define STDEXEC_EXPLICIT_THIS_EAT_this
804-
# define STDEXEC_EXPLICIT_THIS_MANGLE_auto auto STDEXEC_EXPLICIT_THIS_MANGLE STDEXEC_PP_LPAREN
805-
# define STDEXEC_EXPLICIT_THIS_MANGLE_void void STDEXEC_EXPLICIT_THIS_MANGLE STDEXEC_PP_LPAREN
806-
# define STDEXEC_EXPLICIT_THIS_MANGLE_bool bool STDEXEC_EXPLICIT_THIS_MANGLE STDEXEC_PP_LPAREN
804+
# define STDEXEC_EXPLICIT_THIS_MANGLE_auto auto STDEXEC_EXPLICIT_THIS_MANGLE STDEXEC_PP_LPAREN()
805+
# define STDEXEC_EXPLICIT_THIS_MANGLE_void void STDEXEC_EXPLICIT_THIS_MANGLE STDEXEC_PP_LPAREN()
806+
# define STDEXEC_EXPLICIT_THIS_MANGLE_bool bool STDEXEC_EXPLICIT_THIS_MANGLE STDEXEC_PP_LPAREN()
807807

808808
# define STDEXEC_EXPLICIT_THIS_ARGS(...) \
809-
STDEXEC_PP_CAT(STDEXEC_EXPLICIT_THIS_EAT_, __VA_ARGS__) STDEXEC_PP_RPAREN
809+
STDEXEC_PP_CAT(STDEXEC_EXPLICIT_THIS_EAT_, __VA_ARGS__) STDEXEC_PP_RPAREN()
810810

811811
# define STDEXEC_EXPLICIT_THIS_BEGIN(...) \
812812
static STDEXEC_PP_EXPAND(STDEXEC_PP_CAT(STDEXEC_EXPLICIT_THIS_MANGLE_, __VA_ARGS__) \
813-
STDEXEC_PP_RPAREN) STDEXEC_PP_LPAREN STDEXEC_EXPLICIT_THIS_ARGS
813+
STDEXEC_PP_RPAREN()) STDEXEC_PP_LPAREN() STDEXEC_EXPLICIT_THIS_ARGS
814814

815815
# define STDEXEC_EXPLICIT_THIS_END(_NAME) \
816816
template <class... _Ts> \

include/stdexec/__detail/__preprocessor.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
#define STDEXEC_PP_STRINGIZE_I(...) #__VA_ARGS__
1919
#define STDEXEC_PP_STRINGIZE(...) STDEXEC_PP_STRINGIZE_I(__VA_ARGS__)
2020

21-
#define STDEXEC_PP_LPAREN (
22-
#define STDEXEC_PP_RPAREN )
23-
#define STDEXEC_PP_PARENS ()
24-
#define STDEXEC_PP_COMMA ,
21+
#define STDEXEC_PP_LPAREN() (
22+
#define STDEXEC_PP_RPAREN() )
23+
#define STDEXEC_PP_PARENS() ()
24+
#define STDEXEC_PP_COMMA() ,
2525

2626
#define STDEXEC_PP_CAT_I(_XP, ...) _XP##__VA_ARGS__
2727
#define STDEXEC_PP_CAT(_XP, ...) STDEXEC_PP_CAT_I(_XP, __VA_ARGS__)
@@ -83,7 +83,7 @@
8383

8484
#define STDEXEC_PP_FOR_EACH_AGAIN() STDEXEC_PP_FOR_EACH_HELPER
8585
#define STDEXEC_PP_FOR_EACH_HELPER(_MACRO, _A1, ...) \
86-
_MACRO(_A1) __VA_OPT__(STDEXEC_PP_FOR_EACH_AGAIN STDEXEC_PP_PARENS(_MACRO, __VA_ARGS__)) /**/
86+
_MACRO(_A1) __VA_OPT__(STDEXEC_PP_FOR_EACH_AGAIN STDEXEC_PP_PARENS()(_MACRO, __VA_ARGS__)) /**/
8787
#define STDEXEC_PP_FOR_EACH(_MACRO, ...) \
8888
__VA_OPT__(STDEXEC_PP_EXPAND_R(STDEXEC_PP_FOR_EACH_HELPER(_MACRO, __VA_ARGS__)))
8989

@@ -94,7 +94,7 @@
9494
#define STDEXEC_PP_BACK_AGAIN() STDEXEC_PP_BACK_I
9595
#define STDEXEC_PP_BACK_I(_A1, ...) \
9696
STDEXEC_PP_FRONT(__VA_OPT__(, ) _A1, ) \
97-
__VA_OPT__(STDEXEC_PP_BACK_AGAIN STDEXEC_PP_PARENS(__VA_ARGS__))
97+
__VA_OPT__(STDEXEC_PP_BACK_AGAIN STDEXEC_PP_PARENS()(__VA_ARGS__))
9898
#define STDEXEC_PP_BACK(...) __VA_OPT__(STDEXEC_PP_EXPAND_R(STDEXEC_PP_BACK_I(__VA_ARGS__)))
9999

100100
#define STDEXEC_PP_TAIL(_IGN, ...) __VA_ARGS__

0 commit comments

Comments
 (0)