diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc index 9feda19dc96..c416711b5b1 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.cc +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -66,6 +66,24 @@ std::shared_ptr GetAbstract(const AnfNodePtr &node) { return dyn_cast(node->abstract()); } +bool CheckContainsDict(const AbstractBasePtr &abs) { + if (abs == nullptr) { + return false; + } + if (abs->isa()) { + return true; + } + if (abs->isa()) { + auto abs_seq = abs->cast(); + const auto &elements = abs_seq->elements(); + if (std::any_of(elements.begin(), elements.end(), + [](const AbstractBasePtr &element) { return CheckContainsDict(element); })) { + return true; + } + } + return false; +} + // =========================================================================== // BaseRewriter provides a common framework for data struct simplify. // =========================================================================== @@ -164,7 +182,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter { public: using ThisClass = SimplifyDataStructuresRewriter; SimplifyDataStructuresRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager) - : BaseRewriter(root_graph, manager), is_dict_output_{IsDictOutput()} {} + : BaseRewriter(root_graph, manager), is_dict_output_{HasDictOutput()} {} ~SimplifyDataStructuresRewriter() override = default; protected: @@ -347,13 +365,9 @@ class SimplifyDataStructuresRewriter : public BaseRewriter { return new_node; } - bool IsDictOutput() const { + bool HasDictOutput() const { const AnfNodePtr &output = root_graph_->output(); - auto abs_dict = GetAbstract(output); - if (abs_dict != nullptr) { - return true; - } - return false; + return CheckContainsDict(output->abstract()); } // DictSetItem --> PyExecute() diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index fd20389e871..f7f60459680 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -274,18 +274,71 @@ std::string ToOrdinal(const size_t &i) { return std::to_string(i) + suffix; } -py::object GetPyExecuteOutput(const AnfNodePtr &output) { +AnfNodePtr GetRealOutput(const AnfNodePtr &node) { + constexpr size_t real_node_index = 1; + if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + const auto cnode = dyn_cast(node); + MS_EXCEPTION_IF_NULL(cnode); + return GetRealOutput(cnode->input(real_node_index)); + } + return node; +} + +bool ContainPyExecuteOutputData(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->has_user_data()) { + return true; + } + auto abs = node->abstract(); + if (abs == nullptr || !abs->isa()) { + return false; + } + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if (std::any_of(inputs.begin(), inputs.end(), + [](const AnfNodePtr &input) { return ContainPyExecuteOutputData(input); })) { + return true; + } + return false; +} + +py::object GetVectorRefOutputDataWithPyExecuteObject(const AnfNodePtr &node, const BaseRef &value) { + MS_EXCEPTION_IF_NULL(node); + auto real_node = GetRealOutput(node); + MS_EXCEPTION_IF_NULL(real_node); + auto abs = real_node->abstract(); + if (!abs->isa() || !real_node->isa()) { + if (real_node->has_user_data()) { + // None case will consider later. + const auto &output_data = real_node->user_data(); + return output_data->obj; + } + return BaseRefToPyData(value, abs); + } + auto abs_seq = utils::cast(abs); + auto value_seq = utils::cast(value); + if (abs_seq->size() != value_seq.size()) { + MS_LOG(EXCEPTION) << "abs size and value size not match. abs size: " << abs_seq->size() + << ", value size: " << value_seq.size(); + } + + size_t seq_size = abs_seq->size(); + // List output will be convert to PyExecute real_node, only need to consider tuple here. + py::tuple ret = py::tuple(seq_size); + auto real_cnode = real_node->cast(); + for (size_t i = 0; i < seq_size; ++i) { + ret[i] = GetVectorRefOutputDataWithPyExecuteObject(real_cnode->input(i + 1), value_seq[i]); + } + return ret; +} + +py::object GetPyExecuteOutput(const AnfNodePtr &output, const BaseRef &value) { static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0"); if (support_fallback_runtime) { - std::function get_real_output = [&get_real_output](const AnfNodePtr &node) { - if (IsPrimitiveCNode(node, prim::kPrimDepend)) { - const auto cnode = dyn_cast(node); - MS_EXCEPTION_IF_NULL(cnode); - return get_real_output(cnode->input(1)); - } - return node; - }; - const auto &real_output = get_real_output(output); + const auto &real_output = GetRealOutput(output); MS_LOG(INFO) << "Real output: " << real_output << ", " << real_output->DebugString() << ", has \'PyExecuteOutputData\': " << real_output->has_user_data(); if (real_output->has_user_data()) { @@ -297,6 +350,16 @@ py::object GetPyExecuteOutput(const AnfNodePtr &output) { return res_obj; } } + // Handle multiple input case. + auto real_output_abs = real_output->abstract(); + MS_EXCEPTION_IF_NULL(real_output_abs); + if (real_output_abs->isa() && ContainPyExecuteOutputData(real_output)) { + MS_LOG(DEBUG) << "Contains PyExecute output data."; + if (!utils::isa(value)) { + MS_LOG(EXCEPTION) << "When the output is tuple, value should be vector ref."; + } + return GetVectorRefOutputDataWithPyExecuteObject(real_output, value); + } } return py::none(); } @@ -1347,15 +1410,16 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o MS_LOG(DEBUG) << "Eval run" << ms_context->backend_policy(); const auto &output = execute_info->func_graph->output(); MS_EXCEPTION_IF_NULL(output); + const auto &output_abs = output->abstract(); MS_EXCEPTION_IF_NULL(output_abs); + BaseRef value; for (int64_t i = 0; i < vm_loop; i++) { - BaseRef value = (*run)(execute_info->arg_list); + value = (*run)(execute_info->arg_list); res = BaseRefToPyData(value, output_abs); } - // Replace the output if it's not Tensor, but Python data. - const auto &py_res = GetPyExecuteOutput(output); + const auto &py_res = GetPyExecuteOutput(output, value); if (py_res != py::none()) { return py_res; } diff --git a/tests/st/fallback/test_graph_fallback_runtime.py b/tests/st/fallback/test_graph_fallback_runtime.py index 59b62661707..c745bc34d96 100644 --- a/tests/st/fallback/test_graph_fallback_runtime.py +++ b/tests/st/fallback/test_graph_fallback_runtime.py @@ -227,6 +227,56 @@ def test_dict_get_3(): assert out == {'y': ms.Tensor(np.array(1), ms.int64), 'a': 'a', 'b': 'c'} +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_multiple_return_contains_dict(): + """ + Feature: Return multiple outputs including dict. + Description: Support dict return. + Expectation: No exception. + """ + @ms.jit + def dict_net_2(): + x = {'a': 1, 'b': 2} + y = x.get('a') + y_tensor = ms.Tensor([y]) + z = dict(a=y_tensor) + return y, z, (1, 2) + + out = dict_net_2() + assert len(out) == 3 + assert out[0] == 1 + assert out[1] == {'a': ms.Tensor(np.array(1), ms.int64)} + assert out[2] == (1, 2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_multiple_return_contains_dict_2(): + """ + Feature: Return multiple outputs including dict. + Description: Support dict return. + Expectation: No exception. + """ + @ms.jit + def dict_net_2(a): + x = {'a': a, 'b': 2} + return a, (x, (1, 2)) + + out = dict_net_2(ms.Tensor([1])) + assert len(out) == 2 + assert out[0] == ms.Tensor([1]) + assert len(out[1]) == 2 + assert out[1][0] == {'a': ms.Tensor([1], ms.int64), 'b': 2} + assert out[1][1] == (1, 2) + + def weight_variable(): """weight initial""" return TruncatedNormal(0.02)