forked from mindspore-Ecosystem/mindspore
!806 GPU update kernel bn
Merge pull request !806 from VectorSL/gpu-fix-bn
This commit is contained in:
commit
d5d3bf7565
|
@ -82,6 +82,7 @@ class FusedBatchNormGpuKernel : public GpuKernel {
|
|||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 5) {
|
||||
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5";
|
||||
|
@ -112,11 +113,11 @@ class FusedBatchNormGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch_, channel_, height_, width_),
|
||||
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set x desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch_, channel_, height_, width_),
|
||||
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set y desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
|
|
|
@ -110,7 +110,7 @@ class FusedBatchNormGradGpuKernel : public GpuKernel {
|
|||
cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set dx desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 1, channel_, 1, 1),
|
||||
cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
|
||||
"Set para desc failed");
|
||||
|
||||
InitSizeLists();
|
||||
|
|
Loading…
Reference in New Issue