From 3eae63e4e9807ba400e6b0655babf0d740d8e3ca Mon Sep 17 00:00:00 2001 From: wilfChen Date: Sat, 10 Oct 2020 17:38:38 +0800 Subject: [PATCH] gpu no broadcast kernel dim exceed --- .../kernel_compiler/gpu/math/broadcast_gpu_kernel.h | 12 +++++++++--- .../gpu/math/broadcast_grad_gpu_kernel.h | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index b739969a3f2..db9af174576 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -83,17 +83,23 @@ class BroadcastOpGpuKernel : public GpuKernel { rhs_shape_.resize(MAX_DIMS, 1); output_shape_.resize(MAX_DIMS, 1); for (size_t i = 0; i < shape3.size(); i++) { - output_shape_[i] = shape3[i]; + if (need_broadcast_) { + output_shape_[i] = shape3[i]; + } output_num_ *= shape3[i]; } int lhs_offset = shape3.size() - shape1.size(); for (size_t j = 0; j < shape1.size(); j++) { - lhs_shape_[j + lhs_offset] = shape1[j]; + if (need_broadcast_) { + lhs_shape_[j + lhs_offset] = shape1[j]; + } input1_num_ *= shape1[j]; } int rhs_offset = shape3.size() - shape2.size(); for (size_t k = 0; k < shape2.size(); k++) { - rhs_shape_[k + rhs_offset] = shape2[k]; + if (need_broadcast_) { + rhs_shape_[k + rhs_offset] = shape2[k]; + } input2_num_ *= shape2[k]; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h index 0d7d9bb2cf3..7a7a1761fef 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h @@ -78,17 +78,23 @@ class BroadcastOpGradGpuKernel : public GpuKernel { } for (size_t i = 0; i < shape3.size(); i++) { - dy_shape_[i] = shape3[i]; + if (need_broadcast_) { + dy_shape_[i] = shape3[i]; + } output_num_ *= shape3[i]; } int x1_offset = shape3.size() - shape1.size(); for (size_t i = 0; i < shape1.size(); i++) { - x1_shape_[i + x1_offset] = shape1[i]; + if (need_broadcast_) { + x1_shape_[i + x1_offset] = shape1[i]; + } input1_num_ *= shape1[i]; } int x2_offset = shape3.size() - shape2.size(); for (size_t i = 0; i < shape2.size(); i++) { - x2_shape_[i + x2_offset] = shape2[i]; + if (need_broadcast_) { + x2_shape_[i + x2_offset] = shape2[i]; + } input2_num_ *= shape2[i]; }