From 0aaa31764e6ac058be82dd24de3a8a248d7bfe93 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Fri, 19 Feb 2021 15:08:11 +0800 Subject: [PATCH] Do not get shape for monad type --- .../ccsrc/frontend/parallel/graph_util/node_info.cc | 13 ++++++++----- mindspore/ccsrc/frontend/parallel/step_parallel.cc | 8 +++++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc index a3d04c56e9f..c19d334c236 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc @@ -78,12 +78,13 @@ std::vector ExtractInputParameterByNode(const CNodePtr &node) { } for (size_t i = 1; i < node_inputs.size(); ++i) { auto input = GetRealInput(node_inputs[i]); - + if (HasAbstractMonad(input)) { + continue; + } if (input->isa()) { auto input_parameter = input->cast(); is_parameter.push_back(ParameterRequireGrad(input_parameter)); - } else if ((input->isa() && !HasAbstractMonad(input)) || IsValueNode(input) || - IsValueNode(input)) { + } else if (input->isa() || IsValueNode(input) || IsValueNode(input)) { is_parameter.push_back(false); } } @@ -174,6 +175,9 @@ std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { // extract input element length for (auto &input : node_inputs) { + if (HasAbstractMonad(input)) { + continue; + } if (IsValueNode(input)) { auto func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -182,8 +186,7 @@ std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; } inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); - } else if ((input->isa() && !HasAbstractMonad(input)) || input->isa() || - IsValueNode(input)) { + } else if (input->isa() || input->isa() || IsValueNode(input)) { // extract input shape from parameter and apply node inputs_type_len.push_back(GetInputsTypeLen(input)); } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index d5bb19c3f1b..d528e9cf2ba 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1108,7 +1108,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); for (auto input : node->inputs()) { - if (input->isa() && HasAbstractMonad(input)) { + if (HasAbstractMonad(input)) { node_size--; } } @@ -1414,6 +1414,9 @@ std::vector ExtractShape(const CNodePtr &node) { for (size_t i = 1; i < inputs_size; ++i) { Shapes input_shapes; AnfNodePtr input = all_inputs[i]; + if (HasAbstractMonad(input)) { + continue; + } if (IsValueNode(input)) { auto func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -1424,8 +1427,7 @@ std::vector ExtractShape(const CNodePtr &node) { std::pair node_pair = std::make_pair(node, SizeToLong(i)); g_RefMap[parameters[0]] = node_pair; input_shapes = GetRefKeyNodeShape(input, func_graph); - } else if ((input->isa() && !HasAbstractMonad(input)) || IsValueNode(input) || - input->isa() || + } else if (input->isa() || IsValueNode(input) || input->isa() || ((IsValueNode(input) || IsValueNode(input)) && (inputs_size == 2))) { input_shapes = GetNodeShape(input); } else {