-
Notifications
You must be signed in to change notification settings - Fork 242
Expand file tree
/
Copy pathupon_error.cuh
More file actions
223 lines (192 loc) · 7.21 KB
/
upon_error.cuh
File metadata and controls
223 lines (192 loc) · 7.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
/*
* Copyright (c) 2022 NVIDIA Corporation
*
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://llvm.org/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// clang-format Language: Cpp
#pragma once
#include "../../stdexec/execution.hpp"
#include "common.cuh"
#include <concepts>
#include <cstddef>
#include <type_traits>
#include <utility>
#include <cuda/std/utility>
STDEXEC_PRAGMA_PUSH()
STDEXEC_PRAGMA_IGNORE_EDG(cuda_compile)
namespace nv::execution::_strm
{
namespace _upon_error
{
template <class... Args, class Fun>
STDEXEC_ATTRIBUTE(launch_bounds(1))
__global__ void _upon_error_kernel(Fun fn, Args... args)
{
static_assert(trivially_copyable<Fun, Args...>);
::cuda::std::move(fn)(static_cast<Args&&>(args)...);
}
template <class... Args, class Fun, class ResultT>
STDEXEC_ATTRIBUTE(launch_bounds(1))
__global__ void _upon_error_kernel_with_result(Fun fn, ResultT* result, Args... args)
{
static_assert(trivially_copyable<Fun, Args...>);
new (result) ResultT(::cuda::std::move(fn)(static_cast<Args&&>(args)...));
}
template <std::size_t MemoryAllocationSize, class Receiver, class Fun>
struct receiver : stream_receiver_base
{
using receiver_concept = STDEXEC::receiver_t;
using env_t = _strm::opstate_base<Receiver>::env_t;
static constexpr std::size_t memory_allocation_size() noexcept
{
return MemoryAllocationSize;
}
explicit receiver(Fun fun, _strm::opstate_base<Receiver>& opstate)
: f_(static_cast<Fun&&>(fun))
, opstate_(opstate)
{}
template <class... Args>
void set_value(Args&&... args) noexcept
{
opstate_.propagate_completion_signal(set_value_t(), static_cast<Args&&>(args)...);
}
template <class Error>
requires std::invocable<Fun, Error>
void set_error(Error&& error) noexcept
{
using result_t = std::invoke_result_t<Fun, Error>;
constexpr bool does_not_return_a_value = std::is_same_v<void, result_t>;
cudaStream_t stream = opstate_.get_stream();
if constexpr (does_not_return_a_value)
{
_upon_error_kernel<Error&&>
<<<1, 1, 0, stream>>>(std::move(f_), static_cast<Error&&>(error));
if (cudaError_t status = STDEXEC_LOG_CUDA_API(cudaPeekAtLastError());
status == cudaSuccess)
{
opstate_.propagate_completion_signal(STDEXEC::set_value);
}
else
{
opstate_.propagate_completion_signal(STDEXEC::set_error, std::move(status));
}
}
else
{
using decayed_result_t = __decay_t<result_t>;
auto* d_result = static_cast<decayed_result_t*>(opstate_.temp_storage_);
_upon_error_kernel_with_result<Error&&>
<<<1, 1, 0, stream>>>(std::move(f_), d_result, static_cast<Error&&>(error));
if (cudaError_t status = STDEXEC_LOG_CUDA_API(cudaPeekAtLastError());
status == cudaSuccess)
{
opstate_.defer_temp_storage_destruction(d_result);
opstate_.propagate_completion_signal(STDEXEC::set_value, std::move(*d_result));
}
else
{
opstate_.propagate_completion_signal(STDEXEC::set_error, std::move(status));
}
}
}
void set_stopped() noexcept
{
opstate_.propagate_completion_signal(set_stopped_t());
}
[[nodiscard]]
auto get_env() const noexcept -> env_t
{
return opstate_.make_env();
}
Fun f_;
_strm::opstate_base<Receiver>& opstate_;
};
} // namespace _upon_error
template <class Sender, class Fun>
struct upon_error_sender : stream_sender_base
{
template <class Receiver>
requires sender_in<Sender, env_of_t<Receiver>>
struct max_result_size
: STDEXEC::__gather_completions_of_t<set_error_t,
Sender,
env_of_t<Receiver>,
__mbind_front<result_size_for, Fun>,
maxsize>
{};
template <class Receiver>
using receiver_t = _upon_error::receiver<max_result_size<Receiver>::value, Receiver, Fun>;
template <class Error>
using _set_error_t = __set_value_from_t<Fun, Error>;
template <class Self, class... Env>
using completion_signatures = __transform_completion_signatures_t<
__completion_signatures_of_t<__copy_cvref_t<Self, Sender>, Env...>,
completion_signatures<set_error_t(cudaError_t)>,
__cmplsigs::__default_set_value,
_set_error_t>;
explicit upon_error_sender(Sender sndr, Fun fun)
noexcept(__nothrow_move_constructible<Sender, Fun>)
: sndr_(static_cast<Sender&&>(sndr))
, fun_(static_cast<Fun&&>(fun))
{}
template <__decays_to<upon_error_sender> Self, receiver Receiver>
requires receiver_of<Receiver, completion_signatures<Self, env_of_t<Receiver>>>
STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver rcvr)
-> stream_opstate_t<__copy_cvref_t<Self, Sender>, receiver_t<Receiver>, Receiver>
{
return stream_opstate<__copy_cvref_t<Self, Sender>>(
static_cast<Self&&>(self).sndr_,
static_cast<Receiver&&>(rcvr),
[&](_strm::opstate_base<Receiver>& stream_provider) -> receiver_t<Receiver>
{ return receiver_t<Receiver>(self.fun_, stream_provider); });
}
STDEXEC_EXPLICIT_THIS_END(connect)
template <__decays_to<upon_error_sender> Self, class... Env>
static consteval auto get_completion_signatures() -> completion_signatures<Self, Env...>
{
return {};
}
auto get_env() const noexcept -> stream_sender_attrs<Sender>
{
return {&sndr_};
}
private:
Sender sndr_;
Fun fun_;
};
template <>
struct transform_sender_for<STDEXEC::upon_error_t>
{
template <class Env, class Fun, class Sender>
auto operator()(Env const &, __ignore, Fun fun, Sender&& sndr) const
{
if constexpr (stream_completing_sender<Sender, Env>)
{
using _sender_t = upon_error_sender<__decay_t<Sender>, Fun>;
return _sender_t{static_cast<Sender&&>(sndr), static_cast<Fun&&>(fun)};
}
else
{
return _strm::_no_stream_scheduler_in_env<STDEXEC::upon_error_t, Sender, Env>();
}
}
};
} // namespace nv::execution::_strm
namespace nvexec = nv::execution;
namespace STDEXEC::__detail
{
template <class Sender, class Fun>
extern __declfn_t<nvexec::_strm::upon_error_sender<__demangle_t<Sender>, Fun>>
__demangle_v<nvexec::_strm::upon_error_sender<Sender, Fun>>;
} // namespace STDEXEC::__detail
STDEXEC_PRAGMA_POP()