diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h index 4a445630675..086a06c3466 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h @@ -91,13 +91,14 @@ class ROIAlignGradGpuFwdKernel : public GpuKernel { roi_end_mode_ = 1; // Get channels, height & width + batch_size_ = xdiff_shape_[0]; channels_ = xdiff_shape_[1]; height_ = xdiff_shape_[2]; width_ = xdiff_shape_[3]; // Get output_shape - output_shape_ = {roi_rows_, channels_, height_, width_}; - output_size_ = roi_rows_ * channels_ * height_ * width_ * sizeof(T); + output_shape_ = {batch_size_, channels_, height_, width_}; + output_size_ = batch_size_ * channels_ * height_ * width_ * sizeof(T); InitSizeLists(); return true; @@ -120,6 +121,7 @@ class ROIAlignGradGpuFwdKernel : public GpuKernel { int roi_rows_; int roi_cols_; + int batch_size_; int channels_; int height_; int width_;