diff --git a/Folder.DotSettings b/Folder.DotSettings new file mode 100644 index 0000000..069ccf5 --- /dev/null +++ b/Folder.DotSettings @@ -0,0 +1,2 @@ + + <NamingElement Priority="6" Title="Parameters"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="function parameter" /><type Name="lambda parameter" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="aaBb"><ExtraRule Prefix="_" Suffix="" Style="aaBb" /></Policy></NamingElement> \ No newline at end of file diff --git a/cmake/dependencies/FindCUDNN.cmake b/cmake/dependencies/FindCUDNN.cmake index a150310..fd77eea 100644 --- a/cmake/dependencies/FindCUDNN.cmake +++ b/cmake/dependencies/FindCUDNN.cmake @@ -76,4 +76,8 @@ if(CUDNN_FOUND) endif() endif() +if (CUDNN_FOUND AND CUDNN_VERSION VERSION_LESS "8.0") + message(FATAL_ERROR "Flashlight requires cuDNN >= 8.0, found ${CUDNN_VERSION}") +endif() + mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE_DIR CUDNN_LIBRARY CUDNN_VERSION) diff --git a/cmake/utils/flashlightConfig.cmake.in b/cmake/utils/flashlightConfig.cmake.in index 2bf2550..9d73423 100644 --- a/cmake/utils/flashlightConfig.cmake.in +++ b/cmake/utils/flashlightConfig.cmake.in @@ -49,7 +49,7 @@ if (@FL_BUILD_STANDALONE@) endif() if (@FL_USE_CUDA@) if (@FL_USE_CUDNN@) - find_dependency(CUDNN 7.1) + find_dependency(CUDNN 8) endif() if (@FL_BUILD_DISTRIBUTED@) find_dependency(NCCL) diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp index 9f6b315..2538f88 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp @@ -48,15 +48,15 @@ namespace { if(minAxis == 0) { modeOut = CUDNN_BATCHNORM_PER_ACTIVATION; - inDescDimsOut = Shape( + inDescDimsOut = Shape{ { 1, 1, nfeatures, static_cast(input.elements() / nfeatures) } - ); - wtDescDimsOut = Shape({1, 1, nfeatures}); + }; + wtDescDimsOut = Shape{1, 1, nfeatures}; } else { modeOut = CUDNN_BATCHNORM_SPATIAL; #if CUDNN_VERSION >= 7003 @@ -67,15 +67,15 @@ namespace { int batchsz = 1; for(int i = maxAxis + 1; i < input.ndim(); ++i) batchsz *= input.dim(i); - inDescDimsOut = Shape( + inDescDimsOut = Shape{ { 1, static_cast(input.elements() / (nfeatures * batchsz)), nfeatures, batchsz, } - ); - wtDescDimsOut = Shape({1, 1, nfeatures}); + }; + wtDescDimsOut = Shape{1, 1, nfeatures}; } } @@ -101,7 +101,7 @@ Tensor CudnnAutogradExtension::batchnorm( ); FL_TENSOR_DTYPES_MATCH_CHECK(weight, bias, runningMean, runningVar); - auto output = Tensor(input.shape(), input.type()); + auto output = Tensor{input.shape(), input.type()}; cudnnBatchNormMode_t mode; Shape inDescDims, wtDescDims; @@ -122,8 +122,8 @@ Tensor CudnnAutogradExtension::batchnorm( fl::dtype scalarsType = input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); - auto inDesc = TensorDescriptor(input.type(), inDescDims); - auto wtDesc = TensorDescriptor(weightArray.type(), wtDescDims); + auto inDesc = TensorDescriptor{input.type(), inDescDims}; + auto wtDesc = TensorDescriptor{weightArray.type(), wtDescDims}; { DevicePtr inRaw(input); @@ -140,8 +140,8 @@ Tensor CudnnAutogradExtension::batchnorm( ); if(train) { - saveMean = Tensor({wtDescDims[2]}, scalarsType); - saveVar = Tensor({wtDescDims[2]}, scalarsType); + saveMean = Tensor{{wtDescDims[2]}, scalarsType}; + saveVar = Tensor{{wtDescDims[2]}, scalarsType}; DevicePtr saveMeanRaw(saveMean); DevicePtr saveVarRaw(saveVar); @@ -223,13 +223,13 @@ std::tuple CudnnAutogradExtension::batchnormBackward( const void* one1 = kOne(scalarsType); const void* zero0 = kZero(scalarsType); - auto iDesc = TensorDescriptor(input.type(), inDescDims); - auto wDesc = TensorDescriptor(wt.type(), wtDescDims); + auto iDesc = TensorDescriptor{input.type(), inDescDims}; + auto wDesc = TensorDescriptor{wt.type(), wtDescDims}; // CuDNN doesn't support calculating only the gradients // required for batchnorm - auto gradIn = Tensor(input.shape(), input.type()); - auto gradWt = Tensor(wt.shape(), wt.type()); - auto gradBs = Tensor(wt.shape(), wt.type()); + auto gradIn = Tensor{input.shape(), input.type()}; + auto gradWt = Tensor{wt.shape(), wt.type()}; + auto gradBs = Tensor{wt.shape(), wt.type()}; { DevicePtr iRaw(input); DevicePtr wRaw(wt); diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp index bb89e61..b9f63c6 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp @@ -314,9 +314,9 @@ Tensor CudnnAutogradExtension::conv2d( auto hasBias = bias.elements() > 0; - auto inDesc = TensorDescriptor(input); - auto wtDesc = FilterDescriptor(weights); - auto convDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); + auto inDesc = TensorDescriptor{input}; + auto wtDesc = FilterDescriptor{weights}; + auto convDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups}; if(input.type() == fl::dtype::f16) CUDNN_CHECK_ERR( cudnnSetConvolutionMathType( @@ -339,8 +339,8 @@ Tensor CudnnAutogradExtension::conv2d( odims.data() ) ); - auto output = Tensor({odims[3], odims[2], odims[1], odims[0]}, input.type()); - auto outDesc = TensorDescriptor(output); + auto output = Tensor{{odims[3], odims[2], odims[1], odims[0]}, input.type()}; + auto outDesc = TensorDescriptor{output}; auto handle = getCudnnHandle(); const auto& cudnnStream = getCudnnStream(); @@ -357,7 +357,7 @@ Tensor CudnnAutogradExtension::conv2d( try { wspace = - Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8); + Tensor{{static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8}; } catch(const std::exception&) { fwdAlgoBestPerf.algo = kFwdDefaultAlgo; CUDNN_CHECK_ERR( @@ -372,7 +372,7 @@ Tensor CudnnAutogradExtension::conv2d( ) ); wspace = - Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8); + Tensor{{static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8}; } { DevicePtr inPtr(input); @@ -405,7 +405,7 @@ Tensor CudnnAutogradExtension::conv2d( ); if(hasBias) { - auto bsDesc = TensorDescriptor(bias); + auto bsDesc = TensorDescriptor{bias}; DevicePtr bsPtr(bias); // ensure cudnn compute stream waits on stream of bias tensor relativeSync(cudnnStream, {bias}); @@ -453,10 +453,10 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( // benchmarking suggests input or weight casting should occur, these // descriptors may not be used/new ones with the correct types will be // used instead. - auto iDesc = TensorDescriptor(input); - auto wDesc = FilterDescriptor(weight); - auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); - auto oDesc = TensorDescriptor(gradOutput); + auto iDesc = TensorDescriptor{input}; + auto wDesc = FilterDescriptor{weight}; + auto cDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups}; + auto oDesc = TensorDescriptor{gradOutput}; setDefaultMathType(cDesc, input); @@ -491,10 +491,10 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( Tensor ws; try { - ws = Tensor( + ws = Tensor{ {static_cast(bwdDataAlgoBestPerf.memory)}, fl::dtype::b8 - ); + }; } catch(const std::exception&) { bwdDataAlgoBestPerf.algo = kBwdDataDefaultAlgo; CUDNN_CHECK_ERR( @@ -508,13 +508,13 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( &bwdDataAlgoBestPerf.memory ) ); - ws = Tensor( + ws = Tensor{ {static_cast(bwdDataAlgoBestPerf.memory)}, fl::dtype::b8 - ); + }; } - auto gradInput = Tensor(inTensor.shape(), inTensor.type()); + auto gradInput = Tensor{inTensor.shape(), inTensor.type()}; { DevicePtr gradInputPtr(gradInput); DevicePtr gradResultPtr(gradOutputTensor); @@ -577,11 +577,11 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( /* incrementCount = */ false ); - auto iDescF32 = TensorDescriptor(inTensorF32); - auto wDescF32 = FilterDescriptor(wtTensorF32); + auto iDescF32 = TensorDescriptor{inTensorF32}; + auto wDescF32 = FilterDescriptor{wtTensorF32}; auto cDescF32 = - ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups); - auto oDescF32 = TensorDescriptor(gradOutputTensorF32); + ConvDescriptor{fl::dtype::f32, px, py, sx, sy, dx, dy, groups}; + auto oDescF32 = TensorDescriptor{gradOutputTensorF32}; // core bwd data computation dataGradBenchmark->audit( [&dataGradOut, @@ -671,10 +671,10 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( // benchmarking suggests input or weight casting should occur, these // descriptors may not be used/new ones with the correct types will be // used instead. - auto iDesc = TensorDescriptor(input); - auto wDesc = FilterDescriptor(weight); - auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); - auto oDesc = TensorDescriptor(gradOutput); + auto iDesc = TensorDescriptor{input}; + auto wDesc = FilterDescriptor{weight}; + auto cDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups}; + auto oDesc = TensorDescriptor{gradOutput}; setDefaultMathType(cDesc, input); @@ -708,10 +708,10 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( Tensor ws; try { - ws = Tensor( + ws = Tensor{ {static_cast(bwdFilterAlgoBestPerf.memory)}, fl::dtype::b8 - ); + }; } catch(const std::exception&) { bwdFilterAlgoBestPerf.algo = kBwdFilterDefaultAlgo; CUDNN_CHECK_ERR( @@ -725,13 +725,13 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( &bwdFilterAlgoBestPerf.memory ) ); - ws = Tensor( + ws = Tensor{ {static_cast(bwdFilterAlgoBestPerf.memory)}, fl::dtype::b8 - ); + }; } - auto gradWeight = Tensor(wtTensor.shape(), wtTensor.type()); + auto gradWeight = Tensor{wtTensor.shape(), wtTensor.type()}; { DevicePtr gradWeightPtr(gradWeight); DevicePtr gradResultPtr(gradOutputTensor); @@ -794,11 +794,11 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( /* incrementCount = */ false ); - auto iDescF32 = TensorDescriptor(inTensorF32); - auto wDescF32 = FilterDescriptor(wtTensorF32); + auto iDescF32 = TensorDescriptor{inTensorF32}; + auto wDescF32 = FilterDescriptor{wtTensorF32}; auto cDescF32 = - ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups); - auto oDescF32 = TensorDescriptor(gradOutputTensorF32); + ConvDescriptor{fl::dtype::f32, px, py, sx, sy, dx, dy, groups}; + auto oDescF32 = TensorDescriptor{gradOutputTensorF32}; // core bwd data computation filterGradBenchmark->audit( [&filterGradOut, @@ -860,13 +860,13 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( const Tensor& bsTensor, const Tensor& gradOutput, const TensorDescriptor& oDesc) -> Tensor { - auto gradBias = Tensor(bsTensor.shape(), bsTensor.type()); + auto gradBias = Tensor{bsTensor.shape(), bsTensor.type()}; { DevicePtr gradBiasPtr(gradBias); DevicePtr gradResultPtr(gradOutput); // ensure cudnn compute stream waits on gradient tensor streams relativeSync(cudnnStream, {gradOutput, gradBias}); - auto bDesc = TensorDescriptor(bsTensor); + auto bDesc = TensorDescriptor{bsTensor}; CUDNN_CHECK_ERR( cudnnConvolutionBackwardBias( hndl, @@ -911,7 +911,7 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( }, /* incrementCount = */ false ); - auto oDescF32 = TensorDescriptor(gradOutputF32); + auto oDescF32 = TensorDescriptor{gradOutputF32}; // Perform bias gradient computation biasGradBenchmark->audit( [&biasGradOut, diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp index 305a6cc..eaac140 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp @@ -16,13 +16,11 @@ namespace fl { std::shared_ptr CudnnAutogradExtension::createBenchmarkOptions() { return std::make_shared( std::make_shared>( - std::vector( - { - KernelMode::F32, - KernelMode::F32_ALLOW_CONVERSION, - KernelMode::F16 - } - ), + std::vector{ + KernelMode::F32, + KernelMode::F32_ALLOW_CONVERSION, + KernelMode::F16 + }, fl::kDynamicBenchmarkDefaultCount ) ); diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp index 82cadcb..804bc6f 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp @@ -25,16 +25,16 @@ struct DeviceHandle { std::shared_ptr stream; explicit DeviceHandle(std::shared_ptr _stream) : cudnnHandle(nullptr), - stream(_stream) { + stream(_stream) { CUDNN_CHECK_ERR(cudnnCreate(&cudnnHandle)); CUDNN_CHECK_ERR(cudnnSetStream(cudnnHandle, stream->handle())); } ~DeviceHandle() { if(cudnnHandle) { -// See https://git.io/fNQnM - sometimes, at exit, the CUDA context -// (or something) is already destroyed by the time a handle gets destroyed -// because of an issue with the destruction order. + // See https://git.io/fNQnM - sometimes, at exit, the CUDA context + // (or something) is already destroyed by the time a handle gets destroyed + // because of an issue with the destruction order. #ifdef NO_CUDNN_DESTROY_HANDLE #else CUDNN_CHECK_ERR(cudnnDestroy(cudnnHandle)); @@ -43,16 +43,16 @@ struct DeviceHandle { } }; -const float kFloatZero = 0.0; -const float kFloatOne = 1.0; +constexpr float kFloatZero = 0.0; +constexpr float kFloatOne = 1.0; -const double kDoubleZero = 0.0; -const double kDoubleOne = 1.0; +constexpr double kDoubleZero = 0.0; +constexpr double kDoubleOne = 1.0; // TODO: move this to CudnnAutogradExtension if we make it a singleton std::unordered_map handles; -const DeviceHandle& getActiveDeviceHandle() { +DeviceHandle const& getActiveDeviceHandle() { auto& manager = fl::DeviceManager::getInstance(); auto& cudaDevice = manager.getActiveDevice(fl::DeviceType::CUDA).impl(); @@ -88,57 +88,42 @@ namespace fl { void cudnnCheckErr(cudnnStatus_t status) { if(status == CUDNN_STATUS_SUCCESS) return; - const char* err = cudnnGetErrorString(status); + char const* err = cudnnGetErrorString(status); switch(status) { - case CUDNN_STATUS_BAD_PARAM: - throw std::invalid_argument(err); - default: - throw std::runtime_error(err); + case CUDNN_STATUS_BAD_PARAM: throw std::invalid_argument(err); + default: throw std::runtime_error(err); } } -cudnnDataType_t cudnnMapToType(const fl::dtype& t) { +cudnnDataType_t cudnnMapToType(fl::dtype const& t) { switch(t) { - case fl::dtype::f16: - return CUDNN_DATA_HALF; - case fl::dtype::f32: - return CUDNN_DATA_FLOAT; - case fl::dtype::f64: - return CUDNN_DATA_DOUBLE; - default: - throw std::invalid_argument("unsupported data type for cuDNN"); + case fl::dtype::f16: return CUDNN_DATA_HALF; + case fl::dtype::f32: return CUDNN_DATA_FLOAT; + case fl::dtype::f64: return CUDNN_DATA_DOUBLE; + default: throw std::invalid_argument("unsupported data type for cuDNN"); } } -cudnnPoolingMode_t cudnnMapToPoolingMode(const PoolingMode mode) { +cudnnPoolingMode_t cudnnMapToPoolingMode(PoolingMode const mode) { switch(mode) { - case PoolingMode::MAX: - return CUDNN_POOLING_MAX; - case PoolingMode::AVG_INCLUDE_PADDING: - return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - case PoolingMode::AVG_EXCLUDE_PADDING: - return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; - default: - throw std::invalid_argument("unsupported pooling mode for cuDNN"); + case PoolingMode::MAX: return CUDNN_POOLING_MAX; + case PoolingMode::AVG_INCLUDE_PADDING: return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + case PoolingMode::AVG_EXCLUDE_PADDING: return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + default: throw std::invalid_argument("unsupported pooling mode for cuDNN"); } } -cudnnRNNMode_t cudnnMapToRNNMode(const RnnMode mode) { +cudnnRNNMode_t cudnnMapToRNNMode(RnnMode const mode) { switch(mode) { - case RnnMode::RELU: - return CUDNN_RNN_RELU; - case RnnMode::TANH: - return CUDNN_RNN_TANH; - case RnnMode::LSTM: - return CUDNN_LSTM; - case RnnMode::GRU: - return CUDNN_GRU; - default: - throw std::invalid_argument("unsupported RNN mode for cuDNN"); + case RnnMode::RELU: return CUDNN_RNN_RELU; + case RnnMode::TANH: return CUDNN_RNN_TANH; + case RnnMode::LSTM: return CUDNN_LSTM; + case RnnMode::GRU: return CUDNN_GRU; + default: throw std::invalid_argument("unsupported RNN mode for cuDNN"); } } -TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) { +TensorDescriptor::TensorDescriptor(fl::dtype const type, Shape const& flDims) { CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor)); cudnnDataType_t cudnntype = cudnnMapToType(type); @@ -165,7 +150,7 @@ TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) { ); } -TensorDescriptor::TensorDescriptor(const Tensor& input) { +TensorDescriptor::TensorDescriptor(Tensor const& input) { CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor)); cudnnDataType_t cudnntype = cudnnMapToType(input.type()); @@ -194,21 +179,19 @@ TensorDescriptor::TensorDescriptor(const Tensor& input) { ); } -TensorDescriptor::~TensorDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyTensorDescriptor(descriptor)); -} +TensorDescriptor::~TensorDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyTensorDescriptor(descriptor)); } TensorDescriptorArray::TensorDescriptorArray( int size, - const fl::dtype type, - const Shape& dims + fl::dtype const type, + Shape const& dims ) { - desc_vec.reserve(size); + _descVec.reserve(size); for(int i = 0; i < size; i++) { - desc_vec.emplace_back(type, dims); - desc_raw_vec.push_back(desc_vec.back().descriptor); + _descVec.emplace_back(type, dims); + _descRawVec.push_back(_descVec.back().descriptor); } - descriptors = desc_raw_vec.data(); + descriptors = _descRawVec.data(); } TensorDescriptorArray::~TensorDescriptorArray() = default; @@ -241,11 +224,9 @@ PoolingDescriptor::PoolingDescriptor( ); } -PoolingDescriptor::~PoolingDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyPoolingDescriptor(descriptor)); -} +PoolingDescriptor::~PoolingDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyPoolingDescriptor(descriptor)); } -FilterDescriptor::FilterDescriptor(const Tensor& input) { +FilterDescriptor::FilterDescriptor(Tensor const& input) { CUDNN_CHECK_ERR(cudnnCreateFilterDescriptor(&descriptor)); cudnnDataType_t cudnntype = cudnnMapToType(input.type()); @@ -267,121 +248,146 @@ FilterDescriptor::FilterDescriptor(const Tensor& input) { ); } -FilterDescriptor::~FilterDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyFilterDescriptor(descriptor)); -} +FilterDescriptor::~FilterDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyFilterDescriptor(descriptor)); } -DropoutDescriptor::DropoutDescriptor(float drop_prob) { +DropoutDescriptor::DropoutDescriptor(float dropProb) { CUDNN_CHECK_ERR(cudnnCreateDropoutDescriptor(&descriptor)); - auto cudnnHandle = getCudnnHandle(); - unsigned long long seed = 0; - size_t state_size; - CUDNN_CHECK_ERR(cudnnDropoutGetStatesSize(cudnnHandle, &state_size)); - auto& dropout_states = getDropoutStates(); - if(dropout_states.isEmpty()) { - dropout_states = - Tensor({static_cast(state_size)}, fl::dtype::b8); - DevicePtr statesraw(dropout_states); + + auto const cudnnHandle = getCudnnHandle(); + constexpr unsigned long long seed = 0; + size_t stateSize; + + CUDNN_CHECK_ERR(cudnnDropoutGetStatesSize(cudnnHandle, &stateSize)); + + auto& dropoutStates = getDropoutStates(); + + if(dropoutStates.isEmpty()) { + dropoutStates = + Tensor{{static_cast(stateSize)}, fl::dtype::b8}; + DevicePtr statesraw(dropoutStates); CUDNN_CHECK_ERR( cudnnSetDropoutDescriptor( descriptor, cudnnHandle, - drop_prob, + dropProb, statesraw.get(), - state_size, + stateSize, seed ) ); - } else { - DevicePtr statesraw(dropout_states); -// See https://git.io/fp9oo for an explanation. -#if CUDNN_VERSION >= 7000 + } + else { + DevicePtr statesraw(dropoutStates); CUDNN_CHECK_ERR( cudnnRestoreDropoutDescriptor( descriptor, cudnnHandle, - drop_prob, + dropProb, statesraw.get(), - state_size, + stateSize, seed ) ); -#else - auto dropout_struct = reinterpret_cast(descriptor); - dropout_struct->dropout = drop_prob; - dropout_struct->nstates = state_size; - dropout_struct->states = statesraw.get(); -#endif } } -DropoutDescriptor::~DropoutDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyDropoutDescriptor(descriptor)); -} +DropoutDescriptor::~DropoutDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyDropoutDescriptor(descriptor)); } Tensor& DropoutDescriptor::getDropoutStates() { - thread_local Tensor dropout_states; - return dropout_states; + thread_local Tensor dropoutStates; + return dropoutStates; } RNNDescriptor::RNNDescriptor( fl::dtype type, - int hidden_size, - int num_layers, + int inputSize, + int hiddenSize, + int numLayers, RnnMode mode, bool bidirectional, DropoutDescriptor& dropout ) { - CUDNN_CHECK_ERR(cudnnCreateRNNDescriptor(&descriptor)); - - auto cudnnHandle = getCudnnHandle(); + CUDNN_CHECK_ERR(cudnnCreateRNNDescriptor(&_handle)); - cudnnRNNInputMode_t in_mode = CUDNN_LINEAR_INPUT; + constexpr auto inMode = CUDNN_LINEAR_INPUT; + constexpr auto algo = CUDNN_RNN_ALGO_STANDARD; - cudnnDirectionMode_t dir = - bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + auto const dir = bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - cudnnRNNMode_t cell = cudnnMapToRNNMode(mode); - cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; - cudnnDataType_t cudnntype = cudnnMapToType(type); + auto const cell = cudnnMapToRNNMode(mode); + auto const dataType = cudnnMapToType(type); -#if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000 CUDNN_CHECK_ERR( - cudnnSetRNNDescriptor( - cudnnHandle, - descriptor, - hidden_size, - num_layers, - dropout.descriptor, - in_mode, - dir, - cell, + //https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn-892/api/index.html#cudnnSetRNNDescriptor_v8 + cudnnSetRNNDescriptor_v8( + _handle, algo, - cudnntype + cell, + CUDNN_RNN_DOUBLE_BIAS, //TODO review; double is default for old cudnn + dir, + inMode, + dataType, + dataType, // math precision + mathType(type), + inputSize, + hiddenSize, + hiddenSize, //projection size (unused) + numLayers, + dropout.descriptor, + 0 ) ); -#else +} + +RNNDescriptor::~RNNDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyRNNDescriptor(_handle)); } + +} + +namespace fl { + + +RNNDataDescriptor::RNNDataDescriptor(fl::dtype type, Shape const& dims) { + create(); + auto const inputSize = dims.ndim() > 0 ? static_cast(dims[0]) : 1; + auto const batchSize = dims.ndim() > 1 ? static_cast(dims[1]) : 1; + auto const maxSeqSize = dims.ndim() > 2 ? static_cast(dims[2]) : 1; + + std::vector seqSizes(batchSize, maxSeqSize); + + set( + type, + inputSize, + maxSeqSize, + seqSizes + ); +} + +RNNDataDescriptor::~RNNDataDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyRNNDataDescriptor(_handle)); } +void RNNDataDescriptor::create() { CUDNN_CHECK_ERR(cudnnCreateRNNDataDescriptor(&_handle)); } +void RNNDataDescriptor::set( + fl::dtype type, + int inputSize, + int maxSeqSize, + std::span sequenceSizes +) const { CUDNN_CHECK_ERR( - cudnnSetRNNDescriptor_v6( - cudnnHandle, - descriptor, - hidden_size, - num_layers, - dropout.descriptor, - in_mode, - dir, - cell, - algo, - cudnntype + cudnnSetRNNDataDescriptor( + _handle, + cudnnMapToType(type), + CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED, + maxSeqSize, + sequenceSizes.size(), //batch size + inputSize, + sequenceSizes.data(), + nullptr //no padding ) ); -#endif } -RNNDescriptor::~RNNDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyRNNDescriptor(descriptor)); } +namespace fl { + ConvDescriptor::ConvDescriptor( fl::dtype type, int px, @@ -413,35 +419,27 @@ ConvDescriptor::ConvDescriptor( CUDNN_CHECK_ERR(cudnnSetConvolutionGroupCount(descriptor, groups)); } -ConvDescriptor::~ConvDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyConvolutionDescriptor(descriptor)); -} +ConvDescriptor::~ConvDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyConvolutionDescriptor(descriptor)); } cudnnHandle_t getCudnnHandle() { return getActiveDeviceHandle().cudnnHandle; } -const CUDAStream& getCudnnStream() { return *getActiveDeviceHandle().stream; } +CUDAStream const& getCudnnStream() { return *getActiveDeviceHandle().stream; } -const void* kOne(const fl::dtype t) { +void const* kOne(fl::dtype const t) { switch(t) { case fl::dtype::f16: - case fl::dtype::f32: - return &kFloatOne; - case fl::dtype::f64: - return &kDoubleOne; - default: - throw std::invalid_argument("unsupported data type for cuDNN"); + case fl::dtype::f32: return &kFloatOne; + case fl::dtype::f64: return &kDoubleOne; + default: throw std::invalid_argument("unsupported data type for cuDNN"); } } -const void* kZero(const fl::dtype t) { +void const* kZero(fl::dtype const t) { switch(t) { case fl::dtype::f16: - case fl::dtype::f32: - return &kFloatZero; - case fl::dtype::f64: - return &kDoubleZero; - default: - throw std::invalid_argument("unsupported data type for cuDNN"); + case fl::dtype::f32: return &kFloatZero; + case fl::dtype::f64: return &kDoubleZero; + default: throw std::invalid_argument("unsupported data type for cuDNN"); } } diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h index fca9969..ad83be3 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h @@ -13,13 +13,15 @@ #include "flashlight/fl/runtime/CUDAStream.h" #include "flashlight/fl/tensor/TensorBase.h" +#include + namespace fl { class TensorDescriptor { public: - explicit TensorDescriptor(const Tensor& a); + explicit TensorDescriptor(Tensor const& a); - TensorDescriptor(const fl::dtype type, const Shape& af_dims); + TensorDescriptor(fl::dtype const type, Shape const& afDims); cudnnTensorDescriptor_t descriptor; ~TensorDescriptor(); @@ -27,19 +29,19 @@ class TensorDescriptor { class TensorDescriptorArray { public: - TensorDescriptorArray(int size, const fl::dtype type, const Shape& dims); + TensorDescriptorArray(int size, fl::dtype const type, Shape const& dims); cudnnTensorDescriptor_t* descriptors; ~TensorDescriptorArray(); private: - std::vector desc_vec; - std::vector desc_raw_vec; + std::vector _descVec; + std::vector _descRawVec; }; class FilterDescriptor { public: - explicit FilterDescriptor(const Tensor& a); + explicit FilterDescriptor(Tensor const& input); cudnnFilterDescriptor_t descriptor; ~FilterDescriptor(); }; @@ -77,7 +79,7 @@ class PoolingDescriptor { class DropoutDescriptor { public: - explicit DropoutDescriptor(float drop_prob); + explicit DropoutDescriptor(float dropProb); cudnnDropoutDescriptor_t descriptor; ~DropoutDescriptor(); @@ -88,28 +90,65 @@ class RNNDescriptor { public: RNNDescriptor( fl::dtype type, - int hidden_size, - int num_layers, + int inputSize, + int hiddenSize, + int numLayers, RnnMode mode, bool bidirectional, DropoutDescriptor& dropout ); - cudnnRNNDescriptor_t descriptor; ~RNNDescriptor(); + +private: + cudnnRNNDescriptor_t _handle = nullptr; + + static constexpr auto mathType(fl::dtype type) { + return type == fl::dtype::f16 ? CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION : CUDNN_DEFAULT_MATH; + } + +public: + /** + * @return descriptor handle + */ + constexpr auto get() const { return _handle; } }; + +class RNNDataDescriptor { +public: + RNNDataDescriptor( + fl::dtype type, + Shape const& dims + ); + + ~RNNDataDescriptor(); + +private: + void create(); + void set(dtype type, int inputSize, int maxSeqSize, std::span sequenceSizes) const; + + cudnnRNNDataDescriptor_t _handle = nullptr; + +public: + /** + * @return descriptor handle + */ + constexpr auto get() const { return _handle; } +}; + + #define CUDNN_CHECK_ERR(expr) ::fl::cudnnCheckErr((expr)) void cudnnCheckErr(cudnnStatus_t status); -cudnnDataType_t cudnnMapToType(const fl::dtype& t); +cudnnDataType_t cudnnMapToType(fl::dtype const& t); -const void* kOne(const fl::dtype t); +void const* kOne(fl::dtype const t); -const void* kZero(const fl::dtype t); +void const* kZero(fl::dtype const t); // TODO: move this to CudnnAutogradExtension if we make it a singleton cudnnHandle_t getCudnnHandle(); -const CUDAStream& getCudnnStream(); +CUDAStream const& getCudnnStream(); } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp index 24b08c8..ef838bb 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp @@ -25,10 +25,10 @@ Tensor CudnnAutogradExtension::pool2d( const PoolingMode mode, std::shared_ptr ) { - auto inDesc = TensorDescriptor(input); + auto inDesc = TensorDescriptor{input}; // init pooling descriptor - auto poolDesc = PoolingDescriptor(wx, wy, sx, sy, px, py, mode); + auto poolDesc = PoolingDescriptor{wx, wy, sx, sy, px, py, mode}; // init output descriptor auto ix = input.dim(0); @@ -36,7 +36,7 @@ Tensor CudnnAutogradExtension::pool2d( auto ox = 1 + (ix + 2 * px - wx) / sx; auto oy = 1 + (iy + 2 * py - wy) / sy; - auto output = Tensor( + auto output = Tensor{ { ox, oy, @@ -44,8 +44,8 @@ Tensor CudnnAutogradExtension::pool2d( input.ndim() < 4 ? 1 : input.dim(3) }, input.type() - ); - auto outDesc = TensorDescriptor(output); + }; + auto outDesc = TensorDescriptor{output}; { DevicePtr inputraw(input); DevicePtr resultraw(output); @@ -90,11 +90,11 @@ Tensor CudnnAutogradExtension::pool2dBackward( const PoolingMode mode, std::shared_ptr ) { - auto i_desc = TensorDescriptor(input); - auto o_desc = TensorDescriptor(poolOutput); - auto p_desc = PoolingDescriptor(wx, wy, sx, sy, px, py, mode); + auto i_desc = TensorDescriptor{input}; + auto o_desc = TensorDescriptor{poolOutput}; + auto p_desc = PoolingDescriptor{wx, wy, sx, sy, px, py, mode}; - auto gradInput = Tensor(input.shape(), input.type()); + auto gradInput = Tensor{input.shape(), input.type()}; auto hndl = getCudnnHandle(); const auto& cudnnStream = getCudnnStream(); diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp index 17b242b..568e4d5 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp @@ -15,58 +15,34 @@ namespace fl { namespace { - size_t getWorkspaceSize( - cudnnHandle_t handle, - const RNNDescriptor& rnnDesc, - const int seqLength, - const TensorDescriptorArray& xDescs - ) { - size_t workspaceSize; - CUDNN_CHECK_ERR( - cudnnGetRNNWorkspaceSize( - handle, - rnnDesc.descriptor, - seqLength, - xDescs.descriptors, - &workspaceSize - ) - ); - return workspaceSize; - } + struct temp_space_sizes { + size_t size; + size_t reserveSize; + }; - size_t getReserveSize( + temp_space_sizes rnnTempSpaceSizes( cudnnHandle_t handle, - const RNNDescriptor& rnnDesc, - const int seqLength, - const TensorDescriptorArray& xDescs + RNNDescriptor const& rnnDescriptor, + RNNDataDescriptor const& xDescriptor, + cudnnForwardMode_t mode ) { - size_t reserveSize; + temp_space_sizes sizes{}; + CUDNN_CHECK_ERR( - cudnnGetRNNTrainingReserveSize( + cudnnGetRNNTempSpaceSizes( handle, - rnnDesc.descriptor, - seqLength, - xDescs.descriptors, - &reserveSize + rnnDescriptor.get(), + mode, + xDescriptor.get(), + &sizes.size, + &sizes.reserveSize ) ); - return reserveSize; - } - void setCudnnRnnMathType(const Tensor& input, const RNNDescriptor& rnnDesc) { - if(input.type() == fl::dtype::f16) - CUDNN_CHECK_ERR( - cudnnSetRNNMatrixMathType( - rnnDesc.descriptor, - CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION - ) - ); - else - CUDNN_CHECK_ERR( - cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH) - ); + return sizes; } + struct CudnnRnnAutogradPayload : public detail::AutogradPayloadData { Tensor reserveSpace; }; @@ -74,15 +50,15 @@ namespace { } // namespace std::tuple CudnnAutogradExtension::rnn( - const Tensor& input, - const Tensor& hiddenStateIn, - const Tensor& cellStateIn, - const Tensor& weights, - const int hiddenSize, - const int numLayers, - const RnnMode mode, - const bool bidirectional, - const float dropProb, + Tensor const& input, + Tensor const& hiddenStateIn, + Tensor const& cellStateIn, + Tensor const& weights, + int const hiddenSize, + int const numLayers, + RnnMode const mode, + bool const bidirectional, + float const dropProb, std::shared_ptr autogradPayload ) { FL_TENSOR_DTYPES_MATCH_CHECK(input, hiddenStateIn, cellStateIn, weights); @@ -93,25 +69,33 @@ std::tuple CudnnAutogradExtension::rnn( autogradPayload->data = payload; Tensor x = input.asContiguousTensor(); + RNNDataDescriptor xDesc{x.type(), x.shape()}; + Tensor hiddenState = hiddenStateIn.asContiguousTensor(); Tensor cellState = cellStateIn.asContiguousTensor(); - DropoutDescriptor dropout(dropProb); - RNNDescriptor rnnDesc( - input.type(), hiddenSize, numLayers, mode, bidirectional, dropout); - setCudnnRnnMathType(input, rnnDesc); + DropoutDescriptor dropout{dropProb}; - auto dims = x.shape(); + auto const& dims = x.shape(); int inputSize = dims[0]; int batchSize = dims.ndim() < 2 ? 1 : dims[1]; int seqLength = dims.ndim() < 3 ? 1 : dims[2]; + + RNNDescriptor rnnDesc{ + input.type(), + inputSize, + hiddenSize, + numLayers, + mode, + bidirectional, + dropout + }; + + int totalLayers = numLayers * (bidirectional ? 2 : 1); int outSize = hiddenSize * (bidirectional ? 2 : 1); - TensorDescriptorArray xDescs( - seqLength, x.type(), {1, 1, inputSize, batchSize}); - if(!hiddenState.isEmpty()) { auto hxDims = hiddenState.shape(); int hxHiddenSize = hxDims[0]; @@ -119,31 +103,31 @@ std::tuple CudnnAutogradExtension::rnn( int hxTotalLayers = hiddenState.ndim() < 3 ? 1 : hxDims[2]; if( - !(hxHiddenSize == hiddenSize && hxBatchSize == batchSize - && hxTotalLayers == totalLayers) + hxHiddenSize != hiddenSize || hxBatchSize != batchSize + || hxTotalLayers != totalLayers ) throw std::invalid_argument("invalid hidden state dims for RNN"); } if( !cellState.isEmpty() - && !(mode == RnnMode::LSTM && cellState.dim(0) == hiddenSize - && cellState.dim(1) == batchSize && cellState.dim(2) == totalLayers) + && (mode != RnnMode::LSTM || cellState.dim(0) != hiddenSize + || cellState.dim(1) != batchSize || cellState.dim(2) != totalLayers) ) throw std::invalid_argument("invalid cell state dims for RNN"); Shape hDims = {1, hiddenSize, batchSize, totalLayers}; - TensorDescriptor hxDesc(x.type(), hDims); - TensorDescriptor cxDesc(x.type(), hDims); + TensorDescriptor hxDesc{x.type(), hDims}; + TensorDescriptor cxDesc{x.type(), hDims}; auto handle = getCudnnHandle(); - const auto& cudnnStream = getCudnnStream(); + auto const& cudnnStream = getCudnnStream(); size_t paramSize; CUDNN_CHECK_ERR( cudnnGetRNNParamsSize( handle, - rnnDesc.descriptor, + rnnDesc._handle, xDescs.descriptors[0], ¶mSize, cudnnMapToType(weights.type()) @@ -155,24 +139,24 @@ std::tuple CudnnAutogradExtension::rnn( ); FilterDescriptor wDesc(weights); - Tensor y({outSize, batchSize, seqLength}, input.type()); - TensorDescriptorArray yDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); + Tensor y{{outSize, batchSize, seqLength}, input.type()}; + TensorDescriptorArray yDesc{seqLength, y.type(), {1, 1, outSize, batchSize}}; - Tensor hy({hiddenSize, batchSize, totalLayers}, x.type()); - TensorDescriptor hyDesc(x.type(), hDims); + Tensor hy{{hiddenSize, batchSize, totalLayers}, x.type()}; + TensorDescriptor hyDesc{x.type(), hDims}; - Tensor cy; + Tensor cy{}; if(mode == RnnMode::LSTM) - cy = Tensor(hy.shape(), x.type()); + cy = Tensor{hy.shape(), x.type()}; - TensorDescriptor cyDesc(x.type(), hDims); + TensorDescriptor cyDesc{x.type(), hDims}; - size_t workspaceSize = getWorkspaceSize(handle, rnnDesc, seqLength, xDescs); - size_t reserveSize = getReserveSize(handle, rnnDesc, seqLength, xDescs); + constexpr auto forwardMode = CUDNN_FWD_MODE_TRAINING; + auto [workspaceSize, reserveSize] = rnnTempSpaceSizes(handle, rnnDesc, xDesc, forwardMode); Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8); // Space must be reused between forward and backward for cuDNN - payload->reserveSpace = Tensor({static_cast(reserveSize)}, fl::dtype::b8); + payload->reserveSpace = Tensor{{static_cast(reserveSize)}, fl::dtype::b8}; { auto contiguousX = x.asContiguousTensor(); @@ -190,34 +174,40 @@ std::tuple CudnnAutogradExtension::rnn( relativeSync( cudnnStream, { - contiguousX, hiddenState, cellState, contiguousWeights, y, hy, cy, - workspace, payload->reserveSpace, + contiguousX, + hiddenState, + cellState, + contiguousWeights, + y, + hy, + cy, + workspace, + payload->reserveSpace, } ); CUDNN_CHECK_ERR( - cudnnRNNForwardTraining( + cudnnRNNForward( handle, - rnnDesc.descriptor, + rnnDesc.get(), seqLength, - xDescs.descriptors, + xDesc.get(), xRaw.get(), - hxDesc.descriptor, - hxRaw.get(), - cxDesc.descriptor, - cxRaw.get(), - wDesc.descriptor, - wRaw.get(), - yDesc.descriptors, - yRaw.get(), - hyDesc.descriptor, - hyRaw.get(), - cyDesc.descriptor, - cyRaw.get(), - workspaceRaw.get(), + yDesc.get(), + yRaw, + hxDesc, + hxRaw, + hyRaw, + cxDesc, + cxRaw, + cyRaw, + weightspace and its size????, + //TEMP continue here + workspaceSize, - reserveSpaceRaw.get(), - reserveSize + workspaceRaw + reserveSize, + reserveSpaceRaw.get() ) ); } @@ -228,17 +218,17 @@ std::tuple CudnnAutogradExtension::rnn( } std::tuple CudnnAutogradExtension::rnnBackward( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const std::shared_ptr gradData, - const Tensor& output, - const int numLayers, - const int hiddenSize, - const RnnMode mode, - const bool bidirectional, - const float dropProb, + Tensor const& input, + Tensor const& hiddenState, + Tensor const& cellState, + Tensor const& weights, + std::shared_ptr const gradData, + Tensor const& output, + int const numLayers, + int const hiddenSize, + RnnMode const mode, + bool const bidirectional, + float const dropProb, std::shared_ptr autogradPayload ) { if(!autogradPayload) @@ -249,7 +239,7 @@ std::tuple CudnnAutogradExtension::rnnBackward( std::static_pointer_cast(autogradPayload->data); auto handle = getCudnnHandle(); - const auto& cudnnStream = getCudnnStream(); + auto const& cudnnStream = getCudnnStream(); auto x = input.asContiguousTensor(); auto& y = output; @@ -261,29 +251,32 @@ std::tuple CudnnAutogradExtension::rnnBackward( int totalLayers = numLayers * (bidirectional ? 2 : 1); int outSize = hiddenSize * (bidirectional ? 2 : 1); - DropoutDescriptor dropout(dropProb); - RNNDescriptor rnnDesc(input.type(), hiddenSize, numLayers, mode, bidirectional, dropout); + DropoutDescriptor dropout{dropProb}; + RNNDescriptor rnnDesc{input.type(), hiddenSize, numLayers, mode, bidirectional, dropout}; setCudnnRnnMathType(input, rnnDesc); - TensorDescriptorArray yDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); - TensorDescriptorArray dyDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); + TensorDescriptorArray yDesc{seqLength, y.type(), {1, 1, outSize, batchSize}}; + TensorDescriptorArray dyDesc{seqLength, y.type(), {1, 1, outSize, batchSize}}; Shape hDims = {1, hiddenSize, batchSize, totalLayers}; - TensorDescriptor dhyDesc(x.type(), hDims); - TensorDescriptor dcyDesc(x.type(), hDims); - TensorDescriptor hxDesc(x.type(), hDims); - TensorDescriptor cxDesc(x.type(), hDims); + TensorDescriptor dhyDesc{x.type(), hDims}; + TensorDescriptor dcyDesc{x.type(), hDims}; + TensorDescriptor hxDesc{x.type(), hDims}; + TensorDescriptor cxDesc{x.type(), hDims}; - Tensor dhx(hiddenState.shape(), hiddenState.type()); - Tensor dcx(cellState.shape(), cellState.type()); - TensorDescriptor dhxDesc(x.type(), hDims); - TensorDescriptor dcxDesc(x.type(), hDims); + Tensor dhx{hiddenState.shape(), hiddenState.type()}; + Tensor dcx{cellState.shape(), cellState.type()}; + TensorDescriptor dhxDesc{x.type(), hDims}; + TensorDescriptor dcxDesc{x.type(), hDims}; FilterDescriptor wDesc(weights); - Tensor dx(input.shape(), input.type()); - TensorDescriptorArray dxDescs( - seqLength, dx.type(), {1, 1, inputSize, batchSize}); + Tensor dx{input.shape(), input.type()}; + TensorDescriptorArray dxDescs{ + seqLength, + dx.type(), + {1, 1, inputSize, batchSize} + }; size_t workspaceSize = getWorkspaceSize(handle, rnnDesc, seqLength, dxDescs); @@ -325,7 +318,7 @@ std::tuple CudnnAutogradExtension::rnnBackward( CUDNN_CHECK_ERR( cudnnRNNBackwardData( handle, - rnnDesc.descriptor, + rnnDesc._handle, seqLength, yDesc.descriptors, yRaw.get(), @@ -357,17 +350,20 @@ std::tuple CudnnAutogradExtension::rnnBackward( if(input.type() == fl::dtype::f16) CUDNN_CHECK_ERR( - cudnnSetRNNMatrixMathType( - rnnDesc.descriptor, - CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION - ) - ); + cudnnSetRNNMatrixMathType( + rnnDesc._handle, + CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION + ) + ); else CUDNN_CHECK_ERR( - cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH) - ); - TensorDescriptorArray xDescs( - seqLength, x.type(), {1, 1, inputSize, batchSize}); + cudnnSetRNNMatrixMathType(rnnDesc._handle, CUDNN_DEFAULT_MATH) + ); + TensorDescriptorArray xDescs{ + seqLength, + x.type(), + {1, 1, inputSize, batchSize} + }; Tensor dw = fl::full(weights.shape(), 0, weights.type()); FilterDescriptor dwDesc(dw); @@ -382,7 +378,7 @@ std::tuple CudnnAutogradExtension::rnnBackward( CUDNN_CHECK_ERR( cudnnRNNBackwardWeights( handle, - rnnDesc.descriptor, + rnnDesc._handle, seqLength, xDescs.descriptors, xRaw.get(),