merge global dict to local dict
This commit is contained in:
parent
28a2cf9855
commit
da6a478b78
|
@ -107,6 +107,7 @@
|
|||
"mindspore/tests/ut/cpp/python_input/gtest_input/mindir/mindir_test.py" "unused-variable"
|
||||
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_primitive.py" "super-init-not-called"
|
||||
"mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate" "unused-variable"
|
||||
"mindspore/tests/ut/cpp/python_input/gtest_input/optimizer/pyinterpret_dict_convert_test.py" "unused-variable"
|
||||
"mindspore/tests/ut/cpp/python_input/gtest_input/tbe" "unused-variable"
|
||||
"mindspore/tests/ut/python" "c-extension-no-member"
|
||||
"mindspore/tests/ut/python/train/summary/test_summary_abnormal_input.py" "bare-except"
|
||||
|
|
|
@ -36,30 +36,163 @@ namespace mindspore {
|
|||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
namespace {
|
||||
py::object CallPythonPushGlobalParams(const py::object &dict) {
|
||||
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||
constexpr auto python_merge_dict = "merge_global_params";
|
||||
return python_adapter::CallPyModFn(mod, python_merge_dict, dict);
|
||||
CNodePtr Transform(const CNodePtr &cnode, const FuncGraphManagerPtr &manager);
|
||||
AnfNodePtr NewValueNodeWithAbstract(const ValuePtr &value) {
|
||||
auto value_node = NewValueNode(value);
|
||||
value_node->set_abstract(value->ToAbstract());
|
||||
return value_node;
|
||||
}
|
||||
|
||||
void FuncGraphToPyData(const ValueDictionaryPtr &value_dict, py::object *global_params_dict) {
|
||||
MS_EXCEPTION_IF_NULL(value_dict);
|
||||
MS_EXCEPTION_IF_NULL(global_params_dict);
|
||||
for (const auto &element : value_dict->value()) {
|
||||
const auto &element_name = element.first;
|
||||
const auto &element_abs = element.second;
|
||||
if (element_abs->IsFromTypeId(FuncGraph::kTypeId)) {
|
||||
auto fg = element_abs->cast<FuncGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto wrapper_obj = fg->python_obj();
|
||||
if (wrapper_obj != nullptr && wrapper_obj->isa<parse::PyObjectWrapper>()) {
|
||||
auto fn_py_obj = wrapper_obj->cast_ptr<parse::PyObjectWrapper>()->obj();
|
||||
(*global_params_dict)[ValueToPyData(element_name)] = fn_py_obj;
|
||||
MS_LOG(DEBUG) << "Found python function object for " << element_name << ", add it to global dict.";
|
||||
}
|
||||
AnfNodePtr FuncGraphToPyData(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<ValueNode>()) {
|
||||
return node;
|
||||
}
|
||||
auto value_node = node->cast_ptr<ValueNode>();
|
||||
auto value = value_node->value();
|
||||
if (value->IsFromTypeId(FuncGraph::kTypeId)) {
|
||||
auto fg = value->cast_ptr<FuncGraph>();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto wrapper_obj = fg->python_obj();
|
||||
if (wrapper_obj != nullptr && wrapper_obj->isa<parse::PyObjectWrapper>()) {
|
||||
return NewValueNode(
|
||||
std::make_shared<parse::InterpretedObject>(wrapper_obj->cast_ptr<parse::PyObjectWrapper>()->obj()));
|
||||
}
|
||||
}
|
||||
return;
|
||||
return node;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> ConvertValueTupleToList(const AnfNodePtr &node) {
|
||||
if ((!IsValueNode<ValueTuple>(node) && !IsPrimitiveCNode(node, prim::kPrimMakeTuple))) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The dictionary's keys and values should be a tuple, but got " << node->DebugString();
|
||||
}
|
||||
std::vector<AnfNodePtr> node_list;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
auto cnode = node->cast_ptr<CNode>();
|
||||
auto inputs = cnode->inputs();
|
||||
std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(node_list));
|
||||
return node_list;
|
||||
}
|
||||
auto tuple_value = GetValueNode<ValueTuplePtr>(node);
|
||||
auto value_list = tuple_value->value();
|
||||
std::transform(value_list.begin(), value_list.end(), std::back_inserter(node_list),
|
||||
[](const ValuePtr &value) -> AnfNodePtr { return NewValueNodeWithAbstract(value); });
|
||||
return node_list;
|
||||
}
|
||||
|
||||
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> UnZippedDict(const AnfNodePtr &dict_node) {
|
||||
MS_EXCEPTION_IF_NULL(dict_node);
|
||||
if (dict_node->isa<ValueNode>()) {
|
||||
auto dict_value = GetValueNode<ValueDictionaryPtr>(dict_node);
|
||||
if (dict_node == nullptr) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret local dict or global dict should be a dictionary, but got "
|
||||
<< dict_node->DebugString();
|
||||
}
|
||||
std::vector<AnfNodePtr> keys;
|
||||
std::vector<AnfNodePtr> values;
|
||||
for (auto item : dict_value->value()) {
|
||||
(void)keys.emplace_back(NewValueNodeWithAbstract(item.first));
|
||||
(void)values.emplace_back(NewValueNodeWithAbstract(item.second));
|
||||
}
|
||||
return std::make_pair(keys, values);
|
||||
}
|
||||
|
||||
auto make_dict_node = dict_node->cast_ptr<CNode>();
|
||||
if (!IsPrimitiveCNode(dict_node, prim::kPrimMakeDict)) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret local dict or global dict should be a dictionary, but got "
|
||||
<< dict_node->DebugString();
|
||||
}
|
||||
constexpr auto kMakeDictKeysInputIndex = 1;
|
||||
constexpr auto kMakeDictValueInputIndex = 2;
|
||||
auto keys_input = make_dict_node->input(kMakeDictKeysInputIndex);
|
||||
auto values_input = make_dict_node->input(kMakeDictValueInputIndex);
|
||||
|
||||
auto keys_list = ConvertValueTupleToList(keys_input);
|
||||
auto values_list = ConvertValueTupleToList(values_input);
|
||||
return std::make_pair(keys_list, values_list);
|
||||
}
|
||||
|
||||
std::set<std::string> GetLocalKeySet(const std::vector<AnfNodePtr> &key_node_list) {
|
||||
std::set<std::string> key_set;
|
||||
std::transform(key_node_list.begin(), key_node_list.end(), std::inserter(key_set, key_set.begin()),
|
||||
[](const AnfNodePtr &node) -> std::string {
|
||||
auto abs = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
auto value = abs->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
return GetValue<std::string>(value);
|
||||
});
|
||||
return key_set;
|
||||
}
|
||||
|
||||
// Merge global dict to local dict and return merged key and value
|
||||
std::pair<AnfNodePtr, AnfNodePtr> MergeGlobalDictToLocal(const AnfNodePtr &global_dict_node,
|
||||
const AnfNodePtr &local_dict_node,
|
||||
const FuncGraphPtr &func_graph,
|
||||
const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(global_dict_node);
|
||||
MS_EXCEPTION_IF_NULL(local_dict_node);
|
||||
auto [global_keys, global_values] = UnZippedDict(global_dict_node);
|
||||
auto [local_keys, local_values] = UnZippedDict(local_dict_node);
|
||||
|
||||
auto local_dict_keys_set = GetLocalKeySet(local_keys);
|
||||
|
||||
std::vector<AnfNodePtr> local_keys_inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
std::vector<AnfNodePtr> local_value_inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
for (size_t index = 0; index < global_keys.size(); ++index) {
|
||||
auto global_key = global_keys.at(index);
|
||||
MS_EXCEPTION_IF_NULL(global_key);
|
||||
auto key = GetValueNode<StringImmPtr>(global_key);
|
||||
if (local_dict_keys_set.find(GetValue<std::string>(key)) != local_dict_keys_set.end()) {
|
||||
MS_LOG(INFO) << "The global dict has the same name with local dict.:" << key->ToString();
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "The global key " << global_key->DebugString() << ", value "
|
||||
<< global_values.at(index)->DebugString() << ". merged in local dict.";
|
||||
(void)local_keys_inputs.emplace_back(global_key);
|
||||
(void)local_value_inputs.emplace_back(FuncGraphToPyData(global_values.at(index)));
|
||||
}
|
||||
std::copy(local_keys.begin(), local_keys.end(), std::back_inserter(local_keys_inputs));
|
||||
std::transform(local_values.begin(), local_values.end(), std::back_inserter(local_value_inputs),
|
||||
[&manager, &func_graph](const AnfNodePtr &node) -> AnfNodePtr {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
|
||||
return node;
|
||||
}
|
||||
auto trans_node = Transform(node->cast<CNodePtr>(), manager);
|
||||
(void)manager->Replace(node, trans_node);
|
||||
return trans_node;
|
||||
});
|
||||
return std::make_pair(func_graph->NewCNode(local_keys_inputs), func_graph->NewCNode(local_value_inputs));
|
||||
}
|
||||
|
||||
CNodePtr Transform(const CNodePtr &cnode, const FuncGraphManagerPtr &manager) {
|
||||
constexpr auto input_index_one = 1;
|
||||
constexpr auto input_index_two = 2;
|
||||
constexpr auto input_index_three = 3;
|
||||
auto new_cnode = std::make_shared<CNode>(*cnode);
|
||||
new_cnode->CloneUserData(cnode);
|
||||
new_cnode->set_input(0, NewValueNode(prim::kPrimPyExecute));
|
||||
|
||||
if (!IsValueNode<parse::Script>(cnode->input(input_index_one))) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The first input should be a Script, but got "
|
||||
<< cnode->input(input_index_one)->DebugString();
|
||||
}
|
||||
const auto &script = GetValueNode<std::shared_ptr<parse::Script>>(cnode->input(input_index_one));
|
||||
const auto &script_str = script->script();
|
||||
const auto &script_strimm_node = NewValueNode(std::make_shared<StringImm>(script_str));
|
||||
new_cnode->set_input(input_index_one, script_strimm_node);
|
||||
auto global_dict_node = cnode->input(input_index_two);
|
||||
auto local_dict_node = cnode->input(input_index_three);
|
||||
|
||||
auto [local_dict_keys, local_dict_values] =
|
||||
MergeGlobalDictToLocal(global_dict_node, local_dict_node, cnode->func_graph(), manager);
|
||||
|
||||
new_cnode->set_input(input_index_two, local_dict_keys);
|
||||
new_cnode->set_input(input_index_three, local_dict_values);
|
||||
|
||||
// Record the PyExecute node.
|
||||
InterpretNodeRecorder::GetInstance().PushPyExecuteNode(new_cnode);
|
||||
return new_cnode;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -68,109 +201,19 @@ void FuncGraphToPyData(const ValueDictionaryPtr &value_dict, py::object *global_
|
|||
// -->
|
||||
// PyExecute(script, local_dict_keys, local_dict_values),
|
||||
// with side-effect operation:
|
||||
// Push global_dict into global parameters list.
|
||||
// (So it requires no same key name.)
|
||||
// Merge global_dict to local dict.
|
||||
// If there are arguments in global dict and local dict use local dict argument instead of global dict.
|
||||
bool PyInterpretToExecute(const pipeline::ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto manager = resource->manager();
|
||||
const auto &all_nodes = manager->all_nodes();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto transact = manager->Transact();
|
||||
constexpr auto input_index_one = 1;
|
||||
constexpr auto input_index_two = 2;
|
||||
constexpr auto input_index_three = 3;
|
||||
for (const auto &node : all_nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
|
||||
continue;
|
||||
const auto all_nodes = manager->all_nodes();
|
||||
for (const auto node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
|
||||
auto trans_node = Transform(node->cast<CNodePtr>(), manager);
|
||||
transact.Replace(node, trans_node);
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_LOG(DEBUG) << "cnode: " << cnode->DebugString();
|
||||
auto new_cnode = std::make_shared<CNode>(*cnode);
|
||||
new_cnode->CloneUserData(cnode);
|
||||
new_cnode->set_input(0, NewValueNode(prim::kPrimPyExecute));
|
||||
|
||||
if (!IsValueNode<parse::Script>(cnode->input(input_index_one))) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The first input should be a Script, but got "
|
||||
<< cnode->input(input_index_one)->DebugString();
|
||||
}
|
||||
const auto &script = GetValueNode<std::shared_ptr<parse::Script>>(cnode->input(input_index_one));
|
||||
const auto &script_str = script->script();
|
||||
const auto &script_strimm_node = NewValueNode(std::make_shared<StringImm>(script_str));
|
||||
new_cnode->set_input(input_index_one, script_strimm_node);
|
||||
|
||||
if (!IsValueNode<ValueDictionary>(cnode->input(input_index_two))) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The second input should be a dictionary, but got "
|
||||
<< cnode->input(input_index_two)->DebugString();
|
||||
}
|
||||
const auto &global_dict = GetValueNode<ValueDictionaryPtr>(cnode->input(input_index_two));
|
||||
auto value_dict = global_dict->cast<ValueDictionaryPtr>();
|
||||
py::object py_global_dict = ValueToPyData(global_dict);
|
||||
FuncGraphToPyData(value_dict, &py_global_dict);
|
||||
MS_LOG(DEBUG) << "py_global_dict: " << py::str(py_global_dict);
|
||||
(void)CallPythonPushGlobalParams(py_global_dict);
|
||||
|
||||
const auto &three_input = cnode->input(input_index_three);
|
||||
if (!IsPrimitiveCNode(three_input, prim::kPrimMakeDict) && !IsValueNode<ValueDictionary>(three_input)) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The 3rd input should be a dictionary, but got " << three_input->DebugString();
|
||||
}
|
||||
AnfNodePtr local_dict_keys = nullptr;
|
||||
AnfNodePtr local_dict_values = nullptr;
|
||||
if (IsValueNode<ValueDictionary>(three_input)) {
|
||||
const auto &local_dict = GetValueNode<ValueDictionaryPtr>(three_input);
|
||||
auto value_local_dict = local_dict->cast<ValueDictionaryPtr>();
|
||||
if (value_local_dict->value().empty()) {
|
||||
auto str_value = std::make_shared<StringImm>("None");
|
||||
std::vector<ValuePtr> none_value{str_value};
|
||||
const auto none_tuple = std::make_shared<ValueTuple>(none_value);
|
||||
auto none_tuple_node = NewValueNode(none_tuple);
|
||||
local_dict_keys = none_tuple_node;
|
||||
local_dict_values = none_tuple_node;
|
||||
} else {
|
||||
std::vector<ValuePtr> key_vec;
|
||||
std::vector<ValuePtr> value_vec;
|
||||
for (auto key_value : value_local_dict->value()) {
|
||||
(void)key_vec.emplace_back(key_value.first);
|
||||
(void)value_vec.emplace_back(key_value.second);
|
||||
}
|
||||
local_dict_keys = NewValueNode(std::make_shared<ValueTuple>(key_vec));
|
||||
local_dict_values = NewValueNode(std::make_shared<ValueTuple>(value_vec));
|
||||
}
|
||||
} else {
|
||||
const auto &local_dict_cnode = dyn_cast<CNode>(three_input);
|
||||
MS_EXCEPTION_IF_NULL(local_dict_cnode);
|
||||
local_dict_keys = local_dict_cnode->input(input_index_one);
|
||||
local_dict_values = local_dict_cnode->input(input_index_two);
|
||||
}
|
||||
|
||||
if ((!IsValueNode<ValueTuple>(local_dict_keys) && !IsPrimitiveCNode(local_dict_keys, prim::kPrimMakeTuple)) ||
|
||||
(!IsValueNode<ValueTuple>(local_dict_values) && !IsPrimitiveCNode(local_dict_values, prim::kPrimMakeTuple))) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The dictionary's keys and values should be a tuple, but got "
|
||||
<< three_input->DebugString();
|
||||
}
|
||||
|
||||
if (local_dict_values->isa<CNode>()) {
|
||||
// Handle values and convert InterpretedObject element.
|
||||
const auto &make_tuple_cnode = dyn_cast<CNode>(local_dict_values);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_cnode);
|
||||
const auto fg = make_tuple_cnode->func_graph();
|
||||
for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
|
||||
// Convert InterpretedObject value node to PyExecute CNode.
|
||||
const auto &input = make_tuple_cnode->input(i);
|
||||
const auto &value = GetValueNode<parse::InterpretedObjectPtr>(input);
|
||||
if (value != nullptr) {
|
||||
const auto interpreted_value = dyn_cast<parse::InterpretedObject>(value);
|
||||
const std::string &key = interpreted_value->name();
|
||||
const py::object &obj = interpreted_value->obj();
|
||||
const auto &interpreted_node = fallback::ConvertPyObjectToPyExecute(fg, key, obj, input, true);
|
||||
interpreted_node->set_debug_info(input->debug_info());
|
||||
(void)transact.Replace(input, interpreted_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
new_cnode->set_input(input_index_two, local_dict_keys);
|
||||
new_cnode->set_input(input_index_three, local_dict_values);
|
||||
(void)transact.Replace(cnode, new_cnode);
|
||||
|
||||
// Record the PyExecute node.
|
||||
InterpretNodeRecorder::GetInstance().PushPyExecuteNode(new_cnode);
|
||||
}
|
||||
transact.Commit();
|
||||
return true;
|
||||
|
|
|
@ -66,12 +66,6 @@ struct PyExecuteUserDataCatcherRegister {
|
|||
} // namespace pyexecute_user_data_catcher
|
||||
} // namespace abstract
|
||||
|
||||
static py::object CallPythonGetGlobalParams() {
|
||||
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||
constexpr auto python_get_dict = "get_global_params";
|
||||
return python_adapter::CallPyModFn(mod, python_get_dict);
|
||||
}
|
||||
|
||||
bool ContainStubTensor(const py::object &obj) {
|
||||
if (py::isinstance<py::list>(obj)) {
|
||||
auto list_obj = py::cast<py::list>(obj);
|
||||
|
@ -156,9 +150,8 @@ class PyExecuteInitializer {
|
|||
}
|
||||
}
|
||||
const auto &py_script = py::str(script_str->value());
|
||||
const auto &global_dict = CallPythonGetGlobalParams();
|
||||
auto params = py::tuple(number_two);
|
||||
params[0] = global_dict;
|
||||
params[0] = py::dict();
|
||||
params[1] = local_dict;
|
||||
MS_LOG(DEBUG) << "Python script: " << py_script << ", local_dict: " << local_dict;
|
||||
try {
|
||||
|
|
|
@ -26,7 +26,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
|
||||
convert_to_ms_cootensor, convert_class_to_function, convert_cell_list_to_sequence, is_cell_list,
|
||||
get_obj_from_sequence, get_type, is_class_member_recursive, merge_global_params, get_global_params,
|
||||
get_obj_from_sequence, get_type, is_class_member_recursive, get_global_params,
|
||||
get_adapter_tensor_attr, get_obj_defined_from_obj_type,
|
||||
is_from_third_party_library, get_const_abs, get_const_round,
|
||||
get_const_len, is_adapter_tensor_class, is_adapter_parameter_class)
|
||||
|
|
|
@ -822,15 +822,9 @@ def get_script_id_attrs(script):
|
|||
return res
|
||||
|
||||
|
||||
def merge_global_params(global_dict):
|
||||
"""Merge the global parameter."""
|
||||
logger.debug(f'merge global_dict: {global_dict}')
|
||||
_global_params.update(global_dict)
|
||||
|
||||
|
||||
def get_global_params():
|
||||
"""Get the global parameter."""
|
||||
logger.debug(f'get global_dict: {_global_params}')
|
||||
logger.debug(f"get global_dict: {_global_params}")
|
||||
return _global_params
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
|
||||
#include "ir/anf.h"
|
||||
|
@ -25,6 +26,7 @@
|
|||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#include "frontend/optimizer/py_interpret_to_execute.h"
|
||||
#include "include/common/debug/draw.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -38,6 +40,55 @@ class TestOptOptimizer : public UT::Common {
|
|||
irpass::OptimizeIRPassLib irpass;
|
||||
};
|
||||
|
||||
class TestPyInterpretToPyExecute : public BackendCommon {
|
||||
public:
|
||||
TestPyInterpretToPyExecute() : getPyFun("gtest_input.optimizer.pyinterpret_dict_convert_test", true) {}
|
||||
~TestPyInterpretToPyExecute() override = default;
|
||||
UT::PyFuncGraphFetcher getPyFun;
|
||||
|
||||
void ChangeStringToScript(const pipeline::ResourcePtr &resource) {
|
||||
auto trans = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(trans);
|
||||
auto nodes = trans->all_nodes();
|
||||
for (const auto &node : nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
|
||||
auto constexpr kScriptInputIdx = 1;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto script_str_node = cnode->input(kScriptInputIdx);
|
||||
auto script_string = GetValueNode<StringImmPtr>(script_str_node);
|
||||
auto script = script_string->value();
|
||||
auto script_node = NewValueNode(std::make_shared<parse::Script>(script));
|
||||
cnode->set_input(kScriptInputIdx, script_node);
|
||||
}
|
||||
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto value_node = node->cast_ptr<ValueNode>();
|
||||
auto value = value_node->value();
|
||||
value_node->set_abstract(value->ToAbstract());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Feature: Test global dict merge to local dict.
|
||||
/// Description: PyInterpret convert to PyExecute and merged arguments with local.
|
||||
/// Expectation: success.
|
||||
TEST_F(TestPyInterpretToPyExecute, test_pyinterpret_to_pyexecute) {
|
||||
FuncGraphPtr before = getPyFun.CallAndParseRet("py_interpret_to_py_execute_test", "before");
|
||||
ASSERT_TRUE(nullptr != before);
|
||||
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
|
||||
res->set_func_graph(before);
|
||||
auto manager = res->manager();
|
||||
manager->KeepRoots({before});
|
||||
ChangeStringToScript(res);
|
||||
|
||||
PyInterpretToExecute(res);
|
||||
|
||||
auto correct_graph = getPyFun.CallAndParseRet("py_interpret_to_py_execute_test", "after");
|
||||
EXPECT_NE(correct_graph, nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(before, correct_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestOptOptimizer, test_step_opt) {
|
||||
FuncGraphPtr before = getPyFun("test_expandJ");
|
||||
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore.ops import Primitive
|
||||
|
||||
py_interpret = Primitive("PyInterpret")
|
||||
py_execute = Primitive("PyExecute")
|
||||
make_tuple = Primitive("MakeTuple")
|
||||
make_dict = Primitive("make_dict")
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fnDict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def py_interpret_to_py_execute_test(tag):
|
||||
""" test_split_bn_fusion """
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before():
|
||||
global_key = make_tuple("g_a", "g_b")
|
||||
global_value = make_tuple(1, 2)
|
||||
local_key = make_tuple("g_a", "a", "b")
|
||||
local_value = make_tuple(3, 4, 5)
|
||||
global_dict = make_dict(global_key, global_value)
|
||||
local_dict = make_dict(local_key, local_value)
|
||||
output = py_interpret("func(g_a, g_b, a, b)", global_dict, local_dict)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after():
|
||||
local_key = make_tuple("g_b", "g_a", "a", "b")
|
||||
local_value = make_tuple(2, 3, 4, 5)
|
||||
output = py_execute("func(g_a, g_b, a, b)", local_key, local_value)
|
||||
return output
|
||||
|
||||
return fns[tag]
|
Loading…
Reference in New Issue