forked from mindspore-Ecosystem/mindspore
!46714 Refactoring and adding dynamic shape support for CudnnGRU
Merge pull request !46714 from liuluobin/cudnn_gru_dyn_shape
This commit is contained in:
commit
1a8763c65d
|
@ -80,7 +80,7 @@ int DynamicRnnOpBaseMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
|||
ResetResource();
|
||||
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
||||
batch_size_ = static_cast<int>(input_shape[1]);
|
||||
if (batch_size_ == -1) {
|
||||
if (batch_size_ == abstract::Shape::kShapeDimAny) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
seq_lens_.resize(IntToSize(batch_size_));
|
||||
|
@ -164,7 +164,6 @@ template <typename T>
|
|||
bool DynamicRnnOpBaseMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
|
||||
VARIABLE_NOT_USED(stream_ptr);
|
||||
|
||||
auto x_addr = GetDeviceAddress<T>(inputs, inputs_x_index_);
|
||||
|
|
|
@ -16,27 +16,286 @@
|
|||
|
||||
#include "plugin/device/gpu/kernel/rl/gru_gpu_kernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(CudnnGRU,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GruGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(CudnnGRU,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GruGpuKernelMod, half)
|
||||
namespace {
|
||||
constexpr size_t DimOfTensor = 3;
|
||||
constexpr size_t LeastWeightShape = 3;
|
||||
constexpr size_t LeastInputShapeSize = 3;
|
||||
constexpr size_t kInputsXIndex = 0;
|
||||
constexpr size_t kInputsHxIndex = 1;
|
||||
constexpr size_t kInputsWIndex = 2;
|
||||
constexpr size_t kOutputsYIndex = 0;
|
||||
constexpr size_t kOutputsHyIndex = 1;
|
||||
constexpr size_t kOutputsReservedAddrIndex = 2;
|
||||
constexpr size_t kOutputsStatedAddrIndex = 3;
|
||||
constexpr size_t kGruInputsNum = 3;
|
||||
constexpr size_t kGruOutputsNum = 4;
|
||||
constexpr size_t kCudnnGRUInputDim = 3;
|
||||
constexpr size_t kCudnnGRUHDim = 3;
|
||||
constexpr size_t kCudnnGRUWDim = 3;
|
||||
} // namespace
|
||||
bool GruGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGruInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGruOutputsNum, kernel_name_);
|
||||
InitResource();
|
||||
|
||||
if (!GetCudnnDataType(TypeIdLabel(inputs[kInputsXIndex]->GetDtype()), &cudnn_data_type_)) {
|
||||
MS_LOG(ERROR) << kernel_name_ << ": Get cudnn data type failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
|
||||
input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kInputsXIndex).dtype);
|
||||
input_size_ = static_cast<int>(GetValue<int64_t>(base_operator->GetAttr("input_size")));
|
||||
hidden_size_ = static_cast<int>(GetValue<int64_t>(base_operator->GetAttr("hidden_size")));
|
||||
num_layers_ = static_cast<int>(GetValue<int64_t>(base_operator->GetAttr("num_layers")));
|
||||
has_bias_ = GetValue<bool>(base_operator->GetAttr("has_bias"));
|
||||
bidirectional_ = GetValue<bool>(base_operator->GetAttr("bidirectional"));
|
||||
dropout_ = GetValue<float>(base_operator->GetAttr("dropout"));
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
void GruGpuKernelMod::ResetResource() noexcept {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
reserved_size_ = 0;
|
||||
}
|
||||
|
||||
int GruGpuKernelMod::CheckInputsShape(const std::vector<KernelTensorPtr> &inputs) {
|
||||
auto input_shape = inputs[kInputsXIndex]->GetShapeVector(); // (seq_len, batch_size, input_size)
|
||||
auto hx_shape = inputs[kInputsHxIndex]->GetShapeVector(); // (num_directions * num_layers, batch_size, hidden_size)
|
||||
auto w_shape = inputs[kInputsWIndex]->GetShapeVector();
|
||||
if (IsDynamic(input_shape) || IsDynamic(hx_shape) || IsDynamic(w_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_dims", input_shape.size(), kEqual, SizeToLong(kCudnnGRUInputDim),
|
||||
kernel_name_);
|
||||
(void)CheckAndConvertUtils::CheckInteger("hx_dims", hx_shape.size(), kEqual, SizeToLong(kCudnnGRUHDim), kernel_name_);
|
||||
(void)CheckAndConvertUtils::CheckInteger("w_dims", w_shape.size(), kEqual, SizeToLong(kCudnnGRUWDim), kernel_name_);
|
||||
if (input_shape[kIndex1] != hx_shape[kIndex1]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input_shape[1] must be equal to hx_shape[1], but got "
|
||||
<< input_shape[kIndex1] << " and " << hx_shape[kIndex1] << ".";
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_shape[2]", input_shape[kIndex2], kEqual, IntToLong(input_size_),
|
||||
kernel_name_);
|
||||
int64_t real_num_layers = bidirectional_ ? IntToLong(num_layers_ * 2) : IntToLong(num_layers_);
|
||||
(void)CheckAndConvertUtils::CheckInteger("hx_shape[0]", hx_shape[kIndex0], kEqual, real_num_layers, kernel_name_);
|
||||
(void)CheckAndConvertUtils::CheckInteger("hx_shape[2]", hx_shape[kIndex2], kEqual, hidden_size_, kernel_name_);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
int GruGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
ResetResource();
|
||||
auto ret = CheckInputsShape(inputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto input_shape = inputs[kInputsXIndex]->GetShapeVector();
|
||||
seq_len_ = LongToInt(input_shape[0]);
|
||||
batch_size_ = LongToInt(input_shape[1]);
|
||||
|
||||
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
|
||||
cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
|
||||
cudnnRNNMode_t rnn_mode = CUDNN_GRU;
|
||||
cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
|
||||
CreateTensorDescGrp();
|
||||
int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_};
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, hx_dims),
|
||||
"set hx_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnSetTensorNdDescriptorEx(cx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, hx_dims),
|
||||
"set cx_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnSetTensorNdDescriptorEx(hy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, hx_dims),
|
||||
"set hy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnSetTensorNdDescriptorEx(cy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, hx_dims),
|
||||
"set cy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0),
|
||||
"set dropout_desc failed");
|
||||
cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS;
|
||||
#if CUDNN_VERSION < 8000
|
||||
cudnnMathType_t math_type = (cudnn_data_type_ == CUDNN_DATA_HALF) ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnSetRNNDescriptor_v6(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, input_mode, direction,
|
||||
rnn_mode, algo, cudnn_data_type_),
|
||||
"set rnn_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnSetRNNMatrixMathType(rnn_desc_, math_type), "Set math type failed.");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed");
|
||||
#else
|
||||
cudnnMathType_t math_type = (cudnn_data_type_ == CUDNN_DATA_HALF) ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnSetRNNDescriptor_v8(rnn_desc_, algo, rnn_mode, bias_mode, direction, input_mode, cudnn_data_type_,
|
||||
cudnn_data_type_, math_type, input_size_, hidden_size_, hidden_size_, num_layers_,
|
||||
dropout_desc_, 0),
|
||||
"set rnn_desc failed");
|
||||
#endif
|
||||
auto weight_shape = inputs[kInputsWIndex]->GetShapeVector();
|
||||
size_t weight_size = LongToSizeClipNeg(weight_shape[0] * weight_shape[1] * weight_shape[kIndex2]) * input_type_size_;
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnGetRNNParamsSize(handle_, rnn_desc_, x_desc_[0], &weight_size_, cudnn_data_type_), "get weight_size_ failed");
|
||||
if (weight_size != weight_size_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of weight should be equal to " << weight_size_
|
||||
<< " but got " << weight_size;
|
||||
}
|
||||
int w_dims[3] = {SizeToInt(weight_size_ / input_type_size_), 1, 1};
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, DimOfTensor, w_dims), "set w_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &reserved_size_),
|
||||
"get reserve size failed");
|
||||
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool GruGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
VARIABLE_NOT_USED(stream_ptr);
|
||||
|
||||
auto x_addr = GetDeviceAddress<T>(inputs, kInputsXIndex);
|
||||
auto hx_addr = GetDeviceAddress<T>(inputs, kInputsHxIndex);
|
||||
auto cx_addr = nullptr;
|
||||
auto w_addr = GetDeviceAddress<T>(inputs, kInputsWIndex);
|
||||
auto y_addr = GetDeviceAddress<T>(outputs, kOutputsYIndex);
|
||||
auto hy_addr = GetDeviceAddress<T>(outputs, kOutputsHyIndex);
|
||||
auto cy_addr = nullptr;
|
||||
auto reserved_addr = GetDeviceAddress<T>(outputs, kOutputsReservedAddrIndex);
|
||||
auto states_addr = GetDeviceAddress<T>(outputs, kOutputsStatedAddrIndex);
|
||||
void *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);
|
||||
|
||||
if (!states_init_) {
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr,
|
||||
output_size_list_[kOutputsStatedAddrIndex], 0),
|
||||
"set dropout descriptor failed. Possible reasons: the GPU is out of memory.");
|
||||
states_init_ = true;
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnRNNForwardTraining(handle_, rnn_desc_, seq_len_, x_desc_.get(), x_addr, hx_desc_, hx_addr, cx_desc_, cx_addr,
|
||||
w_desc_, w_addr, y_desc_.get(), y_addr, hy_desc_, hy_addr, cy_desc_, cy_addr,
|
||||
workspace_addr, workspace_size_list_[0], reserved_addr, reserved_size_),
|
||||
"launch gru kernel failed");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void GruGpuKernelMod::InitResource() {
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnCreateTensorDescriptor(&cx_desc_), "create cx_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnCreateFilterDescriptor(&w_desc_), "create w_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnCreateTensorDescriptor(&hy_desc_), "create hy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnCreateTensorDescriptor(&cy_desc_), "create cy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed");
|
||||
}
|
||||
|
||||
void GruGpuKernelMod::DestroyResource() noexcept {
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(cy_desc_), "destroy cy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(hy_desc_), "destroy hy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyFilterDescriptor(w_desc_), "destroy w_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(cx_desc_), "destroy cx_desc failed");
|
||||
|
||||
for (size_t i = 0; i < IntToSize(seq_len_); ++i) {
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(x_desc_[i]), "destroy x_desc failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GruGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, GruGpuKernelFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
void GruGpuKernelMod::InitSizeLists() {
|
||||
size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * input_type_size_;
|
||||
|
||||
size_t h_size = 0;
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed");
|
||||
|
||||
input_size_list_.push_back(x_size);
|
||||
input_size_list_.push_back(h_size);
|
||||
input_size_list_.push_back(weight_size_);
|
||||
|
||||
size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * input_type_size_;
|
||||
output_size_list_.push_back(y_size);
|
||||
output_size_list_.push_back(h_size);
|
||||
output_size_list_.push_back(reserved_size_);
|
||||
size_t state_size = 0;
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed");
|
||||
output_size_list_.push_back(state_size);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &workspace_size),
|
||||
"get workspace size failed");
|
||||
workspace_size_list_.push_back(workspace_size);
|
||||
}
|
||||
|
||||
void GruGpuKernelMod::CreateTensorDescGrp() {
|
||||
int x_dims[3]{batch_size_, input_size_, 1};
|
||||
int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1};
|
||||
|
||||
x_desc_ = std::make_unique<cudnnTensorDescriptor_t[]>(seq_len_);
|
||||
y_desc_ = std::make_unique<cudnnTensorDescriptor_t[]>(seq_len_);
|
||||
for (size_t i = 0; i < IntToSize(seq_len_); ++i) {
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnCreateTensorDescriptor(&x_desc_[i]), "create x_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnSetTensorNdDescriptorEx(x_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, x_dims),
|
||||
"set x_desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(
|
||||
cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, y_dims),
|
||||
"set y_desc failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, GruGpuKernelFunc>> GruGpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GruGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&GruGpuKernelMod::LaunchKernel<half>}};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CudnnGRU, GruGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,19 +20,21 @@
|
|||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kIndexTwo = 2;
|
||||
constexpr size_t kIndexThree = 3;
|
||||
constexpr size_t DimOfTensor = 3;
|
||||
constexpr size_t LeastWeightShape = 3;
|
||||
constexpr size_t LeastInputShapeSize = 3;
|
||||
template <typename T>
|
||||
class GruGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class GruGpuKernelMod;
|
||||
using GruGpuKernelFunc = std::function<bool(GruGpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
|
||||
|
||||
class GruGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
GruGpuKernelMod()
|
||||
: batch_size_(0),
|
||||
|
@ -62,218 +64,43 @@ class GruGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
VARIABLE_NOT_USED(stream_ptr);
|
||||
auto x_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto hx_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
auto cx_addr = nullptr;
|
||||
auto w_addr = GetDeviceAddress<T>(inputs, 2);
|
||||
auto y_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
auto hy_addr = GetDeviceAddress<T>(outputs, 1);
|
||||
auto cy_addr = nullptr;
|
||||
auto reserved_addr = GetDeviceAddress<T>(outputs, 2);
|
||||
auto states_addr = GetDeviceAddress<T>(outputs, 3);
|
||||
void *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);
|
||||
|
||||
if (!states_init_) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, output_size_list_[kIndexThree], 0),
|
||||
"set dropout descriptor failed. Possible reasons: the GPU is out of memory.");
|
||||
states_init_ = true;
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnRNNForwardTraining(handle_, rnn_desc_, seq_len_, x_desc_.get(), x_addr, hx_desc_, hx_addr, cx_desc_, cx_addr,
|
||||
w_desc_, w_addr, y_desc_.get(), y_addr, hy_desc_, hy_addr, cy_desc_, cy_addr,
|
||||
workspace_addr, workspace_size_list_[0], reserved_addr, reserved_size_),
|
||||
"launch gru kernel failed");
|
||||
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
InitResource();
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (input_shape.size() < LeastInputShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input cannot be less than 3, but got "
|
||||
<< input_shape.size();
|
||||
}
|
||||
seq_len_ = LongToInt(input_shape[0]);
|
||||
batch_size_ = LongToInt(input_shape[1]);
|
||||
input_size_ = LongToInt(input_shape[kIndexTwo]);
|
||||
|
||||
input_size_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "input_size"));
|
||||
hidden_size_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "hidden_size"));
|
||||
num_layers_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "num_layers"));
|
||||
has_bias_ = GetAttr<bool>(kernel_node, "has_bias");
|
||||
bidirectional_ = GetAttr<bool>(kernel_node, "bidirectional");
|
||||
dropout_ = GetAttr<float>(kernel_node, "dropout");
|
||||
|
||||
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
|
||||
cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
|
||||
cudnnRNNMode_t rnn_mode = CUDNN_GRU;
|
||||
cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
|
||||
CreateTensorDescGrp();
|
||||
int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_};
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, hx_dims),
|
||||
"set hx_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnSetTensorNdDescriptorEx(cx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, hx_dims),
|
||||
"set cx_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnSetTensorNdDescriptorEx(hy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, hx_dims),
|
||||
"set hy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnSetTensorNdDescriptorEx(cy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, hx_dims),
|
||||
"set cy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0),
|
||||
"set dropout_desc failed");
|
||||
cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS;
|
||||
#if CUDNN_VERSION < 8000
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnSetRNNDescriptor_v6(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
|
||||
input_mode, direction, rnn_mode, algo, cudnn_data_type_),
|
||||
"set rnn_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed");
|
||||
#else
|
||||
cudnnMathType_t math_type = (cudnn_data_type_ == CUDNN_DATA_HALF) ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnSetRNNDescriptor_v8(rnn_desc_, algo, rnn_mode, bias_mode, direction, input_mode,
|
||||
cudnn_data_type_, cudnn_data_type_, math_type, input_size_,
|
||||
hidden_size_, hidden_size_, num_layers_, dropout_desc_, 0),
|
||||
"set rnn_desc failed");
|
||||
#endif
|
||||
auto weight_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "weight");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (weight_shape.size() < LeastWeightShape) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of weight cannot be less than 3, but got "
|
||||
<< weight_shape.size();
|
||||
}
|
||||
size_t weight_size = LongToSizeClipNeg(weight_shape[0] * weight_shape[1] * weight_shape[kIndexTwo]) * sizeof(T);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnGetRNNParamsSize(handle_, rnn_desc_, x_desc_[0], &weight_size_, cudnn_data_type_),
|
||||
"get weight_size_ failed");
|
||||
if (weight_size != weight_size_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the size of weight should be equal to " << weight_size_
|
||||
<< " but got " << weight_size;
|
||||
}
|
||||
int w_dims[3] = {SizeToInt(weight_size_ / sizeof(T)), 1, 1};
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, DimOfTensor, w_dims),
|
||||
"set w_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &reserved_size_),
|
||||
"get reserve size failed");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
void CreateTensorDescGrp() {
|
||||
int x_dims[3]{batch_size_, input_size_, 1};
|
||||
int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1};
|
||||
|
||||
x_desc_ = std::make_unique<cudnnTensorDescriptor_t[]>(seq_len_);
|
||||
y_desc_ = std::make_unique<cudnnTensorDescriptor_t[]>(seq_len_);
|
||||
for (size_t i = 0; i < IntToSize(seq_len_); ++i) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_[i]), "create x_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensorNdDescriptorEx(x_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, x_dims),
|
||||
"set x_desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, DimOfTensor, y_dims),
|
||||
"set y_desc failed");
|
||||
}
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
void DestroyResource() noexcept override {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyDropoutDescriptor(dropout_desc_),
|
||||
"destroy dropout_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(cy_desc_), "destroy cy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(hy_desc_), "destroy hy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyFilterDescriptor(w_desc_), "destroy w_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(cx_desc_), "destroy cx_desc failed");
|
||||
|
||||
for (size_t i = 0; i < IntToSize(seq_len_); ++i) {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_[i]), "destroy x_desc failed");
|
||||
}
|
||||
}
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
void DestroyResource() noexcept override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&cx_desc_), "create cx_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateFilterDescriptor(&w_desc_), "create w_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&hy_desc_), "create hy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&cy_desc_), "create cy_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateDropoutDescriptor(&dropout_desc_),
|
||||
"create dropout_desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed");
|
||||
}
|
||||
void InitSizeLists() override {
|
||||
size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T);
|
||||
|
||||
size_t h_size = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed");
|
||||
|
||||
input_size_list_.push_back(x_size);
|
||||
input_size_list_.push_back(h_size);
|
||||
input_size_list_.push_back(weight_size_);
|
||||
|
||||
size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T);
|
||||
output_size_list_.push_back(y_size);
|
||||
output_size_list_.push_back(h_size);
|
||||
output_size_list_.push_back(reserved_size_);
|
||||
size_t state_size = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnDropoutGetStatesSize(handle_, &state_size),
|
||||
"get dropout states size failed");
|
||||
output_size_list_.push_back(state_size);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &workspace_size),
|
||||
"get workspace size failed");
|
||||
workspace_size_list_.push_back(workspace_size);
|
||||
}
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
void InitResource() override;
|
||||
|
||||
private:
|
||||
void CreateTensorDescGrp();
|
||||
void ResetResource() noexcept;
|
||||
void InitSizeLists();
|
||||
int CheckInputsShape(const std::vector<KernelTensorPtr> &inputs);
|
||||
|
||||
int batch_size_;
|
||||
int seq_len_;
|
||||
int input_size_;
|
||||
int hidden_size_;
|
||||
int num_layers_;
|
||||
|
||||
bool has_bias_;
|
||||
bool bidirectional_;
|
||||
bool states_init_;
|
||||
bool is_null_input_;
|
||||
float dropout_;
|
||||
|
||||
size_t weight_size_;
|
||||
size_t reserved_size_;
|
||||
size_t input_type_size_; // sizeof(T)
|
||||
GruGpuKernelFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, GruGpuKernelFunc>> func_list_;
|
||||
|
||||
// input desc
|
||||
std::unique_ptr<cudnnTensorDescriptor_t[]> x_desc_;
|
||||
|
@ -285,8 +112,8 @@ class GruGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
cudnnTensorDescriptor_t hy_desc_;
|
||||
cudnnTensorDescriptor_t cy_desc_;
|
||||
cudnnRNNDescriptor_t rnn_desc_;
|
||||
|
||||
cudnnHandle_t handle_;
|
||||
cudaStream_t cuda_stream_;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -638,6 +638,7 @@ GVAR_DEF(PrimitivePtr, kPrimUniqueGrad, std::make_shared<Primitive>("UniqueGrad"
|
|||
GVAR_DEF(PrimitivePtr, kPrimUniqueConsecutive, std::make_shared<Primitive>("UniqueConsecutive"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimExtractImagePatches, std::make_shared<Primitive>("ExtractImagePatches"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDynamicRNN, std::make_shared<Primitive>("DynamicRNN"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCudnnGRU, std::make_shared<Primitive>("CudnnGRU"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimGRUV2, std::make_shared<Primitive>("GRUV2"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimLSTMV2, std::make_shared<Primitive>("LSTMV2"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDynamicRNNGrad, std::make_shared<Primitive>("DynamicRNNGrad"));
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/cudnn_gru.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kCudnnGRUInputDim = 3;
|
||||
constexpr size_t kCudnnGRUHDim = 3;
|
||||
constexpr int64_t kCudnnGRUInputsNum = 3;
|
||||
constexpr auto kCudnnGRURealNumLayers = "real_num_layers";
|
||||
constexpr auto kCudnnGRURealHiddenSize = "real_hidden_size";
|
||||
|
||||
std::unordered_map<std::string, int64_t> CudnnGRUGetAttrMap(const PrimitivePtr &primitive) {
|
||||
std::unordered_map<std::string, int64_t> attr_map;
|
||||
auto input_size_ptr = primitive->GetAttr(kInputSize);
|
||||
MS_EXCEPTION_IF_NULL(input_size_ptr);
|
||||
attr_map[kInputSize] = GetValue<int64_t>(input_size_ptr);
|
||||
|
||||
auto hidden_size_ptr = primitive->GetAttr(kHiddenSize);
|
||||
MS_EXCEPTION_IF_NULL(hidden_size_ptr);
|
||||
auto hidden_size = GetValue<int64_t>(hidden_size_ptr);
|
||||
attr_map[kHiddenSize] = hidden_size;
|
||||
|
||||
auto num_layers_ptr = primitive->GetAttr(kNumLayers);
|
||||
MS_EXCEPTION_IF_NULL(num_layers_ptr);
|
||||
auto num_layers = GetValue<int64_t>(num_layers_ptr);
|
||||
|
||||
auto bidirectional_ptr = primitive->GetAttr(kBidirectional);
|
||||
MS_EXCEPTION_IF_NULL(bidirectional_ptr);
|
||||
auto bidirectional = GetValue<bool>(bidirectional_ptr);
|
||||
|
||||
auto real_hidden_size = bidirectional ? hidden_size * 2 : hidden_size;
|
||||
auto real_num_layers = bidirectional ? num_layers * 2 : num_layers;
|
||||
attr_map[kCudnnGRURealNumLayers] = real_num_layers;
|
||||
attr_map[kCudnnGRURealHiddenSize] = real_hidden_size;
|
||||
return attr_map;
|
||||
}
|
||||
|
||||
abstract::TupleShapePtr CudnnGRUInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kCudnnGRUInputsNum, op_name);
|
||||
auto attr_map = CudnnGRUGetAttrMap(primitive);
|
||||
auto input_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
|
||||
auto h_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
|
||||
auto input_shape = input_shape_map[kShape]; // (seq_len, batch_size, input_size)
|
||||
auto h_shape = h_shape_map[kShape]; // (real_num_layers, batch_size, hidden_size)
|
||||
|
||||
int64_t seq_len = abstract::Shape::kShapeDimAny;
|
||||
int64_t batch_size = abstract::Shape::kShapeDimAny;
|
||||
if (!IsDynamicRank(input_shape)) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_dims", SizeToLong(input_shape.size()), kEqual, kCudnnGRUInputDim,
|
||||
op_name);
|
||||
seq_len = input_shape[kInputIndex0];
|
||||
batch_size = input_shape[kInputIndex1];
|
||||
if (input_shape[kInputIndex2] != abstract::Shape::kShapeDimAny) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_shape[2]", input_shape[kInputIndex2], kEqual,
|
||||
attr_map[kInputSize]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!IsDynamicRank(h_shape)) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_dims", SizeToLong(h_shape.size()), kEqual, kCudnnGRUHDim, op_name);
|
||||
if (h_shape[kInputIndex0] != abstract::Shape::kShapeDimAny) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[0]", h_shape[kInputIndex0], kEqual,
|
||||
attr_map[kCudnnGRURealNumLayers], op_name);
|
||||
}
|
||||
if (h_shape[kInputIndex1] != abstract::Shape::kShapeDimAny) {
|
||||
if (batch_size != abstract::Shape::kShapeDimAny && batch_size != h_shape[kInputIndex1]) {
|
||||
MS_LOG(EXCEPTION) << "For " << op_name << ", input_shape[1] and h_shape[1] should be -1 or equal, but got "
|
||||
<< batch_size << " and " << h_shape[kInputIndex1] << ".";
|
||||
}
|
||||
batch_size = h_shape[kInputIndex1];
|
||||
}
|
||||
if (h_shape[kInputIndex2] != abstract::Shape::kShapeDimAny) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[2]", h_shape[kInputIndex2], kEqual, attr_map[kHiddenSize],
|
||||
op_name);
|
||||
}
|
||||
}
|
||||
|
||||
auto output_shape_ptr =
|
||||
std::make_shared<abstract::Shape>(ShapeVector{seq_len, batch_size, attr_map[kCudnnGRURealHiddenSize]});
|
||||
auto hn_shape_ptr =
|
||||
std::make_shared<abstract::Shape>(ShapeVector{attr_map[kCudnnGRURealNumLayers], batch_size, attr_map[kHiddenSize]});
|
||||
auto reserve_shape_ptr = std::make_shared<abstract::Shape>(ShapeVector{1, 1});
|
||||
auto state_shape_ptr = std::make_shared<abstract::Shape>(ShapeVector{1, 1});
|
||||
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{output_shape_ptr, hn_shape_ptr, reserve_shape_ptr, state_shape_ptr});
|
||||
}
|
||||
|
||||
TuplePtr CudnnGRUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set valid_types = {kFloat16, kFloat32};
|
||||
auto op_name = prim->name();
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("input", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("h", input_args[kInputIndex1]->BuildType());
|
||||
(void)types.emplace("w", input_args[kInputIndex2]->BuildType());
|
||||
auto type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type, type, type});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr CudnnGRUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kCudnnGRUInputsNum, primitive->name());
|
||||
auto infer_shape = CudnnGRUInferShape(primitive, input_args);
|
||||
auto infer_type = CudnnGRUInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(CudnnGRU, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CudnnGRU, prim::kPrimCudnnGRU, CudnnGRUInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CUDNN_GRU_H_
|
||||
#define MINDSPORE_CORE_OPS_CUDNN_GRU_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCudnnGRU = "CudnnGRU";
|
||||
class MIND_API CudnnGRU : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CudnnGRU);
|
||||
|
||||
/// \brief Constructor.
|
||||
CudnnGRU() : BaseOperator(kNameCudnnGRU) { InitIOName({"input", "h", "w"}, {"output", "h_n", "reserve", "state"}); }
|
||||
};
|
||||
AbstractBasePtr CudnnGRUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimCudnnGRUPtr = std::shared_ptr<CudnnGRU>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CUDNN_GRU_H_
|
|
@ -75,10 +75,10 @@ abstract::TupleShapePtr GRUV2InferShape(const PrimitivePtr &primitive, const std
|
|||
auto x_shape = x_shape_map[kShape];
|
||||
auto h_shape = h_shape_map[kShape];
|
||||
auto seq_lengths_shape = seq_lengths_shape_map[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input dims", SizeToLong(x_shape.size()), kEqual, kGRUV2InputSize, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h dims", SizeToLong(h_shape.size()), kEqual, kGRUV2HSize, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("seq_lengths dims", SizeToLong(seq_lengths_shape.size()), kEqual,
|
||||
kGRUV2SeqLenSize, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input dims", x_shape.size(), kEqual, kGRUV2InputSize, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h dims", h_shape.size(), kEqual, kGRUV2HSize, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("seq_lengths dims", seq_lengths_shape.size(), kEqual, kGRUV2SeqLenSize,
|
||||
op_name);
|
||||
auto max_seq_lengths = x_shape[0];
|
||||
auto batch_size = x_shape[1];
|
||||
auto input_size = attr_map[kInputSize];
|
||||
|
|
|
@ -403,7 +403,7 @@ class LSTMV2(Primitive):
|
|||
validator.check_value_type("is_train", is_train, (bool,), self.name)
|
||||
|
||||
|
||||
class CudnnGRU(PrimitiveWithInfer):
|
||||
class CudnnGRU(Primitive):
|
||||
"""
|
||||
Performs the Stacked GRU (Gated Recurrent Unit) on the input.
|
||||
|
||||
|
@ -421,7 +421,7 @@ class CudnnGRU(PrimitiveWithInfer):
|
|||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or
|
||||
(batch_size, seq_len, `input_size`).
|
||||
- **h** (tuple) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
- **h** (Tensor) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
- **w** (Tensor) - The input tensor which states for weights.
|
||||
|
||||
Outputs:
|
||||
|
@ -486,28 +486,6 @@ class CudnnGRU(PrimitiveWithInfer):
|
|||
else:
|
||||
self.num_directions = 1
|
||||
|
||||
def infer_shape(self, x_shape, h_shape, w_shape):
|
||||
validator.check_equal_int(len(x_shape), 3, "x rank", self.name)
|
||||
validator.check_equal_int(x_shape[2], self.input_size, "x[2]", self.name)
|
||||
|
||||
validator.check_equal_int(len(h_shape), 3, "h rank", self.name)
|
||||
|
||||
validator.check_int(h_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h[0]", self.name)
|
||||
validator.check_equal_int(h_shape[1], x_shape[1], "h[1]", self.name)
|
||||
validator.check_int(h_shape[2], self.hidden_size, Rel.EQ, "h[2]", self.name)
|
||||
|
||||
y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions)
|
||||
|
||||
# set arbitrary shape for reserved space
|
||||
reserved_shape = (1, 1)
|
||||
state_shape = (1, 1)
|
||||
return y_shape, h_shape, reserved_shape, state_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, h_dtype, w_dtype):
|
||||
args = {'x': x_dtype, 'h': h_dtype, 'w': w_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
|
||||
return x_dtype, x_dtype, x_dtype, x_dtype
|
||||
|
||||
|
||||
class PriorityReplayBufferCreate(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue