!48282 Enable return multiple output including dict

Merge pull request !48282 from LiangZhibo/return
This commit is contained in:
i-robot 2023-02-02 08:37:21 +00:00 committed by Gitee
commit c3859a7047
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 148 additions and 20 deletions

View File

@ -66,6 +66,24 @@ std::shared_ptr<T> GetAbstract(const AnfNodePtr &node) {
return dyn_cast<T>(node->abstract());
}
bool CheckContainsDict(const AbstractBasePtr &abs) {
if (abs == nullptr) {
return false;
}
if (abs->isa<AbstractDictionary>()) {
return true;
}
if (abs->isa<AbstractSequence>()) {
auto abs_seq = abs->cast<AbstractSequencePtr>();
const auto &elements = abs_seq->elements();
if (std::any_of(elements.begin(), elements.end(),
[](const AbstractBasePtr &element) { return CheckContainsDict(element); })) {
return true;
}
}
return false;
}
// ===========================================================================
// BaseRewriter provides a common framework for data struct simplify.
// ===========================================================================
@ -164,7 +182,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
public:
using ThisClass = SimplifyDataStructuresRewriter;
SimplifyDataStructuresRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
: BaseRewriter(root_graph, manager), is_dict_output_{IsDictOutput()} {}
: BaseRewriter(root_graph, manager), is_dict_output_{HasDictOutput()} {}
~SimplifyDataStructuresRewriter() override = default;
protected:
@ -347,13 +365,9 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return new_node;
}
bool IsDictOutput() const {
bool HasDictOutput() const {
const AnfNodePtr &output = root_graph_->output();
auto abs_dict = GetAbstract<AbstractDictionary>(output);
if (abs_dict != nullptr) {
return true;
}
return false;
return CheckContainsDict(output->abstract());
}
// DictSetItem --> PyExecute()

View File

@ -274,18 +274,71 @@ std::string ToOrdinal(const size_t &i) {
return std::to_string(i) + suffix;
}
py::object GetPyExecuteOutput(const AnfNodePtr &output) {
AnfNodePtr GetRealOutput(const AnfNodePtr &node) {
constexpr size_t real_node_index = 1;
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
const auto cnode = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(cnode);
return GetRealOutput(cnode->input(real_node_index));
}
return node;
}
bool ContainPyExecuteOutputData(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->has_user_data<kernel::PyExecuteOutputData>()) {
return true;
}
auto abs = node->abstract();
if (abs == nullptr || !abs->isa<abstract::AbstractSequence>()) {
return false;
}
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(),
[](const AnfNodePtr &input) { return ContainPyExecuteOutputData(input); })) {
return true;
}
return false;
}
py::object GetVectorRefOutputDataWithPyExecuteObject(const AnfNodePtr &node, const BaseRef &value) {
MS_EXCEPTION_IF_NULL(node);
auto real_node = GetRealOutput(node);
MS_EXCEPTION_IF_NULL(real_node);
auto abs = real_node->abstract();
if (!abs->isa<abstract::AbstractSequence>() || !real_node->isa<CNode>()) {
if (real_node->has_user_data<kernel::PyExecuteOutputData>()) {
// None case will consider later.
const auto &output_data = real_node->user_data<kernel::PyExecuteOutputData>();
return output_data->obj;
}
return BaseRefToPyData(value, abs);
}
auto abs_seq = utils::cast<abstract::AbstractSequencePtr>(abs);
auto value_seq = utils::cast<VectorRef>(value);
if (abs_seq->size() != value_seq.size()) {
MS_LOG(EXCEPTION) << "abs size and value size not match. abs size: " << abs_seq->size()
<< ", value size: " << value_seq.size();
}
size_t seq_size = abs_seq->size();
// List output will be convert to PyExecute real_node, only need to consider tuple here.
py::tuple ret = py::tuple(seq_size);
auto real_cnode = real_node->cast<CNodePtr>();
for (size_t i = 0; i < seq_size; ++i) {
ret[i] = GetVectorRefOutputDataWithPyExecuteObject(real_cnode->input(i + 1), value_seq[i]);
}
return ret;
}
py::object GetPyExecuteOutput(const AnfNodePtr &output, const BaseRef &value) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime) {
std::function<AnfNodePtr(const AnfNodePtr &)> get_real_output = [&get_real_output](const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
const auto cnode = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(cnode);
return get_real_output(cnode->input(1));
}
return node;
};
const auto &real_output = get_real_output(output);
const auto &real_output = GetRealOutput(output);
MS_LOG(INFO) << "Real output: " << real_output << ", " << real_output->DebugString()
<< ", has \'PyExecuteOutputData\': " << real_output->has_user_data<kernel::PyExecuteOutputData>();
if (real_output->has_user_data<kernel::PyExecuteOutputData>()) {
@ -297,6 +350,16 @@ py::object GetPyExecuteOutput(const AnfNodePtr &output) {
return res_obj;
}
}
// Handle multiple input case.
auto real_output_abs = real_output->abstract();
MS_EXCEPTION_IF_NULL(real_output_abs);
if (real_output_abs->isa<AbstractTuple>() && ContainPyExecuteOutputData(real_output)) {
MS_LOG(DEBUG) << "Contains PyExecute output data.";
if (!utils::isa<VectorRef>(value)) {
MS_LOG(EXCEPTION) << "When the output is tuple, value should be vector ref.";
}
return GetVectorRefOutputDataWithPyExecuteObject(real_output, value);
}
}
return py::none();
}
@ -1347,15 +1410,16 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o
MS_LOG(DEBUG) << "Eval run" << ms_context->backend_policy();
const auto &output = execute_info->func_graph->output();
MS_EXCEPTION_IF_NULL(output);
const auto &output_abs = output->abstract();
MS_EXCEPTION_IF_NULL(output_abs);
BaseRef value;
for (int64_t i = 0; i < vm_loop; i++) {
BaseRef value = (*run)(execute_info->arg_list);
value = (*run)(execute_info->arg_list);
res = BaseRefToPyData(value, output_abs);
}
// Replace the output if it's not Tensor, but Python data.
const auto &py_res = GetPyExecuteOutput(output);
const auto &py_res = GetPyExecuteOutput(output, value);
if (py_res != py::none()) {
return py_res;
}

View File

@ -227,6 +227,56 @@ def test_dict_get_3():
assert out == {'y': ms.Tensor(np.array(1), ms.int64), 'a': 'a', 'b': 'c'}
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_multiple_return_contains_dict():
"""
Feature: Return multiple outputs including dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_2():
x = {'a': 1, 'b': 2}
y = x.get('a')
y_tensor = ms.Tensor([y])
z = dict(a=y_tensor)
return y, z, (1, 2)
out = dict_net_2()
assert len(out) == 3
assert out[0] == 1
assert out[1] == {'a': ms.Tensor(np.array(1), ms.int64)}
assert out[2] == (1, 2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_multiple_return_contains_dict_2():
"""
Feature: Return multiple outputs including dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_2(a):
x = {'a': a, 'b': 2}
return a, (x, (1, 2))
out = dict_net_2(ms.Tensor([1]))
assert len(out) == 2
assert out[0] == ms.Tensor([1])
assert len(out[1]) == 2
assert out[1][0] == {'a': ms.Tensor([1], ms.int64), 'b': 2}
assert out[1][1] == (1, 2)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)