forked from mindspore-Ecosystem/mindspore
!48282 Enable return multiple output including dict
Merge pull request !48282 from LiangZhibo/return
This commit is contained in:
commit
c3859a7047
|
@ -66,6 +66,24 @@ std::shared_ptr<T> GetAbstract(const AnfNodePtr &node) {
|
|||
return dyn_cast<T>(node->abstract());
|
||||
}
|
||||
|
||||
bool CheckContainsDict(const AbstractBasePtr &abs) {
|
||||
if (abs == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (abs->isa<AbstractDictionary>()) {
|
||||
return true;
|
||||
}
|
||||
if (abs->isa<AbstractSequence>()) {
|
||||
auto abs_seq = abs->cast<AbstractSequencePtr>();
|
||||
const auto &elements = abs_seq->elements();
|
||||
if (std::any_of(elements.begin(), elements.end(),
|
||||
[](const AbstractBasePtr &element) { return CheckContainsDict(element); })) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// BaseRewriter provides a common framework for data struct simplify.
|
||||
// ===========================================================================
|
||||
|
@ -164,7 +182,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
public:
|
||||
using ThisClass = SimplifyDataStructuresRewriter;
|
||||
SimplifyDataStructuresRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
|
||||
: BaseRewriter(root_graph, manager), is_dict_output_{IsDictOutput()} {}
|
||||
: BaseRewriter(root_graph, manager), is_dict_output_{HasDictOutput()} {}
|
||||
~SimplifyDataStructuresRewriter() override = default;
|
||||
|
||||
protected:
|
||||
|
@ -347,13 +365,9 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
return new_node;
|
||||
}
|
||||
|
||||
bool IsDictOutput() const {
|
||||
bool HasDictOutput() const {
|
||||
const AnfNodePtr &output = root_graph_->output();
|
||||
auto abs_dict = GetAbstract<AbstractDictionary>(output);
|
||||
if (abs_dict != nullptr) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return CheckContainsDict(output->abstract());
|
||||
}
|
||||
|
||||
// DictSetItem --> PyExecute()
|
||||
|
|
|
@ -274,18 +274,71 @@ std::string ToOrdinal(const size_t &i) {
|
|||
return std::to_string(i) + suffix;
|
||||
}
|
||||
|
||||
py::object GetPyExecuteOutput(const AnfNodePtr &output) {
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
|
||||
if (support_fallback_runtime) {
|
||||
std::function<AnfNodePtr(const AnfNodePtr &)> get_real_output = [&get_real_output](const AnfNodePtr &node) {
|
||||
AnfNodePtr GetRealOutput(const AnfNodePtr &node) {
|
||||
constexpr size_t real_node_index = 1;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
const auto cnode = dyn_cast<CNode>(node);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return get_real_output(cnode->input(1));
|
||||
return GetRealOutput(cnode->input(real_node_index));
|
||||
}
|
||||
return node;
|
||||
};
|
||||
const auto &real_output = get_real_output(output);
|
||||
}
|
||||
|
||||
bool ContainPyExecuteOutputData(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->has_user_data<kernel::PyExecuteOutputData>()) {
|
||||
return true;
|
||||
}
|
||||
auto abs = node->abstract();
|
||||
if (abs == nullptr || !abs->isa<abstract::AbstractSequence>()) {
|
||||
return false;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
if (std::any_of(inputs.begin(), inputs.end(),
|
||||
[](const AnfNodePtr &input) { return ContainPyExecuteOutputData(input); })) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
py::object GetVectorRefOutputDataWithPyExecuteObject(const AnfNodePtr &node, const BaseRef &value) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto real_node = GetRealOutput(node);
|
||||
MS_EXCEPTION_IF_NULL(real_node);
|
||||
auto abs = real_node->abstract();
|
||||
if (!abs->isa<abstract::AbstractSequence>() || !real_node->isa<CNode>()) {
|
||||
if (real_node->has_user_data<kernel::PyExecuteOutputData>()) {
|
||||
// None case will consider later.
|
||||
const auto &output_data = real_node->user_data<kernel::PyExecuteOutputData>();
|
||||
return output_data->obj;
|
||||
}
|
||||
return BaseRefToPyData(value, abs);
|
||||
}
|
||||
auto abs_seq = utils::cast<abstract::AbstractSequencePtr>(abs);
|
||||
auto value_seq = utils::cast<VectorRef>(value);
|
||||
if (abs_seq->size() != value_seq.size()) {
|
||||
MS_LOG(EXCEPTION) << "abs size and value size not match. abs size: " << abs_seq->size()
|
||||
<< ", value size: " << value_seq.size();
|
||||
}
|
||||
|
||||
size_t seq_size = abs_seq->size();
|
||||
// List output will be convert to PyExecute real_node, only need to consider tuple here.
|
||||
py::tuple ret = py::tuple(seq_size);
|
||||
auto real_cnode = real_node->cast<CNodePtr>();
|
||||
for (size_t i = 0; i < seq_size; ++i) {
|
||||
ret[i] = GetVectorRefOutputDataWithPyExecuteObject(real_cnode->input(i + 1), value_seq[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
py::object GetPyExecuteOutput(const AnfNodePtr &output, const BaseRef &value) {
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
|
||||
if (support_fallback_runtime) {
|
||||
const auto &real_output = GetRealOutput(output);
|
||||
MS_LOG(INFO) << "Real output: " << real_output << ", " << real_output->DebugString()
|
||||
<< ", has \'PyExecuteOutputData\': " << real_output->has_user_data<kernel::PyExecuteOutputData>();
|
||||
if (real_output->has_user_data<kernel::PyExecuteOutputData>()) {
|
||||
|
@ -297,6 +350,16 @@ py::object GetPyExecuteOutput(const AnfNodePtr &output) {
|
|||
return res_obj;
|
||||
}
|
||||
}
|
||||
// Handle multiple input case.
|
||||
auto real_output_abs = real_output->abstract();
|
||||
MS_EXCEPTION_IF_NULL(real_output_abs);
|
||||
if (real_output_abs->isa<AbstractTuple>() && ContainPyExecuteOutputData(real_output)) {
|
||||
MS_LOG(DEBUG) << "Contains PyExecute output data.";
|
||||
if (!utils::isa<VectorRef>(value)) {
|
||||
MS_LOG(EXCEPTION) << "When the output is tuple, value should be vector ref.";
|
||||
}
|
||||
return GetVectorRefOutputDataWithPyExecuteObject(real_output, value);
|
||||
}
|
||||
}
|
||||
return py::none();
|
||||
}
|
||||
|
@ -1347,15 +1410,16 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o
|
|||
MS_LOG(DEBUG) << "Eval run" << ms_context->backend_policy();
|
||||
const auto &output = execute_info->func_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
|
||||
const auto &output_abs = output->abstract();
|
||||
MS_EXCEPTION_IF_NULL(output_abs);
|
||||
BaseRef value;
|
||||
for (int64_t i = 0; i < vm_loop; i++) {
|
||||
BaseRef value = (*run)(execute_info->arg_list);
|
||||
value = (*run)(execute_info->arg_list);
|
||||
res = BaseRefToPyData(value, output_abs);
|
||||
}
|
||||
|
||||
// Replace the output if it's not Tensor, but Python data.
|
||||
const auto &py_res = GetPyExecuteOutput(output);
|
||||
const auto &py_res = GetPyExecuteOutput(output, value);
|
||||
if (py_res != py::none()) {
|
||||
return py_res;
|
||||
}
|
||||
|
|
|
@ -227,6 +227,56 @@ def test_dict_get_3():
|
|||
assert out == {'y': ms.Tensor(np.array(1), ms.int64), 'a': 'a', 'b': 'c'}
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_multiple_return_contains_dict():
|
||||
"""
|
||||
Feature: Return multiple outputs including dict.
|
||||
Description: Support dict return.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms.jit
|
||||
def dict_net_2():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = x.get('a')
|
||||
y_tensor = ms.Tensor([y])
|
||||
z = dict(a=y_tensor)
|
||||
return y, z, (1, 2)
|
||||
|
||||
out = dict_net_2()
|
||||
assert len(out) == 3
|
||||
assert out[0] == 1
|
||||
assert out[1] == {'a': ms.Tensor(np.array(1), ms.int64)}
|
||||
assert out[2] == (1, 2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_multiple_return_contains_dict_2():
|
||||
"""
|
||||
Feature: Return multiple outputs including dict.
|
||||
Description: Support dict return.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms.jit
|
||||
def dict_net_2(a):
|
||||
x = {'a': a, 'b': 2}
|
||||
return a, (x, (1, 2))
|
||||
|
||||
out = dict_net_2(ms.Tensor([1]))
|
||||
assert len(out) == 2
|
||||
assert out[0] == ms.Tensor([1])
|
||||
assert len(out[1]) == 2
|
||||
assert out[1][0] == {'a': ms.Tensor([1], ms.int64), 'b': 2}
|
||||
assert out[1][1] == (1, 2)
|
||||
|
||||
|
||||
def weight_variable():
|
||||
"""weight initial"""
|
||||
return TruncatedNormal(0.02)
|
||||
|
|
Loading…
Reference in New Issue