From 0775db0940dae36ae224da35a0a0100182df383e Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Sat, 29 Aug 2020 14:39:51 +0800 Subject: [PATCH] modify the format info of tensorAdd --- .../backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index 5c7927a3210..cde22769ea0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -56,9 +56,9 @@ class BroadcastOpGpuKernel : public GpuKernel { } bool Init(const CNodePtr &kernel_node) override { GetOpType(kernel_node); - auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); + auto shape1 = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto shape2 = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + auto shape3 = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); need_broadcast_ = IsBroadcast(shape1, shape2); if (need_broadcast_ && shape1.size() > 7) { MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";