optimize depend for mix target

This commit is contained in:
kswang 2020-06-28 16:38:59 +08:00
parent 7345d7471b
commit a3843659b4
4 changed files with 55 additions and 41 deletions

View File

@ -386,9 +386,15 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
auto new_fg = BasicClone(fg);
cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
}
auto origin_inputs = cnode->inputs();
bool optimize_depend = false;
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) {
optimize_depend = true;
}
// if has multiple depends,only select first depend as parameter
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->inputs()[input_idx];
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
auto anf = origin_inputs[input_idx];
MS_EXCEPTION_IF_NULL(anf);
// anf has been created before
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
@ -413,6 +419,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
(*other_graph_cnode)[anf] = new_parameter;
}
continue;
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
continue;
} else if (anf->isa<AnfNode>()) {
*from_other_graph = true;
// the input node is a cnode from other graph

View File

@ -28,6 +28,7 @@
#include <string>
#include "utils/log_adapter.h"
#include "utils/utils.h"
#include "ir/manager.h"
#include "ir/func_graph_cloner.h"
#include "operator/ops.h"
@ -85,7 +86,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
if (lst.empty()) {
MS_LOG(EXCEPTION) << "Input anf node list is empty";
}
auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr {
if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) {
eqv[a] = a;
@ -95,17 +95,14 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
eqv[a]->set_abstract(a->abstract());
eqv[a]->set_kernel_info(a->kernel_info_ptr());
}
return eqv[a];
};
// Merge CNodes into a AnfGraph that represents a linear instruction segment
for (auto n : lst) {
if (!n->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Inst is not CNode";
}
auto &inps = n->cast<CNodePtr>()->inputs();
if (inps.empty()) {
MS_LOG(EXCEPTION) << "Input is empty";
}
@ -114,21 +111,22 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive valuenode";
}
auto fn = inps[0];
std::vector<AnfNodePtr> args{fn};
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref);
if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && inps[kRealInputIndexInDepend]->isa<ValueNode>() &&
eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
args.emplace_back(inps[kRealInputIndexInDepend]);
args.emplace_back(inps[kRealInputIndexInDepend]);
} else {
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref);
}
eqv[n] = fg->NewCNode(args);
eqv[n]->set_abstract(n->abstract());
eqv[n]->set_kernel_info(n->kernel_info_ptr());
}
std::vector<AnfNodePtr> eqv_keys;
(void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys),
[](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys);
AnfNodePtr fg_output;
if (outputs.size() > 1) {

View File

@ -136,29 +136,12 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
}
}
bool IsGetItemNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
if (!IsValueNode<Primitive>(inputs[0])) {
return true;
}
PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(inputs[0]);
return node_prim->name() == prim::kPrimTupleGetItem->name();
}
return false;
}
std::vector<AnfNodePtr> ReorderGetItemNode(const std::vector<AnfNodePtr> &nodes) {
std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
std::vector<AnfNodePtr> result;
std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
std::map<AnfNodePtr, size_t> node_positions;
for (auto &node : nodes) {
if (IsGetItemNode(node)) {
if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
@ -241,7 +224,7 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
}
}
std::reverse(result.begin(), result.end());
return ReorderGetItemNode(result);
return result;
}
} // namespace
@ -309,19 +292,12 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
return false;
}
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = OptimizeGetItemOrder(input_nodes);
VectorRef splits;
VectorRef split;
auto nodes = TopoSort(graph->get_return());
if (ContainMultiTarget(nodes)) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
nodes = SplitSort(graph, default_target);
}
std::string last_target;
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (IsCut(node)) {
@ -343,6 +319,36 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
return splits;
}
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = TopoSort(graph->get_return());
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
if (ContainMultiTarget(nodes)) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
nodes = SplitSort(graph, default_target);
return SplitNodesWithTarget(nodes, graph);
}
VectorRef splits;
VectorRef split;
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (IsCut(node)) {
if (split.size() != 0) {
splits.push_back(split);
}
splits.push_back(node);
split.clear();
} else if (node->isa<CNode>()) {
split.push_back(node);
}
}
return splits;
}
// Push the value node on the stack.
void CompileGraph::Push(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);

View File

@ -78,6 +78,7 @@ class CompileGraph {
}
private:
VectorRef SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph);
void PushParameters(const FuncGraphPtr &func_graph);
bool SplitGraph(const FuncGraphPtr &func_graph);
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");