From 4b710763e9764abb6813f37ae8fb6697b70a46bc Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Sat, 25 Dec 2021 14:59:20 +0800 Subject: [PATCH] valuetuple support monad --- .../framework/actor/data_prepare_actor.cc | 3 +- .../runtime/framework/control_node_parser.cc | 60 +++++++++++++++---- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc index 3fb3a81fb11..4d0a1860a3b 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc @@ -146,7 +146,8 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with type = kNumberTypeInt32; (reinterpret_cast(host_addr.get()))[0] = GetValue(value); } else { - MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString(); + std::string error_info = "Invalid value:" + value->ToString(); + SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } auto type_size = GetTypeByte(TypeIdToType(type)); diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index a20d258c59b..4cd10844fe0 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -34,9 +34,28 @@ bool IsPartial(const AnfNodePtr &node) { } // Check if node is a value node need to create a device tensor. -bool IsFrontValueNode(const AnfNodePtr &node) { +bool IsFrontValueNode(const KernelWithIndex &node_with_index) { + const auto &node = node_with_index.first; + size_t index = node_with_index.second; MS_EXCEPTION_IF_NULL(node); - return node->isa() && (!IsValueNode(node)) && (!IsValueNode(node)); + if (!node->isa() || IsValueNode(node) || IsValueNode(node)) { + return false; + } + + if (!IsValueNode(node)) { + return !HasAbstractMonad(node); + } + + const auto &abstract = node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + const auto &sub_abstracts = tuple_abstract->elements(); + if (sub_abstracts.size() <= index) { + MS_LOG(EXCEPTION) << "Invalid index:" << index << " for tuple value node:" << node->DebugString(); + } + MS_EXCEPTION_IF_NULL(sub_abstracts[index]); + return !sub_abstracts[index]->isa(); } // Get funcgraph in partial structure. @@ -599,7 +618,6 @@ std::vector FetchInputNodeByNode(const AnfNodePtr &node) { } // 5 Other. - size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract); if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) { const auto &get_item_cnode = real_node->cast(); MS_EXCEPTION_IF_NULL(get_item_cnode); @@ -625,9 +643,23 @@ std::vector FetchInputNodeByNode(const AnfNodePtr &node) { real_indexs.begin(), real_indexs.end(), std::back_inserter(results), [&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); }); } - } else { - for (size_t i = 0; i < output_num; ++i) { - results.emplace_back(real_node, i); + return results; + } + + size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract); + if (output_num == 1) { + results.emplace_back(real_node, 0); + return results; + } + + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + const auto &sub_abstracts = tuple_abstract->elements(); + size_t index = 0; + for (const auto &sub_abstract : sub_abstracts) { + MS_EXCEPTION_IF_NULL(sub_abstract); + if (!sub_abstract->isa()) { + results.emplace_back(real_node, index++); } } return results; @@ -1011,8 +1043,13 @@ void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vectormanager()->func_graphs(); for (auto sub_graph : sub_graphs) { if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) { - func_graph_to_device_contexts_[sub_graph] = - std::vector(sub_graph->parameters().size(), default_context); + size_t output_num = 0; + for (const auto ¶meter : sub_graph->parameters()) { + const auto &abstract = parameter->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + output_num += AnfAlgo::GetOutputNumByAbstract(abstract); + } + func_graph_to_device_contexts_[sub_graph] = std::vector(output_num, default_context); } } } @@ -1277,8 +1314,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr for (const auto &formal_to_real_parameter : formal_to_real_parameters_) { for (const auto &real_parameter_with_index : formal_to_real_parameter.second) { - const auto &real_parameter = real_parameter_with_index.first; - if (!IsFrontValueNode(real_parameter)) { + if (!IsFrontValueNode(real_parameter_with_index)) { continue; } @@ -1299,7 +1335,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr for (const auto &front_to_backend_parameters : front_to_backend_parameters_) { const auto &front_node = front_to_backend_parameters.first.first; MS_EXCEPTION_IF_NULL(front_node); - if (IsFrontValueNode(front_node) && (!front_to_backend_parameters.second.empty())) { + if (IsFrontValueNode(front_to_backend_parameters.first) && (!front_to_backend_parameters.second.empty())) { const auto &backend_parameter = front_to_backend_parameters.second.begin()->first; const auto &device_context = front_to_backend_parameters.second.begin()->second; CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context); @@ -1323,7 +1359,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr } for (size_t i = 0; i < input_with_indexs.size(); ++i) { const auto &input_with_index = input_with_indexs[i]; - if (IsFrontValueNode(input_with_index.first) && + if (IsFrontValueNode(input_with_index) && front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) { CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]); front_value_nodes_.emplace(input_with_index, iter->second[i]);