!47975 Bugfix for the lstm gpu kernels

Merge pull request !47975 from chengang/fix_gpu_lstm_2
This commit is contained in:
i-robot 2023-01-17 11:50:33 +00:00 committed by Gitee
commit 71ceaf8653
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 3 additions and 3 deletions

View File

@ -65,6 +65,7 @@ int LstmGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
DestroyTensorDescGrp();
auto input_shape = inputs[kIndex0]->GetShapeVector();
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input") ||

View File

@ -84,7 +84,6 @@ class LstmGpuKernelMod : public NativeGpuKernelMod {
}
void CreateTensorDescGrp() {
DestroyTensorDescGrp();
int x_dims[3]{batch_size_, input_size_, 1};
int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1};

View File

@ -56,6 +56,7 @@ int LstmGradDataGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
DestroyTensorDescGrp();
auto input_shape = inputs[kIndex0]->GetShapeVector();
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input");

View File

@ -134,7 +134,6 @@ class LstmGradDataGpuKernelMod : public NativeGpuKernelMod {
}
void CreateTensorDescGrp() {
DestroyTensorDescGrp();
int x_dims[3]{batch_size_, input_size_, 1};
int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1};

View File

@ -56,6 +56,7 @@ int LstmGradWeightGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, con
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
DestroyTensorDescGrp();
auto input_shape = inputs[kIndex0]->GetShapeVector();
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input");

View File

@ -111,7 +111,6 @@ class LstmGradWeightGpuKernelMod : public NativeGpuKernelMod {
}
void CreateTensorDescGrp() {
DestroyTensorDescGrp();
int x_dims[3]{batch_size_, input_size_, 1};
int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1};