forked from mindspore-Ecosystem/mindspore
766 lines
22 KiB
C++
766 lines
22 KiB
C++
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include "common/common_test.h"
|
|
#include "common/py_func_graph_fetcher.h"
|
|
#include "ir/dtype.h"
|
|
#include "ir/manager.h"
|
|
#include "ir/func_graph_cloner.h"
|
|
#include "pipeline/jit/parse/parse.h"
|
|
#include "frontend/operator/ops.h"
|
|
#include "utils/log_adapter.h"
|
|
#include "include/common/debug/draw.h"
|
|
#include "utils/label.h"
|
|
|
|
namespace mindspore {
|
|
|
|
namespace {
|
|
std::vector<std::string> SplitString(std::string str, std::string pattern) {
|
|
std::string::size_type pos;
|
|
std::vector<std::string> result;
|
|
str += pattern;
|
|
std::string::size_type size = str.size();
|
|
|
|
for (std::string::size_type i = 0; i < size; ++i) {
|
|
pos = str.find(pattern, i);
|
|
if (pos < size) {
|
|
std::string s = str.substr(i, pos - i);
|
|
result.push_back(s);
|
|
i = pos + pattern.size() - 1;
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
} // namespace
|
|
using std::dynamic_pointer_cast;
|
|
|
|
using TodoList = std::vector<std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>>;
|
|
using TodoListItem = std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
|
|
|
|
class NestingSpecs;
|
|
|
|
class Stage {
|
|
public:
|
|
explicit Stage(std::vector<std::string> specs) {
|
|
for (auto arg : specs) {
|
|
auto spec = SplitString(arg, "=");
|
|
if (spec.size() <= 1) {
|
|
continue;
|
|
}
|
|
std::shared_ptr<NestingSpecs> nesting = std::make_shared<NestingSpecs>(this, spec[1]);
|
|
specs_[ToFullString(spec[0])] = nesting;
|
|
}
|
|
}
|
|
|
|
~Stage() {}
|
|
|
|
std::map<std::string, std::string> &subs() { return subs_; }
|
|
|
|
void set_subs(const std::map<std::string, std::string> &subs) { subs_ = subs; }
|
|
|
|
private:
|
|
std::string ToFullString(std::string s) {
|
|
if (s.find("fv") != std::string::npos) {
|
|
s = s.replace(s.find("fv"), 2, "free_variable");
|
|
}
|
|
|
|
if (s.find("deps") != std::string::npos) {
|
|
s = s.replace(s.find("deps"), 4, "dependencies");
|
|
}
|
|
|
|
return s;
|
|
}
|
|
|
|
std::map<std::string, std::shared_ptr<NestingSpecs>> specs_;
|
|
std::map<std::string, std::string> subs_;
|
|
};
|
|
|
|
class NestingSpecs {
|
|
public:
|
|
NestingSpecs(Stage *stage, std::string specs) : stage_(stage) { ParseSpecs(specs); }
|
|
|
|
~NestingSpecs() {}
|
|
|
|
std::string Name(Any node) {
|
|
std::string name = label_manage::Label(node.cast<AnfNodePtr>()->debug_info());
|
|
if (stage_->subs().find(name) != stage_->subs().end()) {
|
|
return stage_->subs()[name];
|
|
}
|
|
|
|
return name;
|
|
}
|
|
|
|
void Check(std::shared_ptr<DepComputer> results) {
|
|
if (expected_.empty() && expected_recursive_.empty()) {
|
|
return;
|
|
}
|
|
|
|
auto parent = dynamic_pointer_cast<ParentComputer>(results);
|
|
if (parent != nullptr) {
|
|
CheckParent(parent);
|
|
return;
|
|
}
|
|
|
|
auto recursive = dynamic_pointer_cast<RecursiveComputer>(results);
|
|
if (recursive != nullptr) {
|
|
CheckRecursive(recursive);
|
|
return;
|
|
}
|
|
}
|
|
|
|
private:
|
|
void ParseSpecs(std::string specs) {
|
|
if (specs.empty()) {
|
|
return;
|
|
}
|
|
|
|
std::vector<std::string> str_list = SplitString(specs, ";");
|
|
for (auto spec : str_list) {
|
|
spec.erase(0, spec.find_first_not_of(" "));
|
|
spec.erase(spec.find_last_not_of(" ") + 1);
|
|
if (spec.empty()) {
|
|
continue;
|
|
}
|
|
if (spec.find("->") != std::string::npos) {
|
|
auto substr = SplitString(spec, "->");
|
|
ASSERT_GT(substr.size(), 1);
|
|
auto key = substr[0];
|
|
auto value = substr[1];
|
|
if (!value.empty()) {
|
|
expected_[key] = {value};
|
|
}
|
|
} else if (spec.find(":") != std::string::npos) {
|
|
auto substr = SplitString(spec, ":");
|
|
ASSERT_GT(substr.size(), 1);
|
|
auto key = substr[0];
|
|
auto values = SplitString(substr[1], ",");
|
|
std::set<std::string> values_set(values.begin(), values.end());
|
|
if (!values_set.empty()) {
|
|
expected_[key] = values_set;
|
|
}
|
|
} else {
|
|
expected_recursive_[spec] = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
void CheckParent(std::shared_ptr<ParentComputer> results) {
|
|
std::map<std::string, std::set<std::string>> clean_results;
|
|
for (auto &iter : results->parent_analysis()) {
|
|
auto key = iter.first;
|
|
auto value = iter.second;
|
|
if (key == nullptr) {
|
|
continue;
|
|
}
|
|
std::string k = Name(key);
|
|
|
|
std::set<std::string> v;
|
|
if (value != nullptr && !Name(value).empty()) {
|
|
v.insert(Name(value));
|
|
}
|
|
|
|
if (!v.empty()) {
|
|
clean_results[k] = v;
|
|
}
|
|
}
|
|
|
|
ASSERT_EQ(clean_results, expected_);
|
|
}
|
|
|
|
void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
|
|
std::map<std::string, bool> clean_results;
|
|
for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {
|
|
auto key = iter->first;
|
|
auto value = iter->second;
|
|
if (key == nullptr) {
|
|
continue;
|
|
}
|
|
std::string k = Name(key);
|
|
|
|
clean_results[k] = value;
|
|
}
|
|
|
|
ASSERT_EQ(clean_results, expected_recursive_);
|
|
}
|
|
|
|
private:
|
|
Stage *stage_;
|
|
std::map<std::string, std::set<std::string>> expected_;
|
|
std::map<std::string, bool> expected_recursive_;
|
|
};
|
|
|
|
bool CheckUsers(std::shared_ptr<FuncGraphManager> manager) {
|
|
for (auto node : manager->all_nodes()) {
|
|
if (node->isa<CNode>()) {
|
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
auto inp = inputs[i];
|
|
if (!manager->all_nodes().contains(inp)) {
|
|
return false;
|
|
}
|
|
|
|
if (manager->node_users().find(inp) != manager->node_users().end()) {
|
|
auto users = manager->node_users()[inp];
|
|
if (!users.contains(make_pair(node, i))) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (manager->node_users().find(node) != manager->node_users().end()) {
|
|
auto users = manager->node_users()[node];
|
|
for (auto iter = users.begin(); iter != users.end(); ++iter) {
|
|
auto node2 = iter->first;
|
|
auto key = iter->second;
|
|
if (!manager->all_nodes().contains(node2)) {
|
|
return false;
|
|
}
|
|
if (node2->cast<CNodePtr>()->input(key) != node) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
class TestManager : public UT::Common {
|
|
public:
|
|
TestManager() : getPyFun("gtest_input.ir.manager_test") {}
|
|
|
|
void CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng);
|
|
|
|
public:
|
|
std::vector<PrimitivePtr> swaps;
|
|
UT::PyFuncGraphFetcher getPyFun;
|
|
};
|
|
|
|
FuncGraphPtr MakeFuncGraph(PrimitivePtr prim) {
|
|
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
|
|
ParameterPtr x = func_graph->add_parameter();
|
|
ParameterPtr y = func_graph->add_parameter();
|
|
std::vector<AnfNodePtr> inputs;
|
|
inputs.push_back(NewValueNode(prim));
|
|
inputs.push_back(x);
|
|
inputs.push_back(y);
|
|
CNodePtr cnode_add = func_graph->NewCNode(inputs);
|
|
inputs.clear();
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
|
inputs.push_back(cnode_add);
|
|
CNodePtr cnode_return = func_graph->NewCNode(inputs);
|
|
func_graph->set_return(cnode_return);
|
|
return func_graph;
|
|
}
|
|
|
|
std::vector<FuncGraphPtr> MakeNestedGraph() {
|
|
/*
|
|
*def f(x):
|
|
* def g():
|
|
* return x
|
|
* return g
|
|
*/
|
|
FuncGraphPtr f = std::make_shared<FuncGraph>();
|
|
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
|
|
|
ParameterPtr x = f->add_parameter();
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
inputs.push_back(NewValueNode(fg));
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
|
|
|
CNodePtr cnode_f = f->NewCNode(inputs);
|
|
f->set_return(cnode_f);
|
|
|
|
inputs.clear();
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
|
inputs.push_back(x);
|
|
CNodePtr cnode_g = fg->NewCNode(inputs);
|
|
fg->set_return(cnode_g);
|
|
|
|
std::vector<FuncGraphPtr> result = {f, fg};
|
|
return result;
|
|
}
|
|
|
|
std::vector<FuncGraphPtr> MakeNestedGraph2() {
|
|
/* build a closure func_graph */
|
|
/*
|
|
*def foo(x, y):
|
|
* def bar(x1):
|
|
* return x1 + y
|
|
* return bar(x)
|
|
*/
|
|
FuncGraphPtr graph_foo = std::make_shared<FuncGraph>();
|
|
ParameterPtr x = graph_foo->add_parameter();
|
|
ParameterPtr y = graph_foo->add_parameter();
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
// build func_graph bar
|
|
FuncGraphPtr graph_bar = std::make_shared<FuncGraph>();
|
|
ParameterPtr x1 = graph_bar->add_parameter();
|
|
inputs.clear();
|
|
inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
|
|
inputs.push_back(x1);
|
|
inputs.push_back(y);
|
|
CNodePtr cnode_add = graph_bar->NewCNode(inputs);
|
|
inputs.clear();
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
|
inputs.push_back(cnode_add);
|
|
CNodePtr cnode_return = graph_bar->NewCNode(inputs);
|
|
graph_bar->set_return(cnode_return);
|
|
|
|
// build func_graph foo
|
|
inputs.clear();
|
|
inputs.push_back(NewValueNode(graph_bar));
|
|
inputs.push_back(x);
|
|
CNodePtr cnode_graph_bar = graph_foo->NewCNode(inputs);
|
|
|
|
inputs.clear();
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
|
inputs.push_back(cnode_graph_bar);
|
|
cnode_return = graph_foo->NewCNode(inputs);
|
|
graph_foo->set_return(cnode_return);
|
|
|
|
std::vector<FuncGraphPtr> result = {graph_foo, graph_bar};
|
|
return result;
|
|
}
|
|
|
|
// Add TestManager::CheckManager function to checkout the result
|
|
void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
|
|
auto size = mng->func_graphs().size();
|
|
|
|
ASSERT_EQ(size, mng->free_variables_total().size());
|
|
}
|
|
|
|
TEST_F(TestManager, test_scalar_add_manual) {
|
|
auto prim_scalar_add = prim::kPrimScalarAdd;
|
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
|
auto mng = Manage(func_graph);
|
|
}
|
|
|
|
TEST_F(TestManager, test_scalar_replace) {
|
|
auto prim_scalar_add = prim::kPrimScalarAdd;
|
|
|
|
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
|
|
ParameterPtr x = func_graph->add_parameter();
|
|
ParameterPtr y = func_graph->add_parameter();
|
|
std::vector<AnfNodePtr> inputs;
|
|
inputs.push_back(NewValueNode(prim_scalar_add));
|
|
inputs.push_back(x);
|
|
inputs.push_back(y);
|
|
CNodePtr cnode_add = func_graph->NewCNode(inputs);
|
|
inputs.clear();
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
|
inputs.push_back(cnode_add);
|
|
CNodePtr cnode_return = func_graph->NewCNode(inputs);
|
|
func_graph->set_return(cnode_return);
|
|
|
|
auto mng = Manage(func_graph);
|
|
std::cout << "start " << x->ToString() << std::endl;
|
|
mng->Replace(cnode_add, x);
|
|
}
|
|
|
|
TEST_F(TestManager, test_nested_manual) {
|
|
auto graphs = MakeNestedGraph();
|
|
auto f = graphs[0];
|
|
auto g = graphs[1];
|
|
|
|
auto mng = Manage(f);
|
|
|
|
ASSERT_EQ(6, mng->all_nodes().size());
|
|
ASSERT_EQ(2, mng->func_graphs().size());
|
|
ASSERT_EQ(4, mng->node_users().size());
|
|
ASSERT_EQ(1, mng->roots().size());
|
|
CheckAnalysisSize(mng);
|
|
|
|
ASSERT_EQ(2, f->nodes().size());
|
|
ASSERT_EQ(1, g->nodes().size());
|
|
|
|
auto &users = mng->node_users();
|
|
for (auto &iter : users) {
|
|
ASSERT_EQ(1, iter.second.size());
|
|
}
|
|
|
|
ASSERT_EQ(1, f->func_graphs_used().size());
|
|
ASSERT_EQ(0, g->func_graphs_used().size());
|
|
|
|
ASSERT_EQ(0, f->free_variables().size());
|
|
ASSERT_EQ(1, g->free_variables().size());
|
|
|
|
auto fv_total = mng->free_variables_total();
|
|
ASSERT_EQ(0, fv_total[f].size());
|
|
ASSERT_EQ(1, fv_total[g].size());
|
|
|
|
ASSERT_EQ(0, f->func_graph_cnodes_index().size());
|
|
ASSERT_EQ(1, g->func_graph_cnodes_index().size());
|
|
}
|
|
|
|
TEST_F(TestManager, test_deep_nested2_manual) {
|
|
// create parser
|
|
FuncGraphPtr func_graph = getPyFun("test_custom");
|
|
return;
|
|
|
|
// parse ast to func graph
|
|
FuncGraphPtr gfn = BasicClone(func_graph);
|
|
if (gfn == nullptr) {
|
|
return;
|
|
}
|
|
|
|
auto mng = Manage(gfn);
|
|
|
|
ASSERT_EQ(3, mng->func_graphs().size());
|
|
ASSERT_EQ(1, mng->roots().size());
|
|
ASSERT_EQ(4, gfn->nodes().size());
|
|
ASSERT_EQ(20, mng->all_nodes().size());
|
|
ASSERT_EQ(25, mng->node_users().size());
|
|
CheckAnalysisSize(mng);
|
|
}
|
|
|
|
TEST_F(TestManager, test_deep_nested_manual) {
|
|
FuncGraphPtr f = std::make_shared<FuncGraph>();
|
|
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
|
FuncGraphPtr h = std::make_shared<FuncGraph>();
|
|
|
|
ParameterPtr x = f->add_parameter();
|
|
ParameterPtr y = f->add_parameter();
|
|
ParameterPtr z = f->add_parameter();
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
inputs.push_back(NewValueNode(fg));
|
|
inputs.push_back(x);
|
|
inputs.push_back(y);
|
|
CNodePtr cnode_1 = f->NewCNode(inputs);
|
|
|
|
inputs.clear();
|
|
inputs.push_back(cnode_1);
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
|
CNodePtr cnode_0 = f->NewCNode(inputs);
|
|
f->set_return(cnode_0);
|
|
|
|
ParameterPtr x1 = fg->add_parameter();
|
|
ParameterPtr y1 = fg->add_parameter();
|
|
inputs.clear();
|
|
inputs.push_back(NewValueNode(h));
|
|
inputs.push_back(x1);
|
|
CNodePtr cnode_3 = fg->NewCNode(inputs);
|
|
|
|
inputs.clear();
|
|
inputs.push_back(cnode_3);
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
|
CNodePtr cnode_2 = fg->NewCNode(inputs);
|
|
fg->set_return(cnode_2);
|
|
|
|
ParameterPtr x2 = h->add_parameter();
|
|
|
|
inputs.clear();
|
|
inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
|
|
inputs.push_back(x2);
|
|
inputs.push_back(y1);
|
|
CNodePtr cnode_6 = h->NewCNode(inputs);
|
|
|
|
inputs.clear();
|
|
inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
|
|
inputs.push_back(z);
|
|
inputs.push_back(cnode_6);
|
|
CNodePtr cnode_5 = h->NewCNode(inputs);
|
|
|
|
inputs.clear();
|
|
inputs.push_back(cnode_5);
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
|
CNodePtr cnode_4 = h->NewCNode(inputs);
|
|
h->set_return(cnode_4);
|
|
|
|
auto mng = Manage(f);
|
|
|
|
ASSERT_EQ(3, mng->func_graphs().size());
|
|
ASSERT_EQ(1, mng->roots().size());
|
|
ASSERT_EQ(20, mng->all_nodes().size());
|
|
CheckAnalysisSize(mng);
|
|
}
|
|
|
|
TEST_F(TestManager, test_parent1_manual) {
|
|
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
|
|
|
Parameter param(fg);
|
|
std::vector<AnfNodePtr> params;
|
|
CNodePtr app = std::make_shared<CNode>(params, fg);
|
|
fg->set_return(app);
|
|
fg->set_parameters(params);
|
|
|
|
std::shared_ptr<FuncGraphManager> manager = MakeManager();
|
|
manager->AddFuncGraph(fg, true);
|
|
FuncGraphPtr p = fg->parent();
|
|
assert(p == nullptr);
|
|
}
|
|
|
|
TEST_F(TestManager, test_parent_manual) {
|
|
auto prim_scalar_add = prim::kPrimScalarAdd;
|
|
FuncGraphPtr fg = MakeFuncGraph(prim_scalar_add);
|
|
|
|
std::shared_ptr<FuncGraphManager> manager = MakeManager();
|
|
manager->AddFuncGraph(fg);
|
|
FuncGraphPtr p = fg->parent();
|
|
assert(p == nullptr);
|
|
}
|
|
|
|
TEST_F(TestManager, test_flat) {
|
|
std::vector<std::shared_ptr<Stage>> stages;
|
|
std::vector<std::string> specs = {"nodes=X:x", "parents=", "fvs_direct="};
|
|
std::map<std::string, int> size_list;
|
|
size_list["nodes"] = 2;
|
|
}
|
|
|
|
TEST_F(TestManager, test_nested) {
|
|
std::vector<std::shared_ptr<Stage>> stages;
|
|
std::vector<std::string> specs = {"nodes=X:x", "parent=g->X", "fvs_direct=g:x"};
|
|
std::map<std::string, int> size_list;
|
|
return;
|
|
}
|
|
|
|
TEST_F(TestManager, test_calls) {
|
|
std::vector<std::shared_ptr<Stage>> stages;
|
|
std::vector<std::string> specs = {"parents=g->X; h->X", "children=X:g,h", "scopes=X:X,g,h; g:g; h:h",
|
|
"fvs_direct=h:a", "fvs_total=h:a; g:h"};
|
|
std::map<std::string, int> size_list;
|
|
return;
|
|
}
|
|
|
|
TEST_F(TestManager, test_unused_param) {
|
|
std::vector<std::shared_ptr<Stage>> stages;
|
|
std::vector<std::string> specs = {"nodes=X:x,y"};
|
|
std::map<std::string, int> size_list;
|
|
}
|
|
|
|
TEST_F(TestManager, test_cannot_replace_return) {
|
|
FuncGraphPtr fg = getPyFun("test_cannot_replace_return");
|
|
ASSERT_NE(fg, nullptr);
|
|
|
|
auto mng = Manage(fg);
|
|
ASSERT_EQ(fg->manager(), mng);
|
|
|
|
ASSERT_NE(mng, nullptr);
|
|
ASSERT_GT(fg->parameters().size(), 0);
|
|
ASSERT_FALSE(mng->Replace(fg->get_return(), fg->parameters()[0]));
|
|
}
|
|
|
|
TEST_F(TestManager, test_weak_manager) {
|
|
FuncGraphPtr fg = getPyFun("ir_get_fn");
|
|
|
|
auto mng1 = MakeManager({fg}, false);
|
|
ASSERT_EQ(fg->manager(), nullptr);
|
|
auto mng2 = MakeManager({fg}, true);
|
|
ASSERT_EQ(fg->manager(), mng2);
|
|
auto mng3 = MakeManager({fg}, false);
|
|
ASSERT_EQ(fg->manager(), mng2);
|
|
}
|
|
|
|
TEST_F(TestManager, test_drop_root) {
|
|
FuncGraphPtr fg = getPyFun("ir_get_fn");
|
|
|
|
auto mng = Manage(fg);
|
|
const auto &fgs = mng->func_graphs();
|
|
ASSERT_TRUE(fgs.contains(fg));
|
|
FuncGraphSet s;
|
|
s.add(fg);
|
|
mng->MaybeDropFuncGraphs(s);
|
|
ASSERT_TRUE(fgs.contains(fg));
|
|
}
|
|
|
|
TEST_F(TestManager, test_keep_roots) {
|
|
FuncGraphPtr fg1 = getPyFun("ir_get_fn");
|
|
FuncGraphPtr fg2 = getPyFun("test_cannot_replace_return");
|
|
|
|
auto mng = Manage(fg1);
|
|
ASSERT_EQ(mng->func_graphs().size(), (size_t)1);
|
|
ASSERT_TRUE(mng->func_graphs().contains(fg1));
|
|
|
|
mng->AddFuncGraph(fg2);
|
|
ASSERT_EQ(mng->func_graphs().size(), 2);
|
|
ASSERT_TRUE(mng->func_graphs().contains(fg2));
|
|
|
|
mng->KeepRoots();
|
|
ASSERT_EQ(mng->func_graphs().size(), 1);
|
|
ASSERT_TRUE(mng->func_graphs().contains(fg1));
|
|
|
|
mng->KeepRoots({fg2});
|
|
ASSERT_EQ(mng->func_graphs().size(), 1);
|
|
ASSERT_TRUE(mng->func_graphs().contains(fg2));
|
|
}
|
|
|
|
TEST_F(TestManager, test_keep_roots_recursion) {
|
|
return;
|
|
|
|
FuncGraphPtr fg = getPyFun("test_keep_roots_recursion");
|
|
ASSERT_NE(fg, nullptr);
|
|
auto mng = Manage(fg);
|
|
parse::ResolveAll(mng);
|
|
|
|
ASSERT_NE(mng, nullptr);
|
|
ASSERT_EQ(mng->func_graphs().size(), 4);
|
|
|
|
ASSERT_GT(fg->parameters().size(), 0);
|
|
mng->Replace(fg->output(), fg->parameters()[0]);
|
|
ASSERT_EQ(mng->func_graphs().size(), 3);
|
|
|
|
mng->KeepRoots();
|
|
ASSERT_EQ(mng->func_graphs().size(), 1);
|
|
}
|
|
|
|
TEST_F(TestManager, test_add_edge_replace) {
|
|
// fg(x, y, u):
|
|
// x1 = load(x, u)
|
|
// a = add(x1, y)
|
|
// u1 = update_state(u, x1);
|
|
// out = depend(a, u1)
|
|
// return out
|
|
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
|
auto x = fg->add_parameter();
|
|
auto y = fg->add_parameter();
|
|
auto u = fg->add_parameter();
|
|
auto x1 = fg->NewCNode({NewValueNode(prim::kPrimLoad), x, u});
|
|
auto a = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
|
|
auto u1 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
|
|
auto out = fg->NewCNode({NewValueNode(prim::kPrimDepend), a, u1});
|
|
fg->set_output(out);
|
|
|
|
// Create manager.
|
|
auto mgr = Manage(fg);
|
|
ASSERT_NE(mgr, nullptr);
|
|
|
|
// Before AddEdge.
|
|
// a = add(x1, y)
|
|
// u1 = update_state(u, x1);
|
|
// out = depend(a, u1)
|
|
auto a_users = mgr->node_users()[a];
|
|
ASSERT_EQ(a_users.size(), 1);
|
|
|
|
mgr->AddEdge(u1, a);
|
|
|
|
// After AddEdge.
|
|
// a = add(x1, y)
|
|
// u1 = update_state(u, x1, a);
|
|
// out = depend(a, u1)
|
|
a_users = mgr->node_users()[a];
|
|
ASSERT_EQ(a_users.size(), 2);
|
|
|
|
// Remove edge by replace update_state.
|
|
auto u2 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
|
|
mgr->Replace(u1, u2);
|
|
|
|
// After replace update_state.
|
|
// a = add(x1, y)
|
|
// u2 = update_state(u, x1);
|
|
// out = depend(a, u2)
|
|
a_users = mgr->node_users()[a];
|
|
ASSERT_EQ(a_users.size(), 1);
|
|
|
|
mgr->AddEdge(u2, a);
|
|
|
|
// After AddEdge to u2.
|
|
// a = add(x1, y)
|
|
// u2 = update_state(u, x1, a);
|
|
// out = depend(a, u2)
|
|
a_users = mgr->node_users()[a];
|
|
ASSERT_EQ(a_users.size(), 2);
|
|
}
|
|
|
|
TEST_F(TestManager, test_add_edge_replace_new) {
|
|
// fg(x, y, u):
|
|
// x1 = load(x, u)
|
|
// a = add(x1, y)
|
|
// u1 = update_state(u, x1);
|
|
// out = depend(a, u1)
|
|
// return out
|
|
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
|
auto x = fg->add_parameter();
|
|
auto y = fg->add_parameter();
|
|
auto u = fg->add_parameter();
|
|
auto x1 = fg->NewCNode({NewValueNode(prim::kPrimLoad), x, u});
|
|
auto a = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
|
|
auto u1 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
|
|
auto out = fg->NewCNode({NewValueNode(prim::kPrimDepend), a, u1});
|
|
fg->set_output(out);
|
|
|
|
// Create manager.
|
|
auto mgr = Manage(fg);
|
|
ASSERT_NE(mgr, nullptr);
|
|
|
|
auto new_add = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
|
|
mgr->AddEdge(u1, new_add);
|
|
|
|
// x1 = load(x, u)
|
|
// a = add(x1, y)
|
|
// new_add = add(x1, y)
|
|
// u1 = update_state(u, x1, new_add);
|
|
// out = depend(a, u1)
|
|
// return out
|
|
ASSERT_EQ(mgr->node_users()[x1].size(), 3);
|
|
ASSERT_EQ(mgr->node_users()[y].size(), 2);
|
|
ASSERT_EQ(mgr->node_users()[new_add].size(), 1);
|
|
|
|
auto new_add1 = fg->NewCNode({NewValueNode(prim::kPrimAdd), y, y});
|
|
mgr->Replace(new_add, new_add1);
|
|
|
|
// x1 = load(x, u)
|
|
// a = add(x1, y)
|
|
// new_add1 = add(y, y)
|
|
// u1 = update_state(u, x1, new_add1);
|
|
// out = depend(a, u1)
|
|
// return out
|
|
ASSERT_EQ(mgr->node_users()[x1].size(), 2);
|
|
ASSERT_EQ(mgr->node_users()[y].size(), 3);
|
|
ASSERT_EQ(mgr->node_users()[new_add].size(), 0);
|
|
ASSERT_EQ(mgr->node_users()[new_add1].size(), 1);
|
|
}
|
|
|
|
TEST_F(TestManager, test_set_edge) {
|
|
// fg(x, y, u):
|
|
// t = make_tuple(x, y)
|
|
// d = depend(t, u);
|
|
// get_item = tuple_get_item(d, 0)
|
|
// return get_item
|
|
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
|
auto x = fg->add_parameter();
|
|
auto y = fg->add_parameter();
|
|
auto u = fg->add_parameter();
|
|
auto t = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), x, y});
|
|
auto d = fg->NewCNode({NewValueNode(prim::kPrimDepend), t, u});
|
|
auto get_item = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), d, NewValueNode(0)});
|
|
fg->set_output(get_item);
|
|
|
|
// Create manager.
|
|
auto mgr = Manage(fg);
|
|
ASSERT_NE(mgr, nullptr);
|
|
|
|
// Before SetEdge.
|
|
ASSERT_EQ(mgr->node_users()[t].size(), 1);
|
|
ASSERT_EQ(mgr->node_users()[d].size(), 1);
|
|
|
|
auto depend = get_item->input(1)->cast<CNodePtr>();
|
|
mgr->SetEdge(get_item, 1, depend->input(1));
|
|
|
|
// After SetEdge.
|
|
ASSERT_EQ(get_item->input(1), t);
|
|
ASSERT_EQ(depend->input(1), t);
|
|
ASSERT_EQ(mgr->node_users()[d].size(), 0);
|
|
ASSERT_EQ(mgr->node_users()[t].size(), 1); // depend removed.
|
|
ASSERT_EQ(mgr->node_users()[t].front().first, get_item);
|
|
}
|
|
|
|
} // namespace mindspore
|