forked from mindspore-Ecosystem/mindspore
!47447 Bugfix for the lstm gpu kernels
Merge pull request !47447 from chengang/bugfix_lstm
This commit is contained in:
commit
517021f548
|
@ -65,6 +65,7 @@ int LstmGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve
|
|||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
DestroyTensorDescGrp();
|
||||
|
||||
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input") ||
|
||||
|
|
|
@ -84,7 +84,6 @@ class LstmGpuKernelMod : public NativeGpuKernelMod {
|
|||
}
|
||||
|
||||
void CreateTensorDescGrp() {
|
||||
DestroyTensorDescGrp();
|
||||
int x_dims[3]{batch_size_, input_size_, 1};
|
||||
int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1};
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ int LstmGradDataGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
DestroyTensorDescGrp();
|
||||
|
||||
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input");
|
||||
|
|
|
@ -134,7 +134,6 @@ class LstmGradDataGpuKernelMod : public NativeGpuKernelMod {
|
|||
}
|
||||
|
||||
void CreateTensorDescGrp() {
|
||||
DestroyTensorDescGrp();
|
||||
int x_dims[3]{batch_size_, input_size_, 1};
|
||||
int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1};
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ int LstmGradWeightGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, con
|
|||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
DestroyTensorDescGrp();
|
||||
|
||||
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input");
|
||||
|
|
|
@ -111,7 +111,6 @@ class LstmGradWeightGpuKernelMod : public NativeGpuKernelMod {
|
|||
}
|
||||
|
||||
void CreateTensorDescGrp() {
|
||||
DestroyTensorDescGrp();
|
||||
int x_dims[3]{batch_size_, input_size_, 1};
|
||||
int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1};
|
||||
|
||||
|
|
Loading…
Reference in New Issue