diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 91695186181..e5890939dc6 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -202,6 +202,14 @@ std::vector GetAllOutputWithIndexInner(const AnfNodePtr &node) if (AnfUtils::IsRealCNodeKernel(node)) { outputs_num = AnfAlgo::GetOutputTensorNum(node); } + + // If the node is a call, the outputs num should get from the abstract. + if (AnfAlgo::IsCallNode(node)) { + auto abstract = node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + outputs_num = AnfAlgo::GetOutputNumByAbstract(abstract); + } + // The output may be the tuple of node, so need visit all the outputs of node. for (size_t i = 0; i < outputs_num; ++i) { // Maybe this scene: tupleGetItem + depend + makeTuple, can be done correctly in VisitKernelWithReturnType. @@ -371,11 +379,21 @@ std::vector AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node size_t AnfRuntimeAlgorithm::GetOutputNumByAbstract(const AbstractBasePtr &node_abstract) { MS_EXCEPTION_IF_NULL(node_abstract); + size_t result = 0; + if (node_abstract->isa()) { + auto csr_tensor_abstract = node_abstract->cast(); + MS_EXCEPTION_IF_NULL(csr_tensor_abstract); + result += GetOutputNumByAbstract(csr_tensor_abstract->indptr()); + result += GetOutputNumByAbstract(csr_tensor_abstract->indices()); + result += GetOutputNumByAbstract(csr_tensor_abstract->values()); + result += GetOutputNumByAbstract(csr_tensor_abstract->dense_shape()); + return result; + } + if (!node_abstract->isa()) { return 1; } - size_t result = 0; auto tuple_abstract = node_abstract->cast(); MS_EXCEPTION_IF_NULL(tuple_abstract); const auto &sub_abstracts = tuple_abstract->elements(); diff --git a/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc b/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc index 0564bfaf417..23ad4b983b3 100644 --- a/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc @@ -134,14 +134,14 @@ void ExitActor::CopyDeviceAddress(OpContext *const context) { std::vector new_device_tensors; for (size_t i = 0; i < input_device_tensors_.size(); ++i) { auto input_device_tensor = input_device_tensors_[i]; - MS_EXCEPTION_IF_NULL(input_device_tensor); - const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex(); - MS_EXCEPTION_IF_NULL(node_with_index.first); if (!is_need_copy_device_tensors_[i]) { new_device_tensors.emplace_back(input_device_tensor); continue; } + MS_EXCEPTION_IF_NULL(input_device_tensor); + const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex(); + MS_EXCEPTION_IF_NULL(node_with_index.first); MS_EXCEPTION_IF_NULL(device_contexts_[i]); // Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs. auto new_device_tensor = device_contexts_[i]->CreateDeviceAddress( diff --git a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc index 913cf259cf7..3fb3a81fb11 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc @@ -142,6 +142,9 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with } else if (value->isa()) { type = kNumberTypeInt64; (reinterpret_cast(host_addr.get()))[0] = GetValue(value); + } else if (value->isa()) { + type = kNumberTypeInt32; + (reinterpret_cast(host_addr.get()))[0] = GetValue(value); } else { MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString(); } diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index ccbf2e5ec3e..0dde2a4cba1 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -16,6 +16,7 @@ #include "runtime/framework/control_node_parser.h" #include "runtime/framework/actor/actor_common.h" +#include "utils/convert_utils.h" #include "abstract/utils.h" #include "ir/tensor.h" @@ -31,6 +32,12 @@ bool IsPartial(const AnfNodePtr &node) { AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial); } +// Check if node is a value node need to create a device tensor. +bool IsFrontValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + return node->isa() && (!IsValueNode(node)) && (!IsValueNode(node)); +} + // Get funcgraph in partial structure. // Depth represents the number of layers of the call. When the first input of the call node is a call node, // the funcgraph in the return value of the inner call needs to be returned. @@ -195,6 +202,64 @@ KernelWithIndex FetchRealInputNode(const KernelWithIndex &node_with_index) { return {}; } +// Fetch all the output index in the sub-abstract of abstract. +std::set FetchRealIndexByAbstract(const AbstractBasePtr &abstract, size_t index) { + MS_EXCEPTION_IF_NULL(abstract); + AbstractBasePtr dst_abstract = abstract; + size_t pre_abstract_num = 0; + + // Fetch the dest abstract by index, and the abstracts num before the dest abstract. + if (abstract->isa()) { + auto csr_abs = abstract->cast(); + MS_EXCEPTION_IF_NULL(csr_abs); + switch (index) { + case kCsrTensorIndPtrIndex: + dst_abstract = csr_abs->indptr(); + pre_abstract_num = kCsrTensorIndPtrIndex; + break; + case kCsrTensorIndicesIndex: + dst_abstract = csr_abs->indices(); + pre_abstract_num = kCsrTensorIndicesIndex; + break; + case kCsrTensorValuesIndex: + dst_abstract = csr_abs->values(); + pre_abstract_num = kCsrTensorValuesIndex; + break; + case kCsrTensorDenseShapeIndex: + dst_abstract = csr_abs->dense_shape(); + pre_abstract_num = kCsrTensorDenseShapeIndex; + break; + default: + MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString(); + break; + } + } else if (abstract->isa()) { + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + const auto &sub_abstracts = tuple_abstract->elements(); + if (sub_abstracts.size() <= index) { + MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString(); + } + for (size_t i = 0; i < index; ++i) { + pre_abstract_num += AnfAlgo::GetOutputNumByAbstract(sub_abstracts[i]); + } + dst_abstract = sub_abstracts[index]; + } else { + if (index != 0) { + MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString(); + } + } + MS_EXCEPTION_IF_NULL(dst_abstract); + + // Fetch real output index. + size_t ouput_num = AnfAlgo::GetOutputNumByAbstract(dst_abstract); + std::set real_indexs; + for (size_t i = pre_abstract_num; i < ouput_num + pre_abstract_num; ++i) { + real_indexs.emplace(i); + } + return real_indexs; +} + // Get all the real parameters corresponding to node. void FetchRealParameterByNode(const KernelWithIndex &node, std::set *real_parameters, std::set *invalid_call_nodes, @@ -384,7 +449,7 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index MS_EXCEPTION_IF_NULL(front_node); const auto &node_value = front_node->cast()->value(); - if ((!node_value->isa()) && (!node_value->isa()) && (!node_value->isa())) { + if (node_value->isa() || node_value->isa()) { return; } @@ -468,6 +533,7 @@ std::vector FetchInputNodeByNode(const AnfNodePtr &node) { const auto &node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple}); auto real_node = node_with_index.first; + size_t real_index = node_with_index.second; MS_EXCEPTION_IF_NULL(real_node); std::vector results; // 2. MakeTuple. @@ -481,13 +547,51 @@ std::vector FetchInputNodeByNode(const AnfNodePtr &node) { return results; } - // 3. One output node. + // 3. kPrimMakeCSRTensor. + if (IsCsrNode(real_node)) { + const auto &cnode = real_node->cast(); + const auto &inputs = cnode->inputs(); + if (inputs.size() <= kMakeCSRTensorInputStartPos) { + MS_LOG(EXCEPTION) << "Invalid make csr tensor node:" << cnode->DebugString(); + } + + // Fetch output put index. + const auto &prim_node = inputs[0]->cast(); + MS_EXCEPTION_IF_NULL(prim_node); + const auto &prim_value = prim_node->value()->cast(); + MS_EXCEPTION_IF_NULL(prim_value); + const auto &src_node = inputs[kMakeCSRTensorInputStartPos]; + MS_EXCEPTION_IF_NULL(src_node); + + const auto iter = sparse_attr_map.find(prim_value->name()); + // Csr node from the make csr tensor node. + if (AnfAlgo::CheckPrimitiveType(src_node, prim::kPrimMakeCSRTensor)) { + const auto &make_csr_tensor_cnode = src_node->cast(); + const auto &csr_tensor_inputs = make_csr_tensor_cnode->inputs(); + if (csr_tensor_inputs.size() <= kMakeCSRTensorInputNum) { + MS_LOG(EXCEPTION) << "Invalid make csr tensor node:" << cnode->DebugString(); + } + const auto &sub_results = FetchInputNodeByNode(csr_tensor_inputs[iter->second + kMakeCSRTensorInputStartPos]); + results.insert(results.end(), sub_results.begin(), sub_results.end()); + } else { + // Csr node from parameter or call node. + auto abstract = src_node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + auto real_indexs = FetchRealIndexByAbstract(abstract, iter->second); + (void)std::transform(real_indexs.begin(), real_indexs.end(), std::back_inserter(results), + [&src_node](const auto &index) { return KernelWithIndex(src_node, index); }); + } + return results; + } + + // 4. One output node. const auto &abstract = real_node->abstract(); - if (abstract == nullptr || (!abstract->isa())) { + if (abstract == nullptr || + ((!abstract->isa()) && (!abstract->isa()))) { if (abstract == nullptr) { MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString(); } - return {AnfAlgo::VisitKernelWithReturnType(real_node, 0)}; + return {AnfAlgo::VisitKernelWithReturnType(real_node, real_index)}; } // 4. Abstract is Tuple. @@ -528,6 +632,17 @@ bool HasAbstractRef(const AnfNodePtr &node) { return (abs != nullptr) && abs->isa(); } +bool IsCsrNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + return AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetIndptr) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetIndices) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetValues) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetDenseShape); +} + KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, KernelGraph *const graph) { MS_EXCEPTION_IF_NULL(graph); const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node); @@ -589,6 +704,22 @@ std::vector FetchInputNodeByCNode(const AnfNodePtr &node) { abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) { MS_EXCEPTION_IF_NULL(abstract); + if (abstract->isa()) { + auto csr_abs = abstract->cast(); + MS_EXCEPTION_IF_NULL(csr_abs); + if (index == kCsrTensorIndPtrIndex) { + return csr_abs->indptr(); + } else if (index == kCsrTensorIndicesIndex) { + return csr_abs->indices(); + } else if (index == kCsrTensorValuesIndex) { + return csr_abs->values(); + } else if (index >= kCsrTensorDenseShapeIndex) { + return FetchAbstractByIndex(csr_abs->dense_shape(), index - kCsrTensorDenseShapeIndex); + } else { + MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString(); + } + } + if (!abstract->isa()) { if (index != 0) { MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString(); @@ -906,13 +1037,20 @@ void ControlNodeParser::ParseDeviceContextForCallNode(const std::vectorcast(); MS_EXCEPTION_IF_NULL(cnode); const auto &inputs = cnode->inputs(); - if (inputs.size() - kCallInputStartPos > iter->second.size()) { - MS_LOG(EXCEPTION) << "Invalid input size:" << inputs.size() << " context size:" << iter->second.size() + size_t call_input_num = 0; + for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { + const auto &abstract = inputs[i]->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + call_input_num += AnfAlgo::GetOutputNumByAbstract(abstract); + } + + if (call_input_num > iter->second.size()) { + MS_LOG(EXCEPTION) << "Invalid input size:" << call_input_num << " context size:" << iter->second.size() << "for funcgraph" << func_graph->ToString() << " for call node:" << cnode->DebugString(); } // Fetch the device contexts for the real parameters on the call node. - for (size_t i = iter->second.size() - inputs.size() + kCallInputStartPos; i < iter->second.size(); ++i) { + for (size_t i = iter->second.size() - call_input_num; i < iter->second.size(); ++i) { if (i >= iter->second.size()) { MS_LOG(EXCEPTION) << "Invalid device context index:" << i << " for funcgraph:" << func_graph->ToString() << " device context size:" << iter->second.size() @@ -1076,7 +1214,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr for (const auto &formal_to_real_parameter : formal_to_real_parameters_) { for (const auto &real_parameter_with_index : formal_to_real_parameter.second) { const auto &real_parameter = real_parameter_with_index.first; - if (!real_parameter->isa() || IsValueNode(real_parameter)) { + if (!IsFrontValueNode(real_parameter)) { continue; } @@ -1097,7 +1235,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr for (const auto &front_to_backend_parameters : front_to_backend_parameters_) { const auto &front_node = front_to_backend_parameters.first.first; MS_EXCEPTION_IF_NULL(front_node); - if (front_node->isa() && (!front_to_backend_parameters.second.empty())) { + if (IsFrontValueNode(front_node) && (!front_to_backend_parameters.second.empty())) { const auto &backend_parameter = front_to_backend_parameters.second.begin()->first; const auto &device_context = front_to_backend_parameters.second.begin()->second; CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context); @@ -1107,18 +1245,21 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr // Create device tensors for those value nodes which direct return by a return node. for (const auto &control_node : control_nodes) { - if (!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) { + if ((!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) && (!AnfAlgo::IsCallNode(control_node))) { continue; } auto input_with_indexs = FetchInputNodeByCNode(control_node); auto iter = control_node_to_device_contexts_.find(control_node); - if (iter == control_node_to_device_contexts_.end() || iter->second.size() != input_with_indexs.size()) { - MS_LOG(EXCEPTION) << "Invalid device context for control node:" << control_node->DebugString(); + if (iter == control_node_to_device_contexts_.end() || iter->second.size() < input_with_indexs.size()) { + MS_LOG(EXCEPTION) << "Invalid device context for control node:" << control_node->DebugString() + << " need:" << input_with_indexs.size() << " current:" + << (iter == control_node_to_device_contexts_.end() ? "null" + : std::to_string(iter->second.size())); } for (size_t i = 0; i < input_with_indexs.size(); ++i) { const auto &input_with_index = input_with_indexs[i]; - if (input_with_index.first->isa() && (!IsValueNode(input_with_index.first)) && + if (IsFrontValueNode(input_with_index.first) && front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) { CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]); front_value_nodes_.emplace(input_with_index, iter->second[i]); diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index 9a79cd29eed..9d87c9f675f 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -56,6 +56,13 @@ constexpr size_t kCNodeInputStartPos = 1; constexpr size_t kReturnInputPos = 1; constexpr size_t kSingleControlNode = 1; +constexpr size_t kCsrTensorIndPtrIndex = 0; +constexpr size_t kCsrTensorIndicesIndex = 1; +constexpr size_t kCsrTensorValuesIndex = 2; +constexpr size_t kCsrTensorDenseShapeIndex = 3; +constexpr size_t kMakeCSRTensorInputStartPos = 1; +constexpr size_t kMakeCSRTensorInputNum = 4; + const char kEntranceActorNameSuffix[] = "_EntranceActor"; const char kExitActorNameSuffix[] = "_ExitActor"; const char kStackActorNameSuffix[] = "_StackActor"; @@ -87,6 +94,8 @@ using KernelGraphGroupInfoPtr = std::shared_ptr; // Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph, // it is determined whether it is a weight. bool HasAbstractRef(const AnfNodePtr &node); +// Check whether the node is a csr node. +bool IsCsrNode(const AnfNodePtr &node); // Get the front node corresponding to the backend node, if the front node is not a parameter node, return the // corresponding cnode. KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, KernelGraph *const graph); diff --git a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc index 85f514418e4..c2047de654c 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc @@ -257,7 +257,7 @@ std::vector ControlNodeScheduler::BuildExitActor(const GraphCompil std::vector device_contexts; for (const auto &node_with_context : kernel_graph_group_info->front_output_nodes_) { - if (HasAbstractMonad(node_with_context.first.first)) { + if (HasAbstractMonad(node_with_context.first.first) || IsCsrNode(node_with_context.first.first)) { continue; } // Collect inputs of exit actor. @@ -666,7 +666,7 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont branch_id); } } - if (abstract->isa()) { + if (real_abstract->isa()) { to_actor->input_partials_num_++; } else { to_actor->input_datas_num_++; diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index b7ec598de63..b5a28651711 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -954,13 +954,24 @@ BaseRef MindRTBackend::ConstructOutputByAbstract(const abstract::AbstractBasePtr MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position << " need:" << outputs_num << " total:" << output_tensors.size(); } + VectorRef outputs; + + if (abstract->isa()) { + auto csr_tensor_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(csr_tensor_abstract); + outputs.emplace_back(ConstructOutputByAbstract(csr_tensor_abstract->indptr(), output_tensors, output_position)); + outputs.emplace_back(ConstructOutputByAbstract(csr_tensor_abstract->indices(), output_tensors, output_position)); + outputs.emplace_back(ConstructOutputByAbstract(csr_tensor_abstract->values(), output_tensors, output_position)); + outputs.emplace_back( + ConstructOutputByAbstract(csr_tensor_abstract->dense_shape(), output_tensors, output_position)); + return outputs; + } if (!abstract->isa()) { (*output_position)++; return output_tensors[(*output_position) - 1]; } - VectorRef outputs; auto tuple_abstract = abstract->cast(); MS_EXCEPTION_IF_NULL(tuple_abstract); const auto &sub_abstracts = tuple_abstract->elements();