forked from mindspore-Ecosystem/mindspore
optimize depend for mix target
This commit is contained in:
parent
7345d7471b
commit
a3843659b4
|
@ -386,9 +386,15 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
|
|||
auto new_fg = BasicClone(fg);
|
||||
cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
|
||||
}
|
||||
auto origin_inputs = cnode->inputs();
|
||||
bool optimize_depend = false;
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
|
||||
origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) {
|
||||
optimize_depend = true;
|
||||
}
|
||||
// if has multiple depends,only select first depend as parameter
|
||||
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
|
||||
auto anf = cnode->inputs()[input_idx];
|
||||
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
|
||||
auto anf = origin_inputs[input_idx];
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
// anf has been created before
|
||||
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
||||
|
@ -413,6 +419,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
|
|||
(*other_graph_cnode)[anf] = new_parameter;
|
||||
}
|
||||
continue;
|
||||
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
|
||||
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
|
||||
continue;
|
||||
} else if (anf->isa<AnfNode>()) {
|
||||
*from_other_graph = true;
|
||||
// the input node is a cnode from other graph
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/utils.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "operator/ops.h"
|
||||
|
@ -85,7 +86,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
if (lst.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Input anf node list is empty";
|
||||
}
|
||||
|
||||
auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr {
|
||||
if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) {
|
||||
eqv[a] = a;
|
||||
|
@ -95,17 +95,14 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
eqv[a]->set_abstract(a->abstract());
|
||||
eqv[a]->set_kernel_info(a->kernel_info_ptr());
|
||||
}
|
||||
|
||||
return eqv[a];
|
||||
};
|
||||
|
||||
// Merge CNodes into a AnfGraph that represents a linear instruction segment
|
||||
for (auto n : lst) {
|
||||
if (!n->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Inst is not CNode";
|
||||
}
|
||||
auto &inps = n->cast<CNodePtr>()->inputs();
|
||||
|
||||
if (inps.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Input is empty";
|
||||
}
|
||||
|
@ -114,21 +111,22 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
|
||||
MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive valuenode";
|
||||
}
|
||||
|
||||
auto fn = inps[0];
|
||||
|
||||
std::vector<AnfNodePtr> args{fn};
|
||||
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref);
|
||||
|
||||
if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && inps[kRealInputIndexInDepend]->isa<ValueNode>() &&
|
||||
eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
|
||||
args.emplace_back(inps[kRealInputIndexInDepend]);
|
||||
args.emplace_back(inps[kRealInputIndexInDepend]);
|
||||
} else {
|
||||
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref);
|
||||
}
|
||||
eqv[n] = fg->NewCNode(args);
|
||||
eqv[n]->set_abstract(n->abstract());
|
||||
eqv[n]->set_kernel_info(n->kernel_info_ptr());
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> eqv_keys;
|
||||
(void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys),
|
||||
[](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
|
||||
|
||||
auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys);
|
||||
AnfNodePtr fg_output;
|
||||
if (outputs.size() > 1) {
|
||||
|
|
|
@ -136,29 +136,12 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
|
|||
}
|
||||
}
|
||||
|
||||
bool IsGetItemNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto &inputs = cnode->inputs();
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||
}
|
||||
if (!IsValueNode<Primitive>(inputs[0])) {
|
||||
return true;
|
||||
}
|
||||
PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(inputs[0]);
|
||||
return node_prim->name() == prim::kPrimTupleGetItem->name();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> ReorderGetItemNode(const std::vector<AnfNodePtr> &nodes) {
|
||||
std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
|
||||
std::vector<AnfNodePtr> result;
|
||||
std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
|
||||
std::map<AnfNodePtr, size_t> node_positions;
|
||||
for (auto &node : nodes) {
|
||||
if (IsGetItemNode(node)) {
|
||||
if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
|
@ -241,7 +224,7 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
|
|||
}
|
||||
}
|
||||
std::reverse(result.begin(), result.end());
|
||||
return ReorderGetItemNode(result);
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -309,19 +292,12 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
|
||||
VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto nodes = OptimizeGetItemOrder(input_nodes);
|
||||
VectorRef splits;
|
||||
VectorRef split;
|
||||
auto nodes = TopoSort(graph->get_return());
|
||||
if (ContainMultiTarget(nodes)) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->device_target();
|
||||
nodes = SplitSort(graph, default_target);
|
||||
}
|
||||
std::string last_target;
|
||||
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
|
||||
for (auto &node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsCut(node)) {
|
||||
|
@ -343,6 +319,36 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
|
|||
return splits;
|
||||
}
|
||||
|
||||
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto nodes = TopoSort(graph->get_return());
|
||||
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
|
||||
|
||||
if (ContainMultiTarget(nodes)) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->device_target();
|
||||
nodes = SplitSort(graph, default_target);
|
||||
return SplitNodesWithTarget(nodes, graph);
|
||||
}
|
||||
|
||||
VectorRef splits;
|
||||
VectorRef split;
|
||||
for (auto &node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsCut(node)) {
|
||||
if (split.size() != 0) {
|
||||
splits.push_back(split);
|
||||
}
|
||||
splits.push_back(node);
|
||||
split.clear();
|
||||
} else if (node->isa<CNode>()) {
|
||||
split.push_back(node);
|
||||
}
|
||||
}
|
||||
return splits;
|
||||
}
|
||||
|
||||
// Push the value node on the stack.
|
||||
void CompileGraph::Push(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
|
|
@ -78,6 +78,7 @@ class CompileGraph {
|
|||
}
|
||||
|
||||
private:
|
||||
VectorRef SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph);
|
||||
void PushParameters(const FuncGraphPtr &func_graph);
|
||||
bool SplitGraph(const FuncGraphPtr &func_graph);
|
||||
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");
|
||||
|
|
Loading…
Reference in New Issue