forked from mindspore-Ecosystem/mindspore
!16569 GPU fix lstm
From: @VectorSL Reviewed-by: @limingqi107,@cristoval Signed-off-by: @cristoval
This commit is contained in:
commit
4f8e4faa08
|
@ -148,7 +148,7 @@ class LstmGpuKernel : public GpuKernel {
|
|||
if (weight_size != weight_size_) {
|
||||
MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " .";
|
||||
}
|
||||
int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1};
|
||||
int w_dims[3] = {SizeToInt(weight_size_ / sizeof(T)), 1, 1};
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims),
|
||||
"set w_desc failed");
|
||||
|
|
|
@ -163,7 +163,7 @@ class LstmGradDataGpuKernel : public GpuKernel {
|
|||
if (weight_size != weight_size_) {
|
||||
MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " .";
|
||||
}
|
||||
int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1};
|
||||
int w_dims[3] = {SizeToInt(weight_size_ / sizeof(T)), 1, 1};
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims),
|
||||
"set w_desc failed");
|
||||
|
|
|
@ -138,7 +138,7 @@ class LstmGradWeightGpuKernel : public GpuKernel {
|
|||
if (weight_size != weight_size_) {
|
||||
MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " .";
|
||||
}
|
||||
int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1};
|
||||
int w_dims[3] = {SizeToInt(weight_size_ / sizeof(T)), 1, 1};
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims),
|
||||
"set dw_desc failed");
|
||||
|
|
Loading…
Reference in New Issue