!16569 GPU fix lstm

From: @VectorSL
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
This commit is contained in:
mindspore-ci-bot 2021-05-20 10:01:39 +08:00 committed by Gitee
commit 4f8e4faa08
3 changed files with 3 additions and 3 deletions

View File

@ -148,7 +148,7 @@ class LstmGpuKernel : public GpuKernel {
if (weight_size != weight_size_) {
MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " .";
}
int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1};
int w_dims[3] = {SizeToInt(weight_size_ / sizeof(T)), 1, 1};
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims),
"set w_desc failed");

View File

@ -163,7 +163,7 @@ class LstmGradDataGpuKernel : public GpuKernel {
if (weight_size != weight_size_) {
MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " .";
}
int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1};
int w_dims[3] = {SizeToInt(weight_size_ / sizeof(T)), 1, 1};
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims),
"set w_desc failed");

View File

@ -138,7 +138,7 @@ class LstmGradWeightGpuKernel : public GpuKernel {
if (weight_size != weight_size_) {
MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " .";
}
int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1};
int w_dims[3] = {SizeToInt(weight_size_ / sizeof(T)), 1, 1};
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims),
"set dw_desc failed");