Support the values of the dict has nested sequences and None.

This commit is contained in:
Margaret_wangrui 2023-02-25 18:33:39 +08:00
parent ed363cde09
commit 3e920a73fc
13 changed files with 146 additions and 87 deletions

View File

@ -16,7 +16,7 @@
* limitations under the License.
*/
#include "frontend/optimizer/clean.h"
#include "frontend/optimizer/fallback_rewriter.h"
#include <iterator>
#include <string>
#include <vector>
@ -98,6 +98,8 @@ class BaseRewriter : protected SimpleRewriter {
bool need_renormalized() const { return need_renormalized_; }
void set_need_renormalized(bool need_renormalized) { need_renormalized_ = need_renormalized; }
virtual bool Execute() {
bool changed = Run();
if (changed) {
@ -941,26 +943,18 @@ class CleanAfterOptARewriter : public BaseRewriter {
{prim::kPrimMakeDict, &ThisClass::ConvertMakeDict},
};
// Convert ValueNode<None> to PyExecute("None", ("None"), ("None")).
AnfNodePtr NoneConvertPyExecute(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto str_value = std::make_shared<StringImm>("None");
auto script_node = NewValueNode(str_value);
script_node->set_abstract(str_value->ToAbstract());
auto empty_tuple = std::vector<ValuePtr>();
auto empty_tuple_value = std::make_shared<ValueTuple>(empty_tuple);
auto local_key_node = NewValueNode(empty_tuple_value);
local_key_node->set_abstract(empty_tuple_value->ToAbstract());
auto local_value_node = NewValueNode(empty_tuple_value);
local_value_node->set_abstract(empty_tuple_value->ToAbstract());
std::vector<ValuePtr> none_value{str_value};
const auto none_tuple = std::make_shared<ValueTuple>(none_value);
auto none_tuple_node = NewValueNode(none_tuple);
AnfNodePtr none_execute_node =
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), script_node, local_key_node, local_value_node});
ShapeVector shp{abstract::Shape::kShapeRankAny};
auto abs = std::make_shared<abstract::AbstractTensor>(kFloat64, std::make_shared<abstract::Shape>(shp));
none_execute_node->set_abstract(abs);
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), script_node, none_tuple_node, none_tuple_node});
MS_LOG(DEBUG) << "none_execute_node:" << none_execute_node->DebugString();
return none_execute_node;
}
@ -981,13 +975,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
if (!IsValueNode<None>(input)) {
continue;
}
// Convert ValueNode<None> to PyExecute("None", (), ()).
auto none_py_execute = NoneConvertPyExecute(cur_func);
manager_->Replace(input, none_py_execute);
}
// If the cnode is depend node, need renew the abstract of the cnode.
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && IsPrimitiveCNode(cnode->input(1), prim::kPrimPyExecute)) {
cnode->set_abstract(cnode->input(1)->abstract());
set_need_renormalized(true);
}
}
@ -1029,26 +1019,53 @@ class CleanAfterOptARewriter : public BaseRewriter {
return value;
}
AnfNodePtr ProcessValueSequence(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueSequence>()) {
auto value_seq = value->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(value_seq);
auto values = value_seq->value();
std::vector<AnfNodePtr> value_seq_inputs{NewValueNode(prim::kPrimMakeTuple)};
for (auto inner_value : values) {
auto inner_value_seq = ProcessValueSequence(inner_value);
(void)value_seq_inputs.emplace_back(inner_value_seq);
}
auto iter_value = root_graph_->NewCNode(value_seq_inputs);
return iter_value;
}
if (value->isa<None>()) {
return NoneConvertPyExecute(root_graph_);
}
return NewValueNode(value);
}
AnfNodePtr PackDictValue(const ValueDictionaryPtr &dict) {
const auto &keys_values = dict->value();
std::vector<AnfNodePtr> value_list{NewValueNode(prim::kPrimMakeTuple)};
for (const auto &key_value : keys_values) {
auto iter_value = ProcessValueSequence(key_value.second);
(void)value_list.emplace_back(iter_value);
}
auto value_tuple_node = root_graph_->NewCNode(value_list);
return value_tuple_node;
}
// dict(k0:v0, k1:v1, ...) --> PyExecute('dict(zip(keys, values))', ...)
AnfNodePtr RebuildValueDict(const ValueNodePtr &value_node, const ValueDictionaryPtr &dict) {
const auto &keys_values = dict->value();
std::vector<ValuePtr> key_list;
key_list.reserve(keys_values.size());
std::vector<ValuePtr> value_list;
value_list.reserve(keys_values.size());
for (const auto &key_value : keys_values) {
(void)key_list.emplace_back(key_value.first);
(void)value_list.emplace_back(key_value.second);
}
// Local parameters values.
// Pack the key tuple.
std::vector<ValuePtr> key_list;
key_list.reserve(keys_values.size());
for (const auto &key_value : keys_values) {
(void)key_list.emplace_back(key_value.first);
}
const auto key_tuple = std::make_shared<ValueTuple>(key_list);
auto key_tuple_node = NewValueNode(key_tuple);
// Pack the value
const auto value_tuple = std::make_shared<ValueTuple>(value_list);
auto value_tuple_node = NewValueNode(value_tuple);
// Pack the value tuple.
auto value_tuple_node = PackDictValue(dict);
// Generate Make Dict PyExecute Node value
auto make_key_tuple_node = ConstructInternalTupleKeysNode(root_graph_, key_tuple_node);
@ -1151,7 +1168,7 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const pipeline::ResourcePtr &resou
CleanAfterOptARewriter rewriter(root, manager);
bool change = rewriter.Execute();
// Renormalize for new PyExecute node.
if (change && rewriter.need_renormalized()) {
if (rewriter.need_renormalized()) {
abstract::AbstractBasePtrList new_args_spec;
std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec),
[](const AnfNodePtr &param) -> AbstractBasePtr { return param->abstract(); });

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,8 +16,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CLEAN_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CLEAN_H_
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_FALLBACK_REWRITRER_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_FALLBACK_REWRITRER_H_
#include "ir/anf.h"
#include "frontend/operator/ops.h"
@ -34,4 +34,4 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const pipeline::ResourcePtr &resou
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CLEAN_H_
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_FALLBACK_REWRITRER_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,7 +19,7 @@
#include "frontend/optimizer/irpass/branch_culling.h"
#include "frontend/optimizer/irpass/cast_eliminate.h"
#include "frontend/optimizer/irpass/get_grad_eliminate.h"
#include "frontend/optimizer/irpass/convert.h"
#include "frontend/optimizer/irpass/print_converter.h"
#include "frontend/optimizer/irpass/environ_eliminate.h"
#include "frontend/optimizer/irpass/meta_fg_var_prepare.h"
#include "frontend/optimizer/irpass/inline.h"

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -1553,11 +1553,6 @@ void SetDynamicInputSizeAttr(const CNodePtr &cnode) {
bool is_dyn_input = false;
for (size_t i = 0; i < input_num; ++i) {
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
// remove monad input
auto abs = input_node->abstract();
if (abs->isa<abstract::AbstractMonad>()) {
continue;
}
if (i < input_obj_types.size() && input_obj_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
auto [input_structural, input_size, dyn_input] = CalOutputTupleSize(input_node);
is_dyn_input |= dyn_input;

View File

@ -30,7 +30,7 @@
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/cse_pass.h"
#include "frontend/optimizer/clean.h"
#include "frontend/optimizer/fallback_rewriter.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/graph_transform.h"
#include "frontend/optimizer/auto_monad_eliminate.h"

View File

@ -22,6 +22,7 @@
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include "pybind11/pybind11.h"
#include "pybind_api/pybind_patch.h"
@ -36,9 +37,34 @@
#include "mindspore/ccsrc/pybind_api/ir/tensor_py.h"
#include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
#include "mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
#include "mindspore/ccsrc/pipeline/jit/parse/resolve.h"
namespace py = pybind11;
namespace mindspore {
namespace abstract {
using PyObjectWrapperPtr = std::shared_ptr<parse::PyObjectWrapper>;
namespace pyexecute_user_data_catcher {
std::pair<bool, ValuePtr> PyExecuteUserDataCatcher(const AbstractBasePtr &element_abs) {
MS_EXCEPTION_IF_NULL(element_abs);
if (element_abs->has_user_data<kernel::PyExecuteOutputUserData>()) {
const auto &data = element_abs->user_data<kernel::PyExecuteOutputUserData>();
MS_EXCEPTION_IF_NULL(data);
auto python_obj = std::make_shared<parse::PyObjectWrapper>(data->obj, "graph python obj");
return {true, python_obj};
}
return {false, nullptr};
}
struct PyExecuteUserDataCatcherRegister {
PyExecuteUserDataCatcherRegister() noexcept {
abstract::AbstractBase::set_pyexecute_user_data_catcher(
[](const AbstractBasePtr &element_abs) { return PyExecuteUserDataCatcher(element_abs); });
}
~PyExecuteUserDataCatcherRegister() {}
} pyexecute_user_data_catcher_register;
} // namespace pyexecute_user_data_catcher
} // namespace abstract
static py::object CallPythonGetGlobalParams() {
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
py::module mod = python_adapter::GetPyModule(python_mod_parse);
@ -64,15 +90,6 @@ class PyExecuteInitializer {
const auto &keys_tuple_abs = input_args[1];
const auto &keys_tuple = keys_tuple_abs->BuildValue();
const auto &keys = dyn_cast<ValueSequence>(keys_tuple);
// Process PyExecute("None", (), (), io)
// Since the backend converts the empty tuple into an empty tensor(not keep ValueSequence),
// so special handling of None is required.
if (script->ToString() == "None") {
const auto &output = py::none();
PushPyExecuteOutput(script, output);
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
return abstract::MakeAbstract(infer_shape, kFloat64);
}
if (keys == nullptr) {
MS_LOG(DEBUG) << "The keys is not tuple value, but got " << keys_tuple->ToString();
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
@ -90,19 +107,24 @@ class PyExecuteInitializer {
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
return abstract::MakeAbstract(infer_shape, kFloat64);
}
MS_LOG(DEBUG) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
MS_LOG(DEBUG) << "The script is: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
<< ", values_tuple: " << values_tuple->ToString();
if (keys->size() != values->size()) {
MS_LOG(EXCEPTION) << "The length of keys(" << keys->size() << ") is not equal of the length of values("
<< values->size() << ").";
}
py::gil_scoped_acquire gil_acquire;
py::dict local_dict;
for (size_t i = 0; i < keys->size(); ++i) {
const auto &key = (*keys)[i];
const auto &key_str = dyn_cast<StringImm>(key);
MS_EXCEPTION_IF_NULL(key_str);
const auto &value = (*values)[i];
MS_LOG(DEBUG) << "input[" << i << "], value : " << value;
MS_LOG(DEBUG) << "input[" << i << "], value : " << value->ToString();
const auto &tuple_abs = values_tuple_abs->cast<abstract::AbstractSequencePtr>();
const auto &value_abs = (*tuple_abs)[i];
if (value_abs->has_user_data<kernel::PyExecuteOutputUserData>()) {
const auto &output_data = value_abs->user_data<kernel::PyExecuteOutputUserData>();
auto obj = output_data->obj;

View File

@ -337,6 +337,12 @@ static ValueNameToConverterVector value_name_to_converter = {
auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
return interpreted_object->obj();
}},
// parse::PyObjectWrapper
{parse::PyObjectWrapper::kTypeId,
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
auto py_object = value->cast<parse::PyObjectWrapperPtr>();
return py_object->obj();
}},
// None
{None::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
// AnyValue

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -751,12 +751,25 @@ AbstractBasePtrList AbstractSequence::ElementsPartialBroaden() const {
return element_list;
}
std::pair<bool, ValuePtr> GetValueFromUserData(const AbstractBasePtr &element_abs) {
MS_EXCEPTION_IF_NULL(element_abs);
if (abstract::AbstractBase::pyexecute_user_data_catcher()) {
return abstract::AbstractBase::pyexecute_user_data_catcher()(element_abs);
}
return {false, nullptr};
}
template <typename T>
ValuePtr AbstractSequence::ElementsBuildValue() const {
std::vector<ValuePtr> element_value_list;
for (const auto &element : elements_) {
MS_EXCEPTION_IF_NULL(element);
ValuePtr element_value = element->BuildValue();
auto [has_user_data, element_value] = GetValueFromUserData(element);
if (has_user_data && element_value != nullptr) {
element_value_list.push_back(element_value);
continue;
}
element_value = element->BuildValue();
MS_EXCEPTION_IF_NULL(element_value);
if (element_value->isa<AnyValue>()) {
return kAnyValue;

View File

@ -259,6 +259,14 @@ class MS_CORE_API AbstractBase : public Base {
static void set_interpret_bool_checker(InterpretBoolChecker checker) { interpret_bool_checker_ = checker; }
static inline InterpretBoolChecker interpret_bool_checker() { return interpret_bool_checker_; }
/// \brief Process the user date of abstract with PyExecute node.
using PyExecuteUserDataCatcher = std::pair<bool, ValuePtr> (*)(const AbstractBasePtr &element_abs);
static inline PyExecuteUserDataCatcher pyexecute_user_data_catcher_ = nullptr;
static void set_pyexecute_user_data_catcher(PyExecuteUserDataCatcher catcher) {
pyexecute_user_data_catcher_ = catcher;
}
static inline PyExecuteUserDataCatcher pyexecute_user_data_catcher() { return pyexecute_user_data_catcher_; }
std::string name() const { return name_; }
void set_name(const std::string &name) { name_ = name; }

View File

@ -1,4 +1,4 @@
# Copyright 2021-2022 Huawei Technologies Co., Ltd
# Copyright 2021-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
import pytest
from tqdm import tqdm
import numpy as np
import mindspore as ms
import mindspore.nn as nn
@ -80,7 +79,7 @@ def test_auto_monad_layer():
train_net.set_train()
gen_samples = dict()
num_epoch = 21
for epoch in tqdm(range(num_epoch)):
for epoch in range(num_epoch):
loss = []
for _, (batch,) in enumerate(dataloader):
batch = Tensor(batch, dtype=ms.float32)

View File

@ -356,7 +356,6 @@ def test_none_is_output_of_function_with_side_effect():
assert res == 4
@pytest.mark.skip(reason="No support None in dict return.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -373,11 +372,34 @@ def test_none_is_input_of_dict_return():
def foo():
x = {'a': 'a', 'b': 'b'}
y = x.get('a')
z = dict(y=y, v=False, w=None)
z = dict(y=y, u=9, v=False, w=None)
return z
out = foo()
assert out == {'y': 'a', 'v': False, 'w': None}
assert out == {'y': 'a', 'u': 9, 'v': False, 'w': None}
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_none_nested_input_of_dict_return():
"""
Feature: Support None.
Description: Support None is input of dict, and the dict is return.
Expectation: No exception.
"""
@jit
def foo():
x = {'a': 'a', 'b': 'b'}
y = x.get('a')
z = dict(y=y, u=9, v=False, w=(None, None), q=[1, (2, None), None])
return z
out = foo()
assert out == {'y': 'a', 'u': 9, 'v': False, 'w': (None, None), 'q': (1, (2, None), None)}
@pytest.mark.level0

View File

@ -194,29 +194,6 @@ def test_dict_return_2():
assert out == {'a': ms.Tensor(np.array(1), ms.int64)}
@pytest.mark.skip(reason="No support None and Scalar in dict.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_return_3():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_3():
x = {'a': 'a', 'b': 'b'}
y = x.get('a')
z = dict(y=y, u=9, v=False, w=None)
return z
out = dict_net_3()
assert out == {'y': 'a', 'u': 9, 'v': False, 'w': None}
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training