modify the format info of tensorAdd

This commit is contained in:
limingqi107 2020-08-29 14:39:51 +08:00
parent 1b6f628bd0
commit 109e2e9bcc
1 changed files with 3 additions and 3 deletions

View File

@ -56,9 +56,9 @@ class BroadcastOpGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
GetOpType(kernel_node);
auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0);
auto shape1 = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto shape2 = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
auto shape3 = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
need_broadcast_ = IsBroadcast(shape1, shape2);
if (need_broadcast_ && shape1.size() > 7) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";