diff --git a/Folder.DotSettings b/Folder.DotSettings
new file mode 100644
index 0000000..069ccf5
--- /dev/null
+++ b/Folder.DotSettings
@@ -0,0 +1,2 @@
+
+ <NamingElement Priority="6" Title="Parameters"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="function parameter" /><type Name="lambda parameter" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="aaBb"><ExtraRule Prefix="_" Suffix="" Style="aaBb" /></Policy></NamingElement>
\ No newline at end of file
diff --git a/cmake/dependencies/FindCUDNN.cmake b/cmake/dependencies/FindCUDNN.cmake
index a150310..fd77eea 100644
--- a/cmake/dependencies/FindCUDNN.cmake
+++ b/cmake/dependencies/FindCUDNN.cmake
@@ -76,4 +76,8 @@ if(CUDNN_FOUND)
endif()
endif()
+if (CUDNN_FOUND AND CUDNN_VERSION VERSION_LESS "8.0")
+ message(FATAL_ERROR "Flashlight requires cuDNN >= 8.0, found ${CUDNN_VERSION}")
+endif()
+
mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE_DIR CUDNN_LIBRARY CUDNN_VERSION)
diff --git a/cmake/utils/flashlightConfig.cmake.in b/cmake/utils/flashlightConfig.cmake.in
index 2bf2550..9d73423 100644
--- a/cmake/utils/flashlightConfig.cmake.in
+++ b/cmake/utils/flashlightConfig.cmake.in
@@ -49,7 +49,7 @@ if (@FL_BUILD_STANDALONE@)
endif()
if (@FL_USE_CUDA@)
if (@FL_USE_CUDNN@)
- find_dependency(CUDNN 7.1)
+ find_dependency(CUDNN 8)
endif()
if (@FL_BUILD_DISTRIBUTED@)
find_dependency(NCCL)
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp
index 9f6b315..2538f88 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp
@@ -48,15 +48,15 @@ namespace {
if(minAxis == 0) {
modeOut = CUDNN_BATCHNORM_PER_ACTIVATION;
- inDescDimsOut = Shape(
+ inDescDimsOut = Shape{
{
1,
1,
nfeatures,
static_cast(input.elements() / nfeatures)
}
- );
- wtDescDimsOut = Shape({1, 1, nfeatures});
+ };
+ wtDescDimsOut = Shape{1, 1, nfeatures};
} else {
modeOut = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7003
@@ -67,15 +67,15 @@ namespace {
int batchsz = 1;
for(int i = maxAxis + 1; i < input.ndim(); ++i)
batchsz *= input.dim(i);
- inDescDimsOut = Shape(
+ inDescDimsOut = Shape{
{
1,
static_cast(input.elements() / (nfeatures * batchsz)),
nfeatures,
batchsz,
}
- );
- wtDescDimsOut = Shape({1, 1, nfeatures});
+ };
+ wtDescDimsOut = Shape{1, 1, nfeatures};
}
}
@@ -101,7 +101,7 @@ Tensor CudnnAutogradExtension::batchnorm(
);
FL_TENSOR_DTYPES_MATCH_CHECK(weight, bias, runningMean, runningVar);
- auto output = Tensor(input.shape(), input.type());
+ auto output = Tensor{input.shape(), input.type()};
cudnnBatchNormMode_t mode;
Shape inDescDims, wtDescDims;
@@ -122,8 +122,8 @@ Tensor CudnnAutogradExtension::batchnorm(
fl::dtype scalarsType =
input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type();
- auto inDesc = TensorDescriptor(input.type(), inDescDims);
- auto wtDesc = TensorDescriptor(weightArray.type(), wtDescDims);
+ auto inDesc = TensorDescriptor{input.type(), inDescDims};
+ auto wtDesc = TensorDescriptor{weightArray.type(), wtDescDims};
{
DevicePtr inRaw(input);
@@ -140,8 +140,8 @@ Tensor CudnnAutogradExtension::batchnorm(
);
if(train) {
- saveMean = Tensor({wtDescDims[2]}, scalarsType);
- saveVar = Tensor({wtDescDims[2]}, scalarsType);
+ saveMean = Tensor{{wtDescDims[2]}, scalarsType};
+ saveVar = Tensor{{wtDescDims[2]}, scalarsType};
DevicePtr saveMeanRaw(saveMean);
DevicePtr saveVarRaw(saveVar);
@@ -223,13 +223,13 @@ std::tuple CudnnAutogradExtension::batchnormBackward(
const void* one1 = kOne(scalarsType);
const void* zero0 = kZero(scalarsType);
- auto iDesc = TensorDescriptor(input.type(), inDescDims);
- auto wDesc = TensorDescriptor(wt.type(), wtDescDims);
+ auto iDesc = TensorDescriptor{input.type(), inDescDims};
+ auto wDesc = TensorDescriptor{wt.type(), wtDescDims};
// CuDNN doesn't support calculating only the gradients
// required for batchnorm
- auto gradIn = Tensor(input.shape(), input.type());
- auto gradWt = Tensor(wt.shape(), wt.type());
- auto gradBs = Tensor(wt.shape(), wt.type());
+ auto gradIn = Tensor{input.shape(), input.type()};
+ auto gradWt = Tensor{wt.shape(), wt.type()};
+ auto gradBs = Tensor{wt.shape(), wt.type()};
{
DevicePtr iRaw(input);
DevicePtr wRaw(wt);
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp
index bb89e61..b9f63c6 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp
@@ -314,9 +314,9 @@ Tensor CudnnAutogradExtension::conv2d(
auto hasBias = bias.elements() > 0;
- auto inDesc = TensorDescriptor(input);
- auto wtDesc = FilterDescriptor(weights);
- auto convDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups);
+ auto inDesc = TensorDescriptor{input};
+ auto wtDesc = FilterDescriptor{weights};
+ auto convDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups};
if(input.type() == fl::dtype::f16)
CUDNN_CHECK_ERR(
cudnnSetConvolutionMathType(
@@ -339,8 +339,8 @@ Tensor CudnnAutogradExtension::conv2d(
odims.data()
)
);
- auto output = Tensor({odims[3], odims[2], odims[1], odims[0]}, input.type());
- auto outDesc = TensorDescriptor(output);
+ auto output = Tensor{{odims[3], odims[2], odims[1], odims[0]}, input.type()};
+ auto outDesc = TensorDescriptor{output};
auto handle = getCudnnHandle();
const auto& cudnnStream = getCudnnStream();
@@ -357,7 +357,7 @@ Tensor CudnnAutogradExtension::conv2d(
try {
wspace =
- Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8);
+ Tensor{{static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8};
} catch(const std::exception&) {
fwdAlgoBestPerf.algo = kFwdDefaultAlgo;
CUDNN_CHECK_ERR(
@@ -372,7 +372,7 @@ Tensor CudnnAutogradExtension::conv2d(
)
);
wspace =
- Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8);
+ Tensor{{static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8};
}
{
DevicePtr inPtr(input);
@@ -405,7 +405,7 @@ Tensor CudnnAutogradExtension::conv2d(
);
if(hasBias) {
- auto bsDesc = TensorDescriptor(bias);
+ auto bsDesc = TensorDescriptor{bias};
DevicePtr bsPtr(bias);
// ensure cudnn compute stream waits on stream of bias tensor
relativeSync(cudnnStream, {bias});
@@ -453,10 +453,10 @@ Tensor CudnnAutogradExtension::conv2dBackwardData(
// benchmarking suggests input or weight casting should occur, these
// descriptors may not be used/new ones with the correct types will be
// used instead.
- auto iDesc = TensorDescriptor(input);
- auto wDesc = FilterDescriptor(weight);
- auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups);
- auto oDesc = TensorDescriptor(gradOutput);
+ auto iDesc = TensorDescriptor{input};
+ auto wDesc = FilterDescriptor{weight};
+ auto cDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups};
+ auto oDesc = TensorDescriptor{gradOutput};
setDefaultMathType(cDesc, input);
@@ -491,10 +491,10 @@ Tensor CudnnAutogradExtension::conv2dBackwardData(
Tensor ws;
try {
- ws = Tensor(
+ ws = Tensor{
{static_cast(bwdDataAlgoBestPerf.memory)},
fl::dtype::b8
- );
+ };
} catch(const std::exception&) {
bwdDataAlgoBestPerf.algo = kBwdDataDefaultAlgo;
CUDNN_CHECK_ERR(
@@ -508,13 +508,13 @@ Tensor CudnnAutogradExtension::conv2dBackwardData(
&bwdDataAlgoBestPerf.memory
)
);
- ws = Tensor(
+ ws = Tensor{
{static_cast(bwdDataAlgoBestPerf.memory)},
fl::dtype::b8
- );
+ };
}
- auto gradInput = Tensor(inTensor.shape(), inTensor.type());
+ auto gradInput = Tensor{inTensor.shape(), inTensor.type()};
{
DevicePtr gradInputPtr(gradInput);
DevicePtr gradResultPtr(gradOutputTensor);
@@ -577,11 +577,11 @@ Tensor CudnnAutogradExtension::conv2dBackwardData(
/* incrementCount = */ false
);
- auto iDescF32 = TensorDescriptor(inTensorF32);
- auto wDescF32 = FilterDescriptor(wtTensorF32);
+ auto iDescF32 = TensorDescriptor{inTensorF32};
+ auto wDescF32 = FilterDescriptor{wtTensorF32};
auto cDescF32 =
- ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups);
- auto oDescF32 = TensorDescriptor(gradOutputTensorF32);
+ ConvDescriptor{fl::dtype::f32, px, py, sx, sy, dx, dy, groups};
+ auto oDescF32 = TensorDescriptor{gradOutputTensorF32};
// core bwd data computation
dataGradBenchmark->audit(
[&dataGradOut,
@@ -671,10 +671,10 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
// benchmarking suggests input or weight casting should occur, these
// descriptors may not be used/new ones with the correct types will be
// used instead.
- auto iDesc = TensorDescriptor(input);
- auto wDesc = FilterDescriptor(weight);
- auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups);
- auto oDesc = TensorDescriptor(gradOutput);
+ auto iDesc = TensorDescriptor{input};
+ auto wDesc = FilterDescriptor{weight};
+ auto cDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups};
+ auto oDesc = TensorDescriptor{gradOutput};
setDefaultMathType(cDesc, input);
@@ -708,10 +708,10 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
Tensor ws;
try {
- ws = Tensor(
+ ws = Tensor{
{static_cast(bwdFilterAlgoBestPerf.memory)},
fl::dtype::b8
- );
+ };
} catch(const std::exception&) {
bwdFilterAlgoBestPerf.algo = kBwdFilterDefaultAlgo;
CUDNN_CHECK_ERR(
@@ -725,13 +725,13 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
&bwdFilterAlgoBestPerf.memory
)
);
- ws = Tensor(
+ ws = Tensor{
{static_cast(bwdFilterAlgoBestPerf.memory)},
fl::dtype::b8
- );
+ };
}
- auto gradWeight = Tensor(wtTensor.shape(), wtTensor.type());
+ auto gradWeight = Tensor{wtTensor.shape(), wtTensor.type()};
{
DevicePtr gradWeightPtr(gradWeight);
DevicePtr gradResultPtr(gradOutputTensor);
@@ -794,11 +794,11 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
/* incrementCount = */ false
);
- auto iDescF32 = TensorDescriptor(inTensorF32);
- auto wDescF32 = FilterDescriptor(wtTensorF32);
+ auto iDescF32 = TensorDescriptor{inTensorF32};
+ auto wDescF32 = FilterDescriptor{wtTensorF32};
auto cDescF32 =
- ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups);
- auto oDescF32 = TensorDescriptor(gradOutputTensorF32);
+ ConvDescriptor{fl::dtype::f32, px, py, sx, sy, dx, dy, groups};
+ auto oDescF32 = TensorDescriptor{gradOutputTensorF32};
// core bwd data computation
filterGradBenchmark->audit(
[&filterGradOut,
@@ -860,13 +860,13 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
const Tensor& bsTensor,
const Tensor& gradOutput,
const TensorDescriptor& oDesc) -> Tensor {
- auto gradBias = Tensor(bsTensor.shape(), bsTensor.type());
+ auto gradBias = Tensor{bsTensor.shape(), bsTensor.type()};
{
DevicePtr gradBiasPtr(gradBias);
DevicePtr gradResultPtr(gradOutput);
// ensure cudnn compute stream waits on gradient tensor streams
relativeSync(cudnnStream, {gradOutput, gradBias});
- auto bDesc = TensorDescriptor(bsTensor);
+ auto bDesc = TensorDescriptor{bsTensor};
CUDNN_CHECK_ERR(
cudnnConvolutionBackwardBias(
hndl,
@@ -911,7 +911,7 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
},
/* incrementCount = */ false
);
- auto oDescF32 = TensorDescriptor(gradOutputF32);
+ auto oDescF32 = TensorDescriptor{gradOutputF32};
// Perform bias gradient computation
biasGradBenchmark->audit(
[&biasGradOut,
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp
index 305a6cc..eaac140 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp
@@ -16,13 +16,11 @@ namespace fl {
std::shared_ptr CudnnAutogradExtension::createBenchmarkOptions() {
return std::make_shared(
std::make_shared>(
- std::vector(
- {
- KernelMode::F32,
- KernelMode::F32_ALLOW_CONVERSION,
- KernelMode::F16
- }
- ),
+ std::vector{
+ KernelMode::F32,
+ KernelMode::F32_ALLOW_CONVERSION,
+ KernelMode::F16
+ },
fl::kDynamicBenchmarkDefaultCount
)
);
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp
index 82cadcb..804bc6f 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp
@@ -25,16 +25,16 @@ struct DeviceHandle {
std::shared_ptr stream;
explicit DeviceHandle(std::shared_ptr _stream) : cudnnHandle(nullptr),
- stream(_stream) {
+ stream(_stream) {
CUDNN_CHECK_ERR(cudnnCreate(&cudnnHandle));
CUDNN_CHECK_ERR(cudnnSetStream(cudnnHandle, stream->handle()));
}
~DeviceHandle() {
if(cudnnHandle) {
-// See https://git.io/fNQnM - sometimes, at exit, the CUDA context
-// (or something) is already destroyed by the time a handle gets destroyed
-// because of an issue with the destruction order.
+ // See https://git.io/fNQnM - sometimes, at exit, the CUDA context
+ // (or something) is already destroyed by the time a handle gets destroyed
+ // because of an issue with the destruction order.
#ifdef NO_CUDNN_DESTROY_HANDLE
#else
CUDNN_CHECK_ERR(cudnnDestroy(cudnnHandle));
@@ -43,16 +43,16 @@ struct DeviceHandle {
}
};
-const float kFloatZero = 0.0;
-const float kFloatOne = 1.0;
+constexpr float kFloatZero = 0.0;
+constexpr float kFloatOne = 1.0;
-const double kDoubleZero = 0.0;
-const double kDoubleOne = 1.0;
+constexpr double kDoubleZero = 0.0;
+constexpr double kDoubleOne = 1.0;
// TODO: move this to CudnnAutogradExtension if we make it a singleton
std::unordered_map handles;
-const DeviceHandle& getActiveDeviceHandle() {
+DeviceHandle const& getActiveDeviceHandle() {
auto& manager = fl::DeviceManager::getInstance();
auto& cudaDevice =
manager.getActiveDevice(fl::DeviceType::CUDA).impl();
@@ -88,57 +88,42 @@ namespace fl {
void cudnnCheckErr(cudnnStatus_t status) {
if(status == CUDNN_STATUS_SUCCESS)
return;
- const char* err = cudnnGetErrorString(status);
+ char const* err = cudnnGetErrorString(status);
switch(status) {
- case CUDNN_STATUS_BAD_PARAM:
- throw std::invalid_argument(err);
- default:
- throw std::runtime_error(err);
+ case CUDNN_STATUS_BAD_PARAM: throw std::invalid_argument(err);
+ default: throw std::runtime_error(err);
}
}
-cudnnDataType_t cudnnMapToType(const fl::dtype& t) {
+cudnnDataType_t cudnnMapToType(fl::dtype const& t) {
switch(t) {
- case fl::dtype::f16:
- return CUDNN_DATA_HALF;
- case fl::dtype::f32:
- return CUDNN_DATA_FLOAT;
- case fl::dtype::f64:
- return CUDNN_DATA_DOUBLE;
- default:
- throw std::invalid_argument("unsupported data type for cuDNN");
+ case fl::dtype::f16: return CUDNN_DATA_HALF;
+ case fl::dtype::f32: return CUDNN_DATA_FLOAT;
+ case fl::dtype::f64: return CUDNN_DATA_DOUBLE;
+ default: throw std::invalid_argument("unsupported data type for cuDNN");
}
}
-cudnnPoolingMode_t cudnnMapToPoolingMode(const PoolingMode mode) {
+cudnnPoolingMode_t cudnnMapToPoolingMode(PoolingMode const mode) {
switch(mode) {
- case PoolingMode::MAX:
- return CUDNN_POOLING_MAX;
- case PoolingMode::AVG_INCLUDE_PADDING:
- return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
- case PoolingMode::AVG_EXCLUDE_PADDING:
- return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
- default:
- throw std::invalid_argument("unsupported pooling mode for cuDNN");
+ case PoolingMode::MAX: return CUDNN_POOLING_MAX;
+ case PoolingMode::AVG_INCLUDE_PADDING: return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
+ case PoolingMode::AVG_EXCLUDE_PADDING: return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
+ default: throw std::invalid_argument("unsupported pooling mode for cuDNN");
}
}
-cudnnRNNMode_t cudnnMapToRNNMode(const RnnMode mode) {
+cudnnRNNMode_t cudnnMapToRNNMode(RnnMode const mode) {
switch(mode) {
- case RnnMode::RELU:
- return CUDNN_RNN_RELU;
- case RnnMode::TANH:
- return CUDNN_RNN_TANH;
- case RnnMode::LSTM:
- return CUDNN_LSTM;
- case RnnMode::GRU:
- return CUDNN_GRU;
- default:
- throw std::invalid_argument("unsupported RNN mode for cuDNN");
+ case RnnMode::RELU: return CUDNN_RNN_RELU;
+ case RnnMode::TANH: return CUDNN_RNN_TANH;
+ case RnnMode::LSTM: return CUDNN_LSTM;
+ case RnnMode::GRU: return CUDNN_GRU;
+ default: throw std::invalid_argument("unsupported RNN mode for cuDNN");
}
}
-TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) {
+TensorDescriptor::TensorDescriptor(fl::dtype const type, Shape const& flDims) {
CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor));
cudnnDataType_t cudnntype = cudnnMapToType(type);
@@ -165,7 +150,7 @@ TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) {
);
}
-TensorDescriptor::TensorDescriptor(const Tensor& input) {
+TensorDescriptor::TensorDescriptor(Tensor const& input) {
CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor));
cudnnDataType_t cudnntype = cudnnMapToType(input.type());
@@ -194,21 +179,19 @@ TensorDescriptor::TensorDescriptor(const Tensor& input) {
);
}
-TensorDescriptor::~TensorDescriptor() {
- CUDNN_CHECK_ERR(cudnnDestroyTensorDescriptor(descriptor));
-}
+TensorDescriptor::~TensorDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyTensorDescriptor(descriptor)); }
TensorDescriptorArray::TensorDescriptorArray(
int size,
- const fl::dtype type,
- const Shape& dims
+ fl::dtype const type,
+ Shape const& dims
) {
- desc_vec.reserve(size);
+ _descVec.reserve(size);
for(int i = 0; i < size; i++) {
- desc_vec.emplace_back(type, dims);
- desc_raw_vec.push_back(desc_vec.back().descriptor);
+ _descVec.emplace_back(type, dims);
+ _descRawVec.push_back(_descVec.back().descriptor);
}
- descriptors = desc_raw_vec.data();
+ descriptors = _descRawVec.data();
}
TensorDescriptorArray::~TensorDescriptorArray() = default;
@@ -241,11 +224,9 @@ PoolingDescriptor::PoolingDescriptor(
);
}
-PoolingDescriptor::~PoolingDescriptor() {
- CUDNN_CHECK_ERR(cudnnDestroyPoolingDescriptor(descriptor));
-}
+PoolingDescriptor::~PoolingDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyPoolingDescriptor(descriptor)); }
-FilterDescriptor::FilterDescriptor(const Tensor& input) {
+FilterDescriptor::FilterDescriptor(Tensor const& input) {
CUDNN_CHECK_ERR(cudnnCreateFilterDescriptor(&descriptor));
cudnnDataType_t cudnntype = cudnnMapToType(input.type());
@@ -267,121 +248,146 @@ FilterDescriptor::FilterDescriptor(const Tensor& input) {
);
}
-FilterDescriptor::~FilterDescriptor() {
- CUDNN_CHECK_ERR(cudnnDestroyFilterDescriptor(descriptor));
-}
+FilterDescriptor::~FilterDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyFilterDescriptor(descriptor)); }
-DropoutDescriptor::DropoutDescriptor(float drop_prob) {
+DropoutDescriptor::DropoutDescriptor(float dropProb) {
CUDNN_CHECK_ERR(cudnnCreateDropoutDescriptor(&descriptor));
- auto cudnnHandle = getCudnnHandle();
- unsigned long long seed = 0;
- size_t state_size;
- CUDNN_CHECK_ERR(cudnnDropoutGetStatesSize(cudnnHandle, &state_size));
- auto& dropout_states = getDropoutStates();
- if(dropout_states.isEmpty()) {
- dropout_states =
- Tensor({static_cast(state_size)}, fl::dtype::b8);
- DevicePtr statesraw(dropout_states);
+
+ auto const cudnnHandle = getCudnnHandle();
+ constexpr unsigned long long seed = 0;
+ size_t stateSize;
+
+ CUDNN_CHECK_ERR(cudnnDropoutGetStatesSize(cudnnHandle, &stateSize));
+
+ auto& dropoutStates = getDropoutStates();
+
+ if(dropoutStates.isEmpty()) {
+ dropoutStates =
+ Tensor{{static_cast(stateSize)}, fl::dtype::b8};
+ DevicePtr statesraw(dropoutStates);
CUDNN_CHECK_ERR(
cudnnSetDropoutDescriptor(
descriptor,
cudnnHandle,
- drop_prob,
+ dropProb,
statesraw.get(),
- state_size,
+ stateSize,
seed
)
);
- } else {
- DevicePtr statesraw(dropout_states);
-// See https://git.io/fp9oo for an explanation.
-#if CUDNN_VERSION >= 7000
+ }
+ else {
+ DevicePtr statesraw(dropoutStates);
CUDNN_CHECK_ERR(
cudnnRestoreDropoutDescriptor(
descriptor,
cudnnHandle,
- drop_prob,
+ dropProb,
statesraw.get(),
- state_size,
+ stateSize,
seed
)
);
-#else
- auto dropout_struct = reinterpret_cast(descriptor);
- dropout_struct->dropout = drop_prob;
- dropout_struct->nstates = state_size;
- dropout_struct->states = statesraw.get();
-#endif
}
}
-DropoutDescriptor::~DropoutDescriptor() {
- CUDNN_CHECK_ERR(cudnnDestroyDropoutDescriptor(descriptor));
-}
+DropoutDescriptor::~DropoutDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyDropoutDescriptor(descriptor)); }
Tensor& DropoutDescriptor::getDropoutStates() {
- thread_local Tensor dropout_states;
- return dropout_states;
+ thread_local Tensor dropoutStates;
+ return dropoutStates;
}
RNNDescriptor::RNNDescriptor(
fl::dtype type,
- int hidden_size,
- int num_layers,
+ int inputSize,
+ int hiddenSize,
+ int numLayers,
RnnMode mode,
bool bidirectional,
DropoutDescriptor& dropout
) {
- CUDNN_CHECK_ERR(cudnnCreateRNNDescriptor(&descriptor));
-
- auto cudnnHandle = getCudnnHandle();
+ CUDNN_CHECK_ERR(cudnnCreateRNNDescriptor(&_handle));
- cudnnRNNInputMode_t in_mode = CUDNN_LINEAR_INPUT;
+ constexpr auto inMode = CUDNN_LINEAR_INPUT;
+ constexpr auto algo = CUDNN_RNN_ALGO_STANDARD;
- cudnnDirectionMode_t dir =
- bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
+ auto const dir = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
- cudnnRNNMode_t cell = cudnnMapToRNNMode(mode);
- cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
- cudnnDataType_t cudnntype = cudnnMapToType(type);
+ auto const cell = cudnnMapToRNNMode(mode);
+ auto const dataType = cudnnMapToType(type);
-#if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000
CUDNN_CHECK_ERR(
- cudnnSetRNNDescriptor(
- cudnnHandle,
- descriptor,
- hidden_size,
- num_layers,
- dropout.descriptor,
- in_mode,
- dir,
- cell,
+ //https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn-892/api/index.html#cudnnSetRNNDescriptor_v8
+ cudnnSetRNNDescriptor_v8(
+ _handle,
algo,
- cudnntype
+ cell,
+ CUDNN_RNN_DOUBLE_BIAS, //TODO review; double is default for old cudnn
+ dir,
+ inMode,
+ dataType,
+ dataType, // math precision
+ mathType(type),
+ inputSize,
+ hiddenSize,
+ hiddenSize, //projection size (unused)
+ numLayers,
+ dropout.descriptor,
+ 0
)
);
-#else
+}
+
+RNNDescriptor::~RNNDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyRNNDescriptor(_handle)); }
+
+}
+
+namespace fl {
+
+
+RNNDataDescriptor::RNNDataDescriptor(fl::dtype type, Shape const& dims) {
+ create();
+ auto const inputSize = dims.ndim() > 0 ? static_cast(dims[0]) : 1;
+ auto const batchSize = dims.ndim() > 1 ? static_cast(dims[1]) : 1;
+ auto const maxSeqSize = dims.ndim() > 2 ? static_cast(dims[2]) : 1;
+
+ std::vector seqSizes(batchSize, maxSeqSize);
+
+ set(
+ type,
+ inputSize,
+ maxSeqSize,
+ seqSizes
+ );
+}
+
+RNNDataDescriptor::~RNNDataDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyRNNDataDescriptor(_handle)); }
+void RNNDataDescriptor::create() { CUDNN_CHECK_ERR(cudnnCreateRNNDataDescriptor(&_handle)); }
+void RNNDataDescriptor::set(
+ fl::dtype type,
+ int inputSize,
+ int maxSeqSize,
+ std::span sequenceSizes
+) const {
CUDNN_CHECK_ERR(
- cudnnSetRNNDescriptor_v6(
- cudnnHandle,
- descriptor,
- hidden_size,
- num_layers,
- dropout.descriptor,
- in_mode,
- dir,
- cell,
- algo,
- cudnntype
+ cudnnSetRNNDataDescriptor(
+ _handle,
+ cudnnMapToType(type),
+ CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED,
+ maxSeqSize,
+ sequenceSizes.size(), //batch size
+ inputSize,
+ sequenceSizes.data(),
+ nullptr //no padding
)
);
-#endif
}
-RNNDescriptor::~RNNDescriptor() {
- CUDNN_CHECK_ERR(cudnnDestroyRNNDescriptor(descriptor));
}
+namespace fl {
+
ConvDescriptor::ConvDescriptor(
fl::dtype type,
int px,
@@ -413,35 +419,27 @@ ConvDescriptor::ConvDescriptor(
CUDNN_CHECK_ERR(cudnnSetConvolutionGroupCount(descriptor, groups));
}
-ConvDescriptor::~ConvDescriptor() {
- CUDNN_CHECK_ERR(cudnnDestroyConvolutionDescriptor(descriptor));
-}
+ConvDescriptor::~ConvDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyConvolutionDescriptor(descriptor)); }
cudnnHandle_t getCudnnHandle() { return getActiveDeviceHandle().cudnnHandle; }
-const CUDAStream& getCudnnStream() { return *getActiveDeviceHandle().stream; }
+CUDAStream const& getCudnnStream() { return *getActiveDeviceHandle().stream; }
-const void* kOne(const fl::dtype t) {
+void const* kOne(fl::dtype const t) {
switch(t) {
case fl::dtype::f16:
- case fl::dtype::f32:
- return &kFloatOne;
- case fl::dtype::f64:
- return &kDoubleOne;
- default:
- throw std::invalid_argument("unsupported data type for cuDNN");
+ case fl::dtype::f32: return &kFloatOne;
+ case fl::dtype::f64: return &kDoubleOne;
+ default: throw std::invalid_argument("unsupported data type for cuDNN");
}
}
-const void* kZero(const fl::dtype t) {
+void const* kZero(fl::dtype const t) {
switch(t) {
case fl::dtype::f16:
- case fl::dtype::f32:
- return &kFloatZero;
- case fl::dtype::f64:
- return &kDoubleZero;
- default:
- throw std::invalid_argument("unsupported data type for cuDNN");
+ case fl::dtype::f32: return &kFloatZero;
+ case fl::dtype::f64: return &kDoubleZero;
+ default: throw std::invalid_argument("unsupported data type for cuDNN");
}
}
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h
index fca9969..ad83be3 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h
@@ -13,13 +13,15 @@
#include "flashlight/fl/runtime/CUDAStream.h"
#include "flashlight/fl/tensor/TensorBase.h"
+#include
+
namespace fl {
class TensorDescriptor {
public:
- explicit TensorDescriptor(const Tensor& a);
+ explicit TensorDescriptor(Tensor const& a);
- TensorDescriptor(const fl::dtype type, const Shape& af_dims);
+ TensorDescriptor(fl::dtype const type, Shape const& afDims);
cudnnTensorDescriptor_t descriptor;
~TensorDescriptor();
@@ -27,19 +29,19 @@ class TensorDescriptor {
class TensorDescriptorArray {
public:
- TensorDescriptorArray(int size, const fl::dtype type, const Shape& dims);
+ TensorDescriptorArray(int size, fl::dtype const type, Shape const& dims);
cudnnTensorDescriptor_t* descriptors;
~TensorDescriptorArray();
private:
- std::vector desc_vec;
- std::vector desc_raw_vec;
+ std::vector _descVec;
+ std::vector _descRawVec;
};
class FilterDescriptor {
public:
- explicit FilterDescriptor(const Tensor& a);
+ explicit FilterDescriptor(Tensor const& input);
cudnnFilterDescriptor_t descriptor;
~FilterDescriptor();
};
@@ -77,7 +79,7 @@ class PoolingDescriptor {
class DropoutDescriptor {
public:
- explicit DropoutDescriptor(float drop_prob);
+ explicit DropoutDescriptor(float dropProb);
cudnnDropoutDescriptor_t descriptor;
~DropoutDescriptor();
@@ -88,28 +90,65 @@ class RNNDescriptor {
public:
RNNDescriptor(
fl::dtype type,
- int hidden_size,
- int num_layers,
+ int inputSize,
+ int hiddenSize,
+ int numLayers,
RnnMode mode,
bool bidirectional,
DropoutDescriptor& dropout
);
- cudnnRNNDescriptor_t descriptor;
~RNNDescriptor();
+
+private:
+ cudnnRNNDescriptor_t _handle = nullptr;
+
+ static constexpr auto mathType(fl::dtype type) {
+ return type == fl::dtype::f16 ? CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION : CUDNN_DEFAULT_MATH;
+ }
+
+public:
+ /**
+ * @return descriptor handle
+ */
+ constexpr auto get() const { return _handle; }
};
+
+class RNNDataDescriptor {
+public:
+ RNNDataDescriptor(
+ fl::dtype type,
+ Shape const& dims
+ );
+
+ ~RNNDataDescriptor();
+
+private:
+ void create();
+ void set(dtype type, int inputSize, int maxSeqSize, std::span sequenceSizes) const;
+
+ cudnnRNNDataDescriptor_t _handle = nullptr;
+
+public:
+ /**
+ * @return descriptor handle
+ */
+ constexpr auto get() const { return _handle; }
+};
+
+
#define CUDNN_CHECK_ERR(expr) ::fl::cudnnCheckErr((expr))
void cudnnCheckErr(cudnnStatus_t status);
-cudnnDataType_t cudnnMapToType(const fl::dtype& t);
+cudnnDataType_t cudnnMapToType(fl::dtype const& t);
-const void* kOne(const fl::dtype t);
+void const* kOne(fl::dtype const t);
-const void* kZero(const fl::dtype t);
+void const* kZero(fl::dtype const t);
// TODO: move this to CudnnAutogradExtension if we make it a singleton
cudnnHandle_t getCudnnHandle();
-const CUDAStream& getCudnnStream();
+CUDAStream const& getCudnnStream();
} // namespace fl
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp
index 24b08c8..ef838bb 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp
@@ -25,10 +25,10 @@ Tensor CudnnAutogradExtension::pool2d(
const PoolingMode mode,
std::shared_ptr
) {
- auto inDesc = TensorDescriptor(input);
+ auto inDesc = TensorDescriptor{input};
// init pooling descriptor
- auto poolDesc = PoolingDescriptor(wx, wy, sx, sy, px, py, mode);
+ auto poolDesc = PoolingDescriptor{wx, wy, sx, sy, px, py, mode};
// init output descriptor
auto ix = input.dim(0);
@@ -36,7 +36,7 @@ Tensor CudnnAutogradExtension::pool2d(
auto ox = 1 + (ix + 2 * px - wx) / sx;
auto oy = 1 + (iy + 2 * py - wy) / sy;
- auto output = Tensor(
+ auto output = Tensor{
{
ox,
oy,
@@ -44,8 +44,8 @@ Tensor CudnnAutogradExtension::pool2d(
input.ndim() < 4 ? 1 : input.dim(3)
},
input.type()
- );
- auto outDesc = TensorDescriptor(output);
+ };
+ auto outDesc = TensorDescriptor{output};
{
DevicePtr inputraw(input);
DevicePtr resultraw(output);
@@ -90,11 +90,11 @@ Tensor CudnnAutogradExtension::pool2dBackward(
const PoolingMode mode,
std::shared_ptr
) {
- auto i_desc = TensorDescriptor(input);
- auto o_desc = TensorDescriptor(poolOutput);
- auto p_desc = PoolingDescriptor(wx, wy, sx, sy, px, py, mode);
+ auto i_desc = TensorDescriptor{input};
+ auto o_desc = TensorDescriptor{poolOutput};
+ auto p_desc = PoolingDescriptor{wx, wy, sx, sy, px, py, mode};
- auto gradInput = Tensor(input.shape(), input.type());
+ auto gradInput = Tensor{input.shape(), input.type()};
auto hndl = getCudnnHandle();
const auto& cudnnStream = getCudnnStream();
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp
index 17b242b..568e4d5 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp
@@ -15,58 +15,34 @@
namespace fl {
namespace {
- size_t getWorkspaceSize(
- cudnnHandle_t handle,
- const RNNDescriptor& rnnDesc,
- const int seqLength,
- const TensorDescriptorArray& xDescs
- ) {
- size_t workspaceSize;
- CUDNN_CHECK_ERR(
- cudnnGetRNNWorkspaceSize(
- handle,
- rnnDesc.descriptor,
- seqLength,
- xDescs.descriptors,
- &workspaceSize
- )
- );
- return workspaceSize;
- }
+ struct temp_space_sizes {
+ size_t size;
+ size_t reserveSize;
+ };
- size_t getReserveSize(
+ temp_space_sizes rnnTempSpaceSizes(
cudnnHandle_t handle,
- const RNNDescriptor& rnnDesc,
- const int seqLength,
- const TensorDescriptorArray& xDescs
+ RNNDescriptor const& rnnDescriptor,
+ RNNDataDescriptor const& xDescriptor,
+ cudnnForwardMode_t mode
) {
- size_t reserveSize;
+ temp_space_sizes sizes{};
+
CUDNN_CHECK_ERR(
- cudnnGetRNNTrainingReserveSize(
+ cudnnGetRNNTempSpaceSizes(
handle,
- rnnDesc.descriptor,
- seqLength,
- xDescs.descriptors,
- &reserveSize
+ rnnDescriptor.get(),
+ mode,
+ xDescriptor.get(),
+ &sizes.size,
+ &sizes.reserveSize
)
);
- return reserveSize;
- }
- void setCudnnRnnMathType(const Tensor& input, const RNNDescriptor& rnnDesc) {
- if(input.type() == fl::dtype::f16)
- CUDNN_CHECK_ERR(
- cudnnSetRNNMatrixMathType(
- rnnDesc.descriptor,
- CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
- )
- );
- else
- CUDNN_CHECK_ERR(
- cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH)
- );
+ return sizes;
}
+
struct CudnnRnnAutogradPayload : public detail::AutogradPayloadData {
Tensor reserveSpace;
};
@@ -74,15 +50,15 @@ namespace {
} // namespace
std::tuple CudnnAutogradExtension::rnn(
- const Tensor& input,
- const Tensor& hiddenStateIn,
- const Tensor& cellStateIn,
- const Tensor& weights,
- const int hiddenSize,
- const int numLayers,
- const RnnMode mode,
- const bool bidirectional,
- const float dropProb,
+ Tensor const& input,
+ Tensor const& hiddenStateIn,
+ Tensor const& cellStateIn,
+ Tensor const& weights,
+ int const hiddenSize,
+ int const numLayers,
+ RnnMode const mode,
+ bool const bidirectional,
+ float const dropProb,
std::shared_ptr autogradPayload
) {
FL_TENSOR_DTYPES_MATCH_CHECK(input, hiddenStateIn, cellStateIn, weights);
@@ -93,25 +69,33 @@ std::tuple CudnnAutogradExtension::rnn(
autogradPayload->data = payload;
Tensor x = input.asContiguousTensor();
+ RNNDataDescriptor xDesc{x.type(), x.shape()};
+
Tensor hiddenState = hiddenStateIn.asContiguousTensor();
Tensor cellState = cellStateIn.asContiguousTensor();
- DropoutDescriptor dropout(dropProb);
- RNNDescriptor rnnDesc(
- input.type(), hiddenSize, numLayers, mode, bidirectional, dropout);
- setCudnnRnnMathType(input, rnnDesc);
+ DropoutDescriptor dropout{dropProb};
- auto dims = x.shape();
+ auto const& dims = x.shape();
int inputSize = dims[0];
int batchSize = dims.ndim() < 2 ? 1 : dims[1];
int seqLength = dims.ndim() < 3 ? 1 : dims[2];
+
+ RNNDescriptor rnnDesc{
+ input.type(),
+ inputSize,
+ hiddenSize,
+ numLayers,
+ mode,
+ bidirectional,
+ dropout
+ };
+
+
int totalLayers = numLayers * (bidirectional ? 2 : 1);
int outSize = hiddenSize * (bidirectional ? 2 : 1);
- TensorDescriptorArray xDescs(
- seqLength, x.type(), {1, 1, inputSize, batchSize});
-
if(!hiddenState.isEmpty()) {
auto hxDims = hiddenState.shape();
int hxHiddenSize = hxDims[0];
@@ -119,31 +103,31 @@ std::tuple CudnnAutogradExtension::rnn(
int hxTotalLayers = hiddenState.ndim() < 3 ? 1 : hxDims[2];
if(
- !(hxHiddenSize == hiddenSize && hxBatchSize == batchSize
- && hxTotalLayers == totalLayers)
+ hxHiddenSize != hiddenSize || hxBatchSize != batchSize
+ || hxTotalLayers != totalLayers
)
throw std::invalid_argument("invalid hidden state dims for RNN");
}
if(
!cellState.isEmpty()
- && !(mode == RnnMode::LSTM && cellState.dim(0) == hiddenSize
- && cellState.dim(1) == batchSize && cellState.dim(2) == totalLayers)
+ && (mode != RnnMode::LSTM || cellState.dim(0) != hiddenSize
+ || cellState.dim(1) != batchSize || cellState.dim(2) != totalLayers)
)
throw std::invalid_argument("invalid cell state dims for RNN");
Shape hDims = {1, hiddenSize, batchSize, totalLayers};
- TensorDescriptor hxDesc(x.type(), hDims);
- TensorDescriptor cxDesc(x.type(), hDims);
+ TensorDescriptor hxDesc{x.type(), hDims};
+ TensorDescriptor cxDesc{x.type(), hDims};
auto handle = getCudnnHandle();
- const auto& cudnnStream = getCudnnStream();
+ auto const& cudnnStream = getCudnnStream();
size_t paramSize;
CUDNN_CHECK_ERR(
cudnnGetRNNParamsSize(
handle,
- rnnDesc.descriptor,
+ rnnDesc._handle,
xDescs.descriptors[0],
¶mSize,
cudnnMapToType(weights.type())
@@ -155,24 +139,24 @@ std::tuple CudnnAutogradExtension::rnn(
);
FilterDescriptor wDesc(weights);
- Tensor y({outSize, batchSize, seqLength}, input.type());
- TensorDescriptorArray yDesc(seqLength, y.type(), {1, 1, outSize, batchSize});
+ Tensor y{{outSize, batchSize, seqLength}, input.type()};
+ TensorDescriptorArray yDesc{seqLength, y.type(), {1, 1, outSize, batchSize}};
- Tensor hy({hiddenSize, batchSize, totalLayers}, x.type());
- TensorDescriptor hyDesc(x.type(), hDims);
+ Tensor hy{{hiddenSize, batchSize, totalLayers}, x.type()};
+ TensorDescriptor hyDesc{x.type(), hDims};
- Tensor cy;
+ Tensor cy{};
if(mode == RnnMode::LSTM)
- cy = Tensor(hy.shape(), x.type());
+ cy = Tensor{hy.shape(), x.type()};
- TensorDescriptor cyDesc(x.type(), hDims);
+ TensorDescriptor cyDesc{x.type(), hDims};
- size_t workspaceSize = getWorkspaceSize(handle, rnnDesc, seqLength, xDescs);
- size_t reserveSize = getReserveSize(handle, rnnDesc, seqLength, xDescs);
+ constexpr auto forwardMode = CUDNN_FWD_MODE_TRAINING;
+ auto [workspaceSize, reserveSize] = rnnTempSpaceSizes(handle, rnnDesc, xDesc, forwardMode);
Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8);
// Space must be reused between forward and backward for cuDNN
- payload->reserveSpace = Tensor({static_cast(reserveSize)}, fl::dtype::b8);
+ payload->reserveSpace = Tensor{{static_cast(reserveSize)}, fl::dtype::b8};
{
auto contiguousX = x.asContiguousTensor();
@@ -190,34 +174,40 @@ std::tuple CudnnAutogradExtension::rnn(
relativeSync(
cudnnStream,
{
- contiguousX, hiddenState, cellState, contiguousWeights, y, hy, cy,
- workspace, payload->reserveSpace,
+ contiguousX,
+ hiddenState,
+ cellState,
+ contiguousWeights,
+ y,
+ hy,
+ cy,
+ workspace,
+ payload->reserveSpace,
}
);
CUDNN_CHECK_ERR(
- cudnnRNNForwardTraining(
+ cudnnRNNForward(
handle,
- rnnDesc.descriptor,
+ rnnDesc.get(),
seqLength,
- xDescs.descriptors,
+ xDesc.get(),
xRaw.get(),
- hxDesc.descriptor,
- hxRaw.get(),
- cxDesc.descriptor,
- cxRaw.get(),
- wDesc.descriptor,
- wRaw.get(),
- yDesc.descriptors,
- yRaw.get(),
- hyDesc.descriptor,
- hyRaw.get(),
- cyDesc.descriptor,
- cyRaw.get(),
- workspaceRaw.get(),
+ yDesc.get(),
+ yRaw,
+ hxDesc,
+ hxRaw,
+ hyRaw,
+ cxDesc,
+ cxRaw,
+ cyRaw,
+ weightspace and its size????,
+ //TEMP continue here
+
workspaceSize,
- reserveSpaceRaw.get(),
- reserveSize
+ workspaceRaw
+ reserveSize,
+ reserveSpaceRaw.get()
)
);
}
@@ -228,17 +218,17 @@ std::tuple CudnnAutogradExtension::rnn(
}
std::tuple CudnnAutogradExtension::rnnBackward(
- const Tensor& input,
- const Tensor& hiddenState,
- const Tensor& cellState,
- const Tensor& weights,
- const std::shared_ptr gradData,
- const Tensor& output,
- const int numLayers,
- const int hiddenSize,
- const RnnMode mode,
- const bool bidirectional,
- const float dropProb,
+ Tensor const& input,
+ Tensor const& hiddenState,
+ Tensor const& cellState,
+ Tensor const& weights,
+ std::shared_ptr const gradData,
+ Tensor const& output,
+ int const numLayers,
+ int const hiddenSize,
+ RnnMode const mode,
+ bool const bidirectional,
+ float const dropProb,
std::shared_ptr autogradPayload
) {
if(!autogradPayload)
@@ -249,7 +239,7 @@ std::tuple CudnnAutogradExtension::rnnBackward(
std::static_pointer_cast(autogradPayload->data);
auto handle = getCudnnHandle();
- const auto& cudnnStream = getCudnnStream();
+ auto const& cudnnStream = getCudnnStream();
auto x = input.asContiguousTensor();
auto& y = output;
@@ -261,29 +251,32 @@ std::tuple CudnnAutogradExtension::rnnBackward(
int totalLayers = numLayers * (bidirectional ? 2 : 1);
int outSize = hiddenSize * (bidirectional ? 2 : 1);
- DropoutDescriptor dropout(dropProb);
- RNNDescriptor rnnDesc(input.type(), hiddenSize, numLayers, mode, bidirectional, dropout);
+ DropoutDescriptor dropout{dropProb};
+ RNNDescriptor rnnDesc{input.type(), hiddenSize, numLayers, mode, bidirectional, dropout};
setCudnnRnnMathType(input, rnnDesc);
- TensorDescriptorArray yDesc(seqLength, y.type(), {1, 1, outSize, batchSize});
- TensorDescriptorArray dyDesc(seqLength, y.type(), {1, 1, outSize, batchSize});
+ TensorDescriptorArray yDesc{seqLength, y.type(), {1, 1, outSize, batchSize}};
+ TensorDescriptorArray dyDesc{seqLength, y.type(), {1, 1, outSize, batchSize}};
Shape hDims = {1, hiddenSize, batchSize, totalLayers};
- TensorDescriptor dhyDesc(x.type(), hDims);
- TensorDescriptor dcyDesc(x.type(), hDims);
- TensorDescriptor hxDesc(x.type(), hDims);
- TensorDescriptor cxDesc(x.type(), hDims);
+ TensorDescriptor dhyDesc{x.type(), hDims};
+ TensorDescriptor dcyDesc{x.type(), hDims};
+ TensorDescriptor hxDesc{x.type(), hDims};
+ TensorDescriptor cxDesc{x.type(), hDims};
- Tensor dhx(hiddenState.shape(), hiddenState.type());
- Tensor dcx(cellState.shape(), cellState.type());
- TensorDescriptor dhxDesc(x.type(), hDims);
- TensorDescriptor dcxDesc(x.type(), hDims);
+ Tensor dhx{hiddenState.shape(), hiddenState.type()};
+ Tensor dcx{cellState.shape(), cellState.type()};
+ TensorDescriptor dhxDesc{x.type(), hDims};
+ TensorDescriptor dcxDesc{x.type(), hDims};
FilterDescriptor wDesc(weights);
- Tensor dx(input.shape(), input.type());
- TensorDescriptorArray dxDescs(
- seqLength, dx.type(), {1, 1, inputSize, batchSize});
+ Tensor dx{input.shape(), input.type()};
+ TensorDescriptorArray dxDescs{
+ seqLength,
+ dx.type(),
+ {1, 1, inputSize, batchSize}
+ };
size_t workspaceSize =
getWorkspaceSize(handle, rnnDesc, seqLength, dxDescs);
@@ -325,7 +318,7 @@ std::tuple CudnnAutogradExtension::rnnBackward(
CUDNN_CHECK_ERR(
cudnnRNNBackwardData(
handle,
- rnnDesc.descriptor,
+ rnnDesc._handle,
seqLength,
yDesc.descriptors,
yRaw.get(),
@@ -357,17 +350,20 @@ std::tuple CudnnAutogradExtension::rnnBackward(
if(input.type() == fl::dtype::f16)
CUDNN_CHECK_ERR(
- cudnnSetRNNMatrixMathType(
- rnnDesc.descriptor,
- CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
- )
- );
+ cudnnSetRNNMatrixMathType(
+ rnnDesc._handle,
+ CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
+ )
+ );
else
CUDNN_CHECK_ERR(
- cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH)
- );
- TensorDescriptorArray xDescs(
- seqLength, x.type(), {1, 1, inputSize, batchSize});
+ cudnnSetRNNMatrixMathType(rnnDesc._handle, CUDNN_DEFAULT_MATH)
+ );
+ TensorDescriptorArray xDescs{
+ seqLength,
+ x.type(),
+ {1, 1, inputSize, batchSize}
+ };
Tensor dw = fl::full(weights.shape(), 0, weights.type());
FilterDescriptor dwDesc(dw);
@@ -382,7 +378,7 @@ std::tuple CudnnAutogradExtension::rnnBackward(
CUDNN_CHECK_ERR(
cudnnRNNBackwardWeights(
handle,
- rnnDesc.descriptor,
+ rnnDesc._handle,
seqLength,
xDescs.descriptors,
xRaw.get(),