forked from OSSInnovation/mindspore
Optimize depend edge with make tuple input
This commit is contained in:
parent
85df19b23c
commit
7f53bb062d
|
@ -27,6 +27,69 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
constexpr auto kSingleInputIndex = 1;
|
constexpr auto kSingleInputIndex = 1;
|
||||||
|
namespace {
|
||||||
|
AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
string op_name = AnfAlgo::GetCNodeName(cnode);
|
||||||
|
// Currently we only eliminate transdata or cast nodes.
|
||||||
|
if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
// Check whether the node has only one output node.
|
||||||
|
if (manager->node_users().find(cnode) == manager->node_users().end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The node should be used by at least another node's input";
|
||||||
|
}
|
||||||
|
if (manager->node_users()[cnode].size() > 1) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
CheckCNodeInputSize(cnode, kSingleInputIndex + 1);
|
||||||
|
return cnode->input(kSingleInputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<AnfNodePtr> new_make_tuple_inputs;
|
||||||
|
bool need_update = false;
|
||||||
|
for (const auto &input : cnode->inputs()) {
|
||||||
|
AnfNodePtr replace_input = GetReplaceNode(func_graph, input);
|
||||||
|
// If replace input is not null, it will be the input of the TransData or Cast.
|
||||||
|
if (replace_input == nullptr) {
|
||||||
|
new_make_tuple_inputs.push_back(input);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
new_make_tuple_inputs.push_back(replace_input);
|
||||||
|
need_update = true;
|
||||||
|
}
|
||||||
|
if (need_update) {
|
||||||
|
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||||
|
CNodePtr new_make_tuple = nullptr;
|
||||||
|
if (kernel_graph == nullptr) {
|
||||||
|
new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs);
|
||||||
|
} else {
|
||||||
|
new_make_tuple = kernel_graph->NewCNode(cnode);
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(new_make_tuple);
|
||||||
|
new_make_tuple->set_inputs(new_make_tuple_inputs);
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
manager->Replace(cnode, new_make_tuple);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
const BaseRef OptimizeDependence::DefinePattern() const {
|
const BaseRef OptimizeDependence::DefinePattern() const {
|
||||||
VarPtr X = std::make_shared<Var>("X");
|
VarPtr X = std::make_shared<Var>("X");
|
||||||
MS_EXCEPTION_IF_NULL(X);
|
MS_EXCEPTION_IF_NULL(X);
|
||||||
|
@ -43,9 +106,8 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto depend_cnode = node->cast<CNodePtr>();
|
auto depend_cnode = node->cast<CNodePtr>();
|
||||||
if (depend_cnode->inputs().size() < kDependInputNum) {
|
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||||
return nullptr;
|
CheckCNodeInputSize(depend_cnode, kDependInputNum);
|
||||||
}
|
|
||||||
auto replacing_node = depend_cnode->input(kDependInputNum - 1);
|
auto replacing_node = depend_cnode->input(kDependInputNum - 1);
|
||||||
MS_EXCEPTION_IF_NULL(replacing_node);
|
MS_EXCEPTION_IF_NULL(replacing_node);
|
||||||
if (!replacing_node->isa<CNode>()) {
|
if (!replacing_node->isa<CNode>()) {
|
||||||
|
@ -53,36 +115,29 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
|
||||||
}
|
}
|
||||||
auto replacing_cnode = replacing_node->cast<CNodePtr>();
|
auto replacing_cnode = replacing_node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(replacing_cnode);
|
MS_EXCEPTION_IF_NULL(replacing_cnode);
|
||||||
// Currently we only optimize transdata or cast nodes.
|
// Deal with the make_tuple with TransData or Cast inputs.
|
||||||
string replacing_cnode_op_name = AnfAlgo::GetCNodeName(replacing_cnode);
|
if (ReplaceMakeTuple(func_graph, replacing_cnode)) {
|
||||||
if (replacing_cnode_op_name != kTransDataOpName && replacing_cnode_op_name != prim::kPrimCast->name()) {
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto manager = func_graph->manager();
|
AnfNodePtr replace_node = GetReplaceNode(func_graph, replacing_cnode);
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
if (replace_node == nullptr) {
|
||||||
// Check whether the replacing node has only one input and one output.
|
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
|
||||||
if (replacing_cnode->inputs().size() != kSingleInputIndex + 1) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (manager->node_users().find(replacing_node) == manager->node_users().end()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The node should be used by at least another node input";
|
|
||||||
}
|
|
||||||
if (manager->node_users()[replacing_node].size() > 1) {
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex),
|
std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex),
|
||||||
depend_cnode->input(kRealInputIndexInDepend),
|
depend_cnode->input(kRealInputIndexInDepend), replace_node};
|
||||||
replacing_cnode->input(kSingleInputIndex)};
|
|
||||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||||
CNodePtr new_depend;
|
CNodePtr new_depend;
|
||||||
if (kernel_graph == nullptr) {
|
if (kernel_graph == nullptr) {
|
||||||
new_depend = func_graph->NewCNode(new_depend_inputs);
|
new_depend = func_graph->NewCNode(new_depend_inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_depend);
|
||||||
|
new_depend->set_abstract(node->abstract());
|
||||||
|
new_depend->set_scope(node->scope());
|
||||||
} else {
|
} else {
|
||||||
new_depend = kernel_graph->NewCNode(depend_cnode);
|
new_depend = kernel_graph->NewCNode(depend_cnode);
|
||||||
MS_EXCEPTION_IF_NULL(new_depend);
|
MS_EXCEPTION_IF_NULL(new_depend);
|
||||||
new_depend->set_inputs(new_depend_inputs);
|
new_depend->set_inputs(new_depend_inputs);
|
||||||
}
|
}
|
||||||
new_depend->set_abstract(node->abstract());
|
|
||||||
return new_depend;
|
return new_depend;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -48,5 +48,25 @@ TEST_F(TestHWOptimizeDependence, test_optimize_dependence) {
|
||||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_dependence", "after");
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_dependence", "after");
|
||||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestHWOptimizeDependence, test_optimize_dependence_with_make_tuple) {
|
||||||
|
/*
|
||||||
|
* def before(x, y, a, b):
|
||||||
|
* z = make_tuple(TransData(a), TransData(b))
|
||||||
|
* depend_intput = depend(y, z)
|
||||||
|
* sum = add(x, depend_intput)
|
||||||
|
* return sum
|
||||||
|
*/
|
||||||
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_dependence_with_make_tuple", "before");
|
||||||
|
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
|
pm->AddPass(std::make_shared<opt::OptimizeDependence>());
|
||||||
|
optimizer->AddPassManager(pm);
|
||||||
|
FuncGraphPtr new_graph = optimizer->Optimize(g);
|
||||||
|
|
||||||
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_dependence_with_make_tuple", "after");
|
||||||
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,6 +18,8 @@ from mindspore.ops import Primitive
|
||||||
depend = Primitive('depend')
|
depend = Primitive('depend')
|
||||||
TransData = Primitive('TransData')
|
TransData = Primitive('TransData')
|
||||||
add = P.TensorAdd()
|
add = P.TensorAdd()
|
||||||
|
make_tuple = Primitive('make_tuple')
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
class FnDict:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -47,3 +49,23 @@ def test_optimize_dependence(tag):
|
||||||
return sum
|
return sum
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
||||||
|
|
||||||
|
def test_optimize_dependence_with_make_tuple(tag):
|
||||||
|
fns = FnDict()
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before(x, y, a, b):
|
||||||
|
z = make_tuple(TransData(a), TransData(b))
|
||||||
|
depend_intput = depend(y, z)
|
||||||
|
sum = add(x, depend_intput)
|
||||||
|
return sum
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after(x, y, a, b):
|
||||||
|
z = make_tuple(a, b)
|
||||||
|
depend_intput = depend(y, z)
|
||||||
|
sum = add(x, depend_intput)
|
||||||
|
return sum
|
||||||
|
|
||||||
|
return fns[tag]
|
||||||
|
|
Loading…
Reference in New Issue