merge global dict to local dict

This commit is contained in:
lianliguang 2023-08-25 10:16:49 +08:00
parent 28a2cf9855
commit da6a478b78
7 changed files with 273 additions and 135 deletions

View File

@ -107,6 +107,7 @@
"mindspore/tests/ut/cpp/python_input/gtest_input/mindir/mindir_test.py" "unused-variable"
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_primitive.py" "super-init-not-called"
"mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate" "unused-variable"
"mindspore/tests/ut/cpp/python_input/gtest_input/optimizer/pyinterpret_dict_convert_test.py" "unused-variable"
"mindspore/tests/ut/cpp/python_input/gtest_input/tbe" "unused-variable"
"mindspore/tests/ut/python" "c-extension-no-member"
"mindspore/tests/ut/python/train/summary/test_summary_abnormal_input.py" "bare-except"

View File

@ -36,30 +36,163 @@ namespace mindspore {
/* namespace to support opt */
namespace opt {
namespace {
py::object CallPythonPushGlobalParams(const py::object &dict) {
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
constexpr auto python_merge_dict = "merge_global_params";
return python_adapter::CallPyModFn(mod, python_merge_dict, dict);
CNodePtr Transform(const CNodePtr &cnode, const FuncGraphManagerPtr &manager);
AnfNodePtr NewValueNodeWithAbstract(const ValuePtr &value) {
auto value_node = NewValueNode(value);
value_node->set_abstract(value->ToAbstract());
return value_node;
}
void FuncGraphToPyData(const ValueDictionaryPtr &value_dict, py::object *global_params_dict) {
MS_EXCEPTION_IF_NULL(value_dict);
MS_EXCEPTION_IF_NULL(global_params_dict);
for (const auto &element : value_dict->value()) {
const auto &element_name = element.first;
const auto &element_abs = element.second;
if (element_abs->IsFromTypeId(FuncGraph::kTypeId)) {
auto fg = element_abs->cast<FuncGraphPtr>();
MS_EXCEPTION_IF_NULL(fg);
auto wrapper_obj = fg->python_obj();
if (wrapper_obj != nullptr && wrapper_obj->isa<parse::PyObjectWrapper>()) {
auto fn_py_obj = wrapper_obj->cast_ptr<parse::PyObjectWrapper>()->obj();
(*global_params_dict)[ValueToPyData(element_name)] = fn_py_obj;
MS_LOG(DEBUG) << "Found python function object for " << element_name << ", add it to global dict.";
}
AnfNodePtr FuncGraphToPyData(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<ValueNode>()) {
return node;
}
auto value_node = node->cast_ptr<ValueNode>();
auto value = value_node->value();
if (value->IsFromTypeId(FuncGraph::kTypeId)) {
auto fg = value->cast_ptr<FuncGraph>();
MS_EXCEPTION_IF_NULL(fg);
auto wrapper_obj = fg->python_obj();
if (wrapper_obj != nullptr && wrapper_obj->isa<parse::PyObjectWrapper>()) {
return NewValueNode(
std::make_shared<parse::InterpretedObject>(wrapper_obj->cast_ptr<parse::PyObjectWrapper>()->obj()));
}
}
return;
return node;
}
std::vector<AnfNodePtr> ConvertValueTupleToList(const AnfNodePtr &node) {
if ((!IsValueNode<ValueTuple>(node) && !IsPrimitiveCNode(node, prim::kPrimMakeTuple))) {
MS_LOG(INTERNAL_EXCEPTION) << "The dictionary's keys and values should be a tuple, but got " << node->DebugString();
}
std::vector<AnfNodePtr> node_list;
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
auto cnode = node->cast_ptr<CNode>();
auto inputs = cnode->inputs();
std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(node_list));
return node_list;
}
auto tuple_value = GetValueNode<ValueTuplePtr>(node);
auto value_list = tuple_value->value();
std::transform(value_list.begin(), value_list.end(), std::back_inserter(node_list),
[](const ValuePtr &value) -> AnfNodePtr { return NewValueNodeWithAbstract(value); });
return node_list;
}
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> UnZippedDict(const AnfNodePtr &dict_node) {
MS_EXCEPTION_IF_NULL(dict_node);
if (dict_node->isa<ValueNode>()) {
auto dict_value = GetValueNode<ValueDictionaryPtr>(dict_node);
if (dict_node == nullptr) {
MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret local dict or global dict should be a dictionary, but got "
<< dict_node->DebugString();
}
std::vector<AnfNodePtr> keys;
std::vector<AnfNodePtr> values;
for (auto item : dict_value->value()) {
(void)keys.emplace_back(NewValueNodeWithAbstract(item.first));
(void)values.emplace_back(NewValueNodeWithAbstract(item.second));
}
return std::make_pair(keys, values);
}
auto make_dict_node = dict_node->cast_ptr<CNode>();
if (!IsPrimitiveCNode(dict_node, prim::kPrimMakeDict)) {
MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret local dict or global dict should be a dictionary, but got "
<< dict_node->DebugString();
}
constexpr auto kMakeDictKeysInputIndex = 1;
constexpr auto kMakeDictValueInputIndex = 2;
auto keys_input = make_dict_node->input(kMakeDictKeysInputIndex);
auto values_input = make_dict_node->input(kMakeDictValueInputIndex);
auto keys_list = ConvertValueTupleToList(keys_input);
auto values_list = ConvertValueTupleToList(values_input);
return std::make_pair(keys_list, values_list);
}
std::set<std::string> GetLocalKeySet(const std::vector<AnfNodePtr> &key_node_list) {
std::set<std::string> key_set;
std::transform(key_node_list.begin(), key_node_list.end(), std::inserter(key_set, key_set.begin()),
[](const AnfNodePtr &node) -> std::string {
auto abs = node->abstract();
MS_EXCEPTION_IF_NULL(abs);
auto value = abs->BuildValue();
MS_EXCEPTION_IF_NULL(value);
return GetValue<std::string>(value);
});
return key_set;
}
// Merge global dict to local dict and return merged key and value
std::pair<AnfNodePtr, AnfNodePtr> MergeGlobalDictToLocal(const AnfNodePtr &global_dict_node,
const AnfNodePtr &local_dict_node,
const FuncGraphPtr &func_graph,
const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(global_dict_node);
MS_EXCEPTION_IF_NULL(local_dict_node);
auto [global_keys, global_values] = UnZippedDict(global_dict_node);
auto [local_keys, local_values] = UnZippedDict(local_dict_node);
auto local_dict_keys_set = GetLocalKeySet(local_keys);
std::vector<AnfNodePtr> local_keys_inputs{NewValueNode(prim::kPrimMakeTuple)};
std::vector<AnfNodePtr> local_value_inputs{NewValueNode(prim::kPrimMakeTuple)};
for (size_t index = 0; index < global_keys.size(); ++index) {
auto global_key = global_keys.at(index);
MS_EXCEPTION_IF_NULL(global_key);
auto key = GetValueNode<StringImmPtr>(global_key);
if (local_dict_keys_set.find(GetValue<std::string>(key)) != local_dict_keys_set.end()) {
MS_LOG(INFO) << "The global dict has the same name with local dict.:" << key->ToString();
continue;
}
MS_LOG(DEBUG) << "The global key " << global_key->DebugString() << ", value "
<< global_values.at(index)->DebugString() << ". merged in local dict.";
(void)local_keys_inputs.emplace_back(global_key);
(void)local_value_inputs.emplace_back(FuncGraphToPyData(global_values.at(index)));
}
std::copy(local_keys.begin(), local_keys.end(), std::back_inserter(local_keys_inputs));
std::transform(local_values.begin(), local_values.end(), std::back_inserter(local_value_inputs),
[&manager, &func_graph](const AnfNodePtr &node) -> AnfNodePtr {
if (!IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
return node;
}
auto trans_node = Transform(node->cast<CNodePtr>(), manager);
(void)manager->Replace(node, trans_node);
return trans_node;
});
return std::make_pair(func_graph->NewCNode(local_keys_inputs), func_graph->NewCNode(local_value_inputs));
}
CNodePtr Transform(const CNodePtr &cnode, const FuncGraphManagerPtr &manager) {
constexpr auto input_index_one = 1;
constexpr auto input_index_two = 2;
constexpr auto input_index_three = 3;
auto new_cnode = std::make_shared<CNode>(*cnode);
new_cnode->CloneUserData(cnode);
new_cnode->set_input(0, NewValueNode(prim::kPrimPyExecute));
if (!IsValueNode<parse::Script>(cnode->input(input_index_one))) {
MS_LOG(INTERNAL_EXCEPTION) << "The first input should be a Script, but got "
<< cnode->input(input_index_one)->DebugString();
}
const auto &script = GetValueNode<std::shared_ptr<parse::Script>>(cnode->input(input_index_one));
const auto &script_str = script->script();
const auto &script_strimm_node = NewValueNode(std::make_shared<StringImm>(script_str));
new_cnode->set_input(input_index_one, script_strimm_node);
auto global_dict_node = cnode->input(input_index_two);
auto local_dict_node = cnode->input(input_index_three);
auto [local_dict_keys, local_dict_values] =
MergeGlobalDictToLocal(global_dict_node, local_dict_node, cnode->func_graph(), manager);
new_cnode->set_input(input_index_two, local_dict_keys);
new_cnode->set_input(input_index_three, local_dict_values);
// Record the PyExecute node.
InterpretNodeRecorder::GetInstance().PushPyExecuteNode(new_cnode);
return new_cnode;
}
} // namespace
@ -68,109 +201,19 @@ void FuncGraphToPyData(const ValueDictionaryPtr &value_dict, py::object *global_
// -->
// PyExecute(script, local_dict_keys, local_dict_values),
// with side-effect operation:
// Push global_dict into global parameters list.
// (So it requires no same key name.)
// Merge global_dict to local dict.
// If there are arguments in global dict and local dict use local dict argument instead of global dict.
bool PyInterpretToExecute(const pipeline::ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
auto manager = resource->manager();
const auto &all_nodes = manager->all_nodes();
MS_EXCEPTION_IF_NULL(manager);
auto transact = manager->Transact();
constexpr auto input_index_one = 1;
constexpr auto input_index_two = 2;
constexpr auto input_index_three = 3;
for (const auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
continue;
const auto all_nodes = manager->all_nodes();
for (const auto node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
auto trans_node = Transform(node->cast<CNodePtr>(), manager);
transact.Replace(node, trans_node);
}
const auto &cnode = node->cast<CNodePtr>();
MS_LOG(DEBUG) << "cnode: " << cnode->DebugString();
auto new_cnode = std::make_shared<CNode>(*cnode);
new_cnode->CloneUserData(cnode);
new_cnode->set_input(0, NewValueNode(prim::kPrimPyExecute));
if (!IsValueNode<parse::Script>(cnode->input(input_index_one))) {
MS_LOG(INTERNAL_EXCEPTION) << "The first input should be a Script, but got "
<< cnode->input(input_index_one)->DebugString();
}
const auto &script = GetValueNode<std::shared_ptr<parse::Script>>(cnode->input(input_index_one));
const auto &script_str = script->script();
const auto &script_strimm_node = NewValueNode(std::make_shared<StringImm>(script_str));
new_cnode->set_input(input_index_one, script_strimm_node);
if (!IsValueNode<ValueDictionary>(cnode->input(input_index_two))) {
MS_LOG(INTERNAL_EXCEPTION) << "The second input should be a dictionary, but got "
<< cnode->input(input_index_two)->DebugString();
}
const auto &global_dict = GetValueNode<ValueDictionaryPtr>(cnode->input(input_index_two));
auto value_dict = global_dict->cast<ValueDictionaryPtr>();
py::object py_global_dict = ValueToPyData(global_dict);
FuncGraphToPyData(value_dict, &py_global_dict);
MS_LOG(DEBUG) << "py_global_dict: " << py::str(py_global_dict);
(void)CallPythonPushGlobalParams(py_global_dict);
const auto &three_input = cnode->input(input_index_three);
if (!IsPrimitiveCNode(three_input, prim::kPrimMakeDict) && !IsValueNode<ValueDictionary>(three_input)) {
MS_LOG(INTERNAL_EXCEPTION) << "The 3rd input should be a dictionary, but got " << three_input->DebugString();
}
AnfNodePtr local_dict_keys = nullptr;
AnfNodePtr local_dict_values = nullptr;
if (IsValueNode<ValueDictionary>(three_input)) {
const auto &local_dict = GetValueNode<ValueDictionaryPtr>(three_input);
auto value_local_dict = local_dict->cast<ValueDictionaryPtr>();
if (value_local_dict->value().empty()) {
auto str_value = std::make_shared<StringImm>("None");
std::vector<ValuePtr> none_value{str_value};
const auto none_tuple = std::make_shared<ValueTuple>(none_value);
auto none_tuple_node = NewValueNode(none_tuple);
local_dict_keys = none_tuple_node;
local_dict_values = none_tuple_node;
} else {
std::vector<ValuePtr> key_vec;
std::vector<ValuePtr> value_vec;
for (auto key_value : value_local_dict->value()) {
(void)key_vec.emplace_back(key_value.first);
(void)value_vec.emplace_back(key_value.second);
}
local_dict_keys = NewValueNode(std::make_shared<ValueTuple>(key_vec));
local_dict_values = NewValueNode(std::make_shared<ValueTuple>(value_vec));
}
} else {
const auto &local_dict_cnode = dyn_cast<CNode>(three_input);
MS_EXCEPTION_IF_NULL(local_dict_cnode);
local_dict_keys = local_dict_cnode->input(input_index_one);
local_dict_values = local_dict_cnode->input(input_index_two);
}
if ((!IsValueNode<ValueTuple>(local_dict_keys) && !IsPrimitiveCNode(local_dict_keys, prim::kPrimMakeTuple)) ||
(!IsValueNode<ValueTuple>(local_dict_values) && !IsPrimitiveCNode(local_dict_values, prim::kPrimMakeTuple))) {
MS_LOG(INTERNAL_EXCEPTION) << "The dictionary's keys and values should be a tuple, but got "
<< three_input->DebugString();
}
if (local_dict_values->isa<CNode>()) {
// Handle values and convert InterpretedObject element.
const auto &make_tuple_cnode = dyn_cast<CNode>(local_dict_values);
MS_EXCEPTION_IF_NULL(make_tuple_cnode);
const auto fg = make_tuple_cnode->func_graph();
for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
// Convert InterpretedObject value node to PyExecute CNode.
const auto &input = make_tuple_cnode->input(i);
const auto &value = GetValueNode<parse::InterpretedObjectPtr>(input);
if (value != nullptr) {
const auto interpreted_value = dyn_cast<parse::InterpretedObject>(value);
const std::string &key = interpreted_value->name();
const py::object &obj = interpreted_value->obj();
const auto &interpreted_node = fallback::ConvertPyObjectToPyExecute(fg, key, obj, input, true);
interpreted_node->set_debug_info(input->debug_info());
(void)transact.Replace(input, interpreted_node);
}
}
}
new_cnode->set_input(input_index_two, local_dict_keys);
new_cnode->set_input(input_index_three, local_dict_values);
(void)transact.Replace(cnode, new_cnode);
// Record the PyExecute node.
InterpretNodeRecorder::GetInstance().PushPyExecuteNode(new_cnode);
}
transact.Commit();
return true;

View File

@ -66,12 +66,6 @@ struct PyExecuteUserDataCatcherRegister {
} // namespace pyexecute_user_data_catcher
} // namespace abstract
static py::object CallPythonGetGlobalParams() {
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
constexpr auto python_get_dict = "get_global_params";
return python_adapter::CallPyModFn(mod, python_get_dict);
}
bool ContainStubTensor(const py::object &obj) {
if (py::isinstance<py::list>(obj)) {
auto list_obj = py::cast<py::list>(obj);
@ -156,9 +150,8 @@ class PyExecuteInitializer {
}
}
const auto &py_script = py::str(script_str->value());
const auto &global_dict = CallPythonGetGlobalParams();
auto params = py::tuple(number_two);
params[0] = global_dict;
params[0] = py::dict();
params[1] = local_dict;
MS_LOG(DEBUG) << "Python script: " << py_script << ", local_dict: " << local_dict;
try {

View File

@ -26,7 +26,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
convert_to_ms_cootensor, convert_class_to_function, convert_cell_list_to_sequence, is_cell_list,
get_obj_from_sequence, get_type, is_class_member_recursive, merge_global_params, get_global_params,
get_obj_from_sequence, get_type, is_class_member_recursive, get_global_params,
get_adapter_tensor_attr, get_obj_defined_from_obj_type,
is_from_third_party_library, get_const_abs, get_const_round,
get_const_len, is_adapter_tensor_class, is_adapter_parameter_class)

View File

@ -822,15 +822,9 @@ def get_script_id_attrs(script):
return res
def merge_global_params(global_dict):
"""Merge the global parameter."""
logger.debug(f'merge global_dict: {global_dict}')
_global_params.update(global_dict)
def get_global_params():
"""Get the global parameter."""
logger.debug(f'get global_dict: {_global_params}')
logger.debug(f"get global_dict: {_global_params}")
return _global_params

View File

@ -17,6 +17,7 @@
#include <memory>
#include "common/common_test.h"
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "ir/anf.h"
@ -25,6 +26,7 @@
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/py_interpret_to_execute.h"
#include "include/common/debug/draw.h"
namespace mindspore {
@ -38,6 +40,55 @@ class TestOptOptimizer : public UT::Common {
irpass::OptimizeIRPassLib irpass;
};
class TestPyInterpretToPyExecute : public BackendCommon {
public:
TestPyInterpretToPyExecute() : getPyFun("gtest_input.optimizer.pyinterpret_dict_convert_test", true) {}
~TestPyInterpretToPyExecute() override = default;
UT::PyFuncGraphFetcher getPyFun;
void ChangeStringToScript(const pipeline::ResourcePtr &resource) {
auto trans = resource->manager();
MS_EXCEPTION_IF_NULL(trans);
auto nodes = trans->all_nodes();
for (const auto &node : nodes) {
if (IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
auto constexpr kScriptInputIdx = 1;
auto cnode = node->cast<CNodePtr>();
auto script_str_node = cnode->input(kScriptInputIdx);
auto script_string = GetValueNode<StringImmPtr>(script_str_node);
auto script = script_string->value();
auto script_node = NewValueNode(std::make_shared<parse::Script>(script));
cnode->set_input(kScriptInputIdx, script_node);
}
if (node->isa<ValueNode>()) {
auto value_node = node->cast_ptr<ValueNode>();
auto value = value_node->value();
value_node->set_abstract(value->ToAbstract());
}
}
}
};
/// Feature: Test global dict merge to local dict.
/// Description: PyInterpret convert to PyExecute and merged arguments with local.
/// Expectation: success.
TEST_F(TestPyInterpretToPyExecute, test_pyinterpret_to_pyexecute) {
FuncGraphPtr before = getPyFun.CallAndParseRet("py_interpret_to_py_execute_test", "before");
ASSERT_TRUE(nullptr != before);
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
res->set_func_graph(before);
auto manager = res->manager();
manager->KeepRoots({before});
ChangeStringToScript(res);
PyInterpretToExecute(res);
auto correct_graph = getPyFun.CallAndParseRet("py_interpret_to_py_execute_test", "after");
EXPECT_NE(correct_graph, nullptr);
EXPECT_TRUE(CheckEqualGraph(before, correct_graph));
}
TEST_F(TestOptOptimizer, test_step_opt) {
FuncGraphPtr before = getPyFun("test_expandJ");

View File

@ -0,0 +1,56 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
py_interpret = Primitive("PyInterpret")
py_execute = Primitive("PyExecute")
make_tuple = Primitive("MakeTuple")
make_dict = Primitive("make_dict")
class FnDict:
def __init__(self):
self.fnDict = {}
def __call__(self, fn):
self.fnDict[fn.__name__] = fn
def __getitem__(self, name):
return self.fnDict[name]
def py_interpret_to_py_execute_test(tag):
""" test_split_bn_fusion """
fns = FnDict()
@fns
def before():
global_key = make_tuple("g_a", "g_b")
global_value = make_tuple(1, 2)
local_key = make_tuple("g_a", "a", "b")
local_value = make_tuple(3, 4, 5)
global_dict = make_dict(global_key, global_value)
local_dict = make_dict(local_key, local_value)
output = py_interpret("func(g_a, g_b, a, b)", global_dict, local_dict)
return output
@fns
def after():
local_key = make_tuple("g_b", "g_a", "a", "b")
local_value = make_tuple(2, 3, 4, 5)
output = py_execute("func(g_a, g_b, a, b)", local_key, local_value)
return output
return fns[tag]