forked from mindspore-Ecosystem/mindspore
[auto-monad] Enforce order of exection for Loads user nodes in frontend
This commit is contained in:
parent
9edf30bd05
commit
121a6a28d9
|
@ -132,7 +132,7 @@ def Depend(value, expr):
|
|||
return value
|
||||
|
||||
|
||||
def UpdateState(monad, expr):
|
||||
def UpdateState(monad, *exprs):
|
||||
"""Implement `UpdateState`."""
|
||||
return monad
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
auto real_input = node_with_index.first;
|
||||
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
|
||||
input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
|
||||
input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
AnfAlgo::SetNodeInput(node, input_node, index);
|
||||
}
|
||||
|
@ -120,10 +120,16 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
|
|||
return node;
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node,
|
||||
const AnfNodePtr &node, const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_node);
|
||||
for (auto &update_state : update_states) {
|
||||
manager->SetEdge(update_state.first, update_state.second, node);
|
||||
}
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
size_t out_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
|
@ -282,7 +288,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
|||
return cast;
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
size_t outputs_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
if (outputs_num == 0) {
|
||||
|
@ -298,7 +304,7 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
return new_node;
|
||||
}
|
||||
// Multiple output
|
||||
return InsertTransOpForMultipleOutput(func_graph, node, kernel_select);
|
||||
return InsertTransOpForMultipleOutput(func_graph, orig_node, node, kernel_select);
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
|
|
|
@ -103,7 +103,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
|||
AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select);
|
||||
|
||||
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select);
|
||||
|
||||
CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
|
|
@ -66,6 +66,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
|
|||
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
|
||||
for (auto out_getitem : manager->node_users()[bnupdate]) {
|
||||
MS_EXCEPTION_IF_NULL(out_getitem.first);
|
||||
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
|
||||
auto input2 = out_getitem_ptr->input(2);
|
||||
|
|
|
@ -43,6 +43,9 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
|
|||
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
|
||||
for (auto out_getitem : manager->node_users()[bnupdate]) {
|
||||
MS_EXCEPTION_IF_NULL(out_getitem.first);
|
||||
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
|
||||
auto input2 = out_getitem_ptr->input(2);
|
||||
|
|
|
@ -297,9 +297,11 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
|
|||
} else {
|
||||
int64_t prev_idx = 0;
|
||||
std::vector<AnfNodePtr> tuple_getitem_nodes;
|
||||
std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(),
|
||||
std::back_inserter(tuple_getitem_nodes),
|
||||
[](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; });
|
||||
for (auto &user : manager->node_users()[node]) {
|
||||
if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimTupleGetItem)) {
|
||||
tuple_getitem_nodes.emplace_back(user.first);
|
||||
}
|
||||
}
|
||||
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare);
|
||||
for (auto &getitem : tuple_getitem_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
|
|
|
@ -163,7 +163,20 @@ CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get
|
|||
return func_graph->NewCNode(depend_nodes);
|
||||
}
|
||||
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
|
||||
const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
||||
const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto cnode = orig_cnode;
|
||||
auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode);
|
||||
if (!update_states.empty()) {
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
cnode = kernel_graph->NewCNode(orig_cnode);
|
||||
cnode->set_inputs(orig_cnode->inputs());
|
||||
for (auto &update_state : update_states) {
|
||||
manager->SetEdge(update_state.first, update_state.second, cnode);
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
auto ref_infos = op_info->ref_infos();
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
|
|
|
@ -30,9 +30,16 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode,
|
||||
const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode);
|
||||
for (auto &update_state : update_states) {
|
||||
manager->SetEdge(update_state.first, update_state.second, cnode);
|
||||
}
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
AbstractBasePtrList abstract_list;
|
||||
make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
|
@ -69,9 +76,9 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
|
|||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
return make_tuple;
|
||||
} // namespace
|
||||
}
|
||||
|
||||
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
|
||||
|
@ -99,7 +106,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
|
|||
return replace_node;
|
||||
}
|
||||
// Multiple output
|
||||
return InsertCastForMultipleOutput(func_graph, cnode);
|
||||
return InsertCastForMultipleOutput(func_graph, orig_cnode, cnode);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -124,7 +131,7 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
|
|||
kernel_graph->ReplaceInternalOutput(node, new_node);
|
||||
}
|
||||
// process output
|
||||
return InsertCastForOutput(func_graph, new_node);
|
||||
return InsertCastForOutput(func_graph, cnode, new_node);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -43,7 +43,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
|
|||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
|
||||
kernel_graph->ReplaceInternalOutput(node, new_node);
|
||||
}
|
||||
return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
|
||||
return InsertTransOpForOutput(func_graph, node, new_node, kernel_select_);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "backend/optimizer/pass/convert_const_scalar_to_tensor.h"
|
||||
#include "backend/optimizer/pass/convert_attr_to_unify_mindir.h"
|
||||
#include "backend/optimizer/pass/add_training_attr.h"
|
||||
#include "backend/optimizer/pass/optimize_updatestate.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
|
@ -58,5 +59,24 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
DumpIR(file_name, kernel_graph);
|
||||
}
|
||||
}
|
||||
|
||||
void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
// Run optimizer passes.
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto pm = std::make_shared<PassManager>("final_opt");
|
||||
pm->AddPass(std::make_shared<OptimizeUpdateState>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
// Dump IR if save_graphs is set.
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
const bool save_graphs = context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (save_graphs) {
|
||||
std::string filename = "hwopt_common_final_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(filename, kernel_graph);
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -401,11 +401,9 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
|
|||
}
|
||||
auto output_info_list = iter->second;
|
||||
for (const auto &output_info : output_info_list) {
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
|
||||
output_info.second == kDependAttachNodeIndex) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) {
|
||||
auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
|
||||
if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
|
||||
(cnode_name == prim::kPrimUpdateState->name())) {
|
||||
continue;
|
||||
}
|
||||
output_node_list->push_back(output_info);
|
||||
|
@ -426,12 +424,13 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu
|
|||
}
|
||||
auto output_info_list = iter->second;
|
||||
for (const auto &output_info : output_info_list) {
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
|
||||
output_info.second == kDependAttachNodeIndex) {
|
||||
auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
|
||||
if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
|
||||
(cnode_name == prim::kPrimUpdateState->name())) {
|
||||
continue;
|
||||
}
|
||||
size_t used_output_index;
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) {
|
||||
if (cnode_name == prim::kPrimTupleGetItem->name()) {
|
||||
used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
|
||||
} else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
|
||||
used_output_index = output_index;
|
||||
|
@ -906,12 +905,13 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
// find BatchNorm's output which is a Depend
|
||||
// Find BatchNorm's output which is a Depend or UpdateState.
|
||||
for (const auto &node_index : manager->node_users()[old_node]) {
|
||||
AnfNodePtr output = node_index.first;
|
||||
size_t index = IntToSize(node_index.second);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
|
||||
AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
|
||||
auto depend = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
depend->set_input(index, new_node);
|
||||
|
|
|
@ -66,13 +66,14 @@ bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr
|
|||
auto output_num = output->cast<CNodePtr>()->size() - 1;
|
||||
getitem_list->clear();
|
||||
getitem_list->resize(output_num, nullptr);
|
||||
const auto &users = mng->node_users()[node];
|
||||
auto users = mng->node_users()[node];
|
||||
bool changed = false;
|
||||
AnfNodePtrList user_nodes;
|
||||
std::transform(users.begin(), users.end(), std::back_inserter(user_nodes),
|
||||
[](const std::pair<AnfNodePtr, int> &user) { return user.first; });
|
||||
for (const auto &getitem : user_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
for (const auto &user : users) {
|
||||
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
|
||||
// Sometime, the user of MakeTuple is not a TupleGetItem, but a UpdateState.
|
||||
continue;
|
||||
}
|
||||
auto &getitem = user.first;
|
||||
auto idx = GetIndex(getitem);
|
||||
if (idx >= output_num) {
|
||||
MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString();
|
||||
|
|
|
@ -35,19 +35,17 @@ CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cno
|
|||
const std::vector<AnfNodePtr> &new_depend_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
CNodePtr new_depend = nullptr;
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
if (kernel_graph == nullptr) {
|
||||
new_depend = func_graph->NewCNode(new_depend_inputs);
|
||||
auto new_depend = func_graph->NewCNode(new_depend_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_depend);
|
||||
new_depend->set_abstract(cnode->abstract());
|
||||
new_depend->set_scope(cnode->scope());
|
||||
} else {
|
||||
new_depend = kernel_graph->NewCNode(cnode);
|
||||
MS_EXCEPTION_IF_NULL(new_depend);
|
||||
new_depend->set_inputs(new_depend_inputs);
|
||||
return new_depend;
|
||||
}
|
||||
func_graph->manager()->Replace(cnode, new_depend);
|
||||
auto new_depend = kernel_graph->NewCNode(cnode);
|
||||
MS_EXCEPTION_IF_NULL(new_depend);
|
||||
new_depend->set_inputs(new_depend_inputs);
|
||||
return new_depend;
|
||||
}
|
||||
|
||||
|
@ -77,9 +75,9 @@ AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, con
|
|||
auto replace_node = eliminate_node->input(kSingleInputIndex);
|
||||
std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs();
|
||||
new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node;
|
||||
auto new_cnode = CreateNewDependNode(func_graph, cnode, new_depend_inputs);
|
||||
auto new_node = new_cnode->cast<AnfNodePtr>();
|
||||
return new_node;
|
||||
auto new_depend = CreateNewDependNode(func_graph, cnode, new_depend_inputs);
|
||||
func_graph->manager()->Replace(cnode, new_depend);
|
||||
return new_depend;
|
||||
}
|
||||
|
||||
AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
|
@ -157,55 +155,53 @@ const BaseRef OptimizeDependence::DefinePattern() const {
|
|||
return VectorRef({X, Xs});
|
||||
}
|
||||
|
||||
std::pair<AnfNodePtr, size_t> SearchTransDataAndCast(const AnfNodePtr &node, bool is_first_node) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return std::pair<AnfNodePtr, size_t>(nullptr, 0);
|
||||
std::vector<size_t> SearchTransDataAndCast(const CNodePtr &cnode) {
|
||||
// Search Depend and UpdateState only.
|
||||
if (!cnode->IsApply(prim::kPrimDepend) && !cnode->IsApply(prim::kPrimUpdateState)) {
|
||||
return {};
|
||||
}
|
||||
// get real input of depend and update state.
|
||||
size_t replace_input_index = 0;
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
||||
replace_input_index = is_first_node ? kDependAttachNodeIndex : kRealInputIndexInDepend;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
|
||||
replace_input_index = is_first_node ? kUpdateStateStateInput : kUpdateStateRealInput;
|
||||
} else {
|
||||
return std::pair<AnfNodePtr, size_t>(nullptr, 0);
|
||||
// Find inputs which is Cast or TransData.
|
||||
std::vector<size_t> result;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto &input = cnode->input(i);
|
||||
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimCast) ||
|
||||
AnfAlgo::CheckPrimitiveType(input, prim::KPrimTransData) ||
|
||||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
|
||||
result.emplace_back(i);
|
||||
}
|
||||
}
|
||||
// check whether real input is cast or trans data
|
||||
auto real_input = node->cast<CNodePtr>()->input(replace_input_index);
|
||||
if (AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimCast) ||
|
||||
AnfAlgo::CheckPrimitiveType(real_input, prim::KPrimTransData) ||
|
||||
AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimMakeTuple)) {
|
||||
return std::pair<AnfNodePtr, size_t>(node, replace_input_index);
|
||||
}
|
||||
return SearchTransDataAndCast(real_input, false);
|
||||
return result;
|
||||
}
|
||||
|
||||
const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
// Get the cnode with repalce input index
|
||||
auto cnode_with_input_index = SearchTransDataAndCast(node, true);
|
||||
if (cnode_with_input_index.first == nullptr) {
|
||||
// Search inputs to be replaced.
|
||||
auto candidate_inputs = SearchTransDataAndCast(cnode);
|
||||
if (candidate_inputs.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
size_t replace_index = cnode_with_input_index.second;
|
||||
auto depend_cnode = cnode_with_input_index.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
// Get new node which will act as new input of depend or UpdateState.
|
||||
std::vector<AnfNodePtr> new_depend_inputs = depend_cnode->inputs();
|
||||
auto replace_node = GetConvertNode(func_graph, depend_cnode, replace_index);
|
||||
if (replace_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
new_depend_inputs[replace_index] = replace_node;
|
||||
auto new_depend = CreateNewDependNode(func_graph, depend_cnode, new_depend_inputs);
|
||||
if (new_depend == nullptr) {
|
||||
// Get new nodes which will act as new inputs of Depend or UpdateState.
|
||||
std::vector<AnfNodePtr> new_inputs = cnode->inputs();
|
||||
bool inputs_changed = false;
|
||||
for (auto index : candidate_inputs) {
|
||||
auto replace_node = GetConvertNode(func_graph, cnode, index);
|
||||
if (replace_node != nullptr) {
|
||||
new_inputs[index] = replace_node;
|
||||
inputs_changed = true;
|
||||
}
|
||||
}
|
||||
if (!inputs_changed) {
|
||||
return nullptr;
|
||||
}
|
||||
// Create a new Depend node to replace the old one if inputs changed.
|
||||
auto new_depend = CreateNewDependNode(func_graph, cnode, new_inputs);
|
||||
func_graph->manager()->Replace(cnode, new_depend);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/pass/optimize_updatestate.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kInputIndex = 1;
|
||||
constexpr size_t kAttachIndex = 2;
|
||||
constexpr size_t kAdditionalAttachIndex = 3;
|
||||
|
||||
const BaseRef OptimizeUpdateState::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimUpdateState, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr OptimizeUpdateState::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
auto update_state = dyn_cast<CNode>(node);
|
||||
MS_EXCEPTION_IF_NULL(update_state);
|
||||
if (update_state->size() <= kAdditionalAttachIndex) {
|
||||
// Skip UpdateState nodes with no additional attaches.
|
||||
return nullptr;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto &node_users = manager->node_users();
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
new_inputs.emplace_back(update_state->input(0));
|
||||
new_inputs.emplace_back(update_state->input(kInputIndex));
|
||||
new_inputs.emplace_back(update_state->input(kAttachIndex));
|
||||
for (size_t i = kAdditionalAttachIndex; i < update_state->size(); ++i) {
|
||||
auto &attach = update_state->input(i);
|
||||
auto &users = node_users[attach];
|
||||
if ((users.size() == 1) && (users.front().first == update_state)) {
|
||||
// If the only user of attach is the UpdateState node, drop the attach node.
|
||||
continue;
|
||||
}
|
||||
new_inputs.emplace_back(attach);
|
||||
}
|
||||
if (new_inputs.size() == update_state->size()) {
|
||||
// Attaches not changed.
|
||||
return nullptr;
|
||||
}
|
||||
// Attaches changed, make a new UpdateState.
|
||||
auto new_update_state = func_graph->NewCNode(new_inputs);
|
||||
new_update_state->set_abstract(update_state->abstract());
|
||||
new_update_state->set_scope(update_state->scope());
|
||||
return new_update_state;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class OptimizeUpdateState : public PatternProcessPass {
|
||||
public:
|
||||
explicit OptimizeUpdateState(bool multigraph = true) : PatternProcessPass("optimize_updatestate", multigraph) {}
|
||||
~OptimizeUpdateState() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_
|
|
@ -1931,5 +1931,15 @@ void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_
|
|||
{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
|
||||
root_graph->set_output(make_tuple);
|
||||
}
|
||||
|
||||
AnfNodeIndexSet AnfRuntimeAlgorithm::GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
|
||||
AnfNodeIndexSet update_states;
|
||||
for (auto &user : manager->node_users()[node]) {
|
||||
if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) {
|
||||
update_states.insert(user);
|
||||
}
|
||||
}
|
||||
return update_states;
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -267,6 +267,7 @@ class AnfRuntimeAlgorithm {
|
|||
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
||||
std::set<AnfNodePtr> *visited);
|
||||
static void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph);
|
||||
static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -936,6 +936,7 @@ void AscendSession::InitRuntimeResource() {
|
|||
void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "HardwareOptimize start!";
|
||||
opt::AscendBackendOptimization(kernel_graph);
|
||||
FinalOptimize(kernel_graph);
|
||||
GraphKernelOptimize(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -104,6 +104,7 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
|
|||
SetKernelInfo(graph.get());
|
||||
MS_LOG(INFO) << "Set kernel info end";
|
||||
Optimize(graph);
|
||||
FinalOptimize(graph);
|
||||
MS_LOG(INFO) << "Build kernel";
|
||||
BuildKernel(graph.get());
|
||||
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
|
||||
|
|
|
@ -341,6 +341,8 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) {
|
|||
SelectKernel(graph);
|
||||
// Graph optimization relevant to device data format
|
||||
HardwareOptimize(graph);
|
||||
// Run final optimization
|
||||
FinalOptimize(graph);
|
||||
// Graph kernel fusion optimization
|
||||
GraphKernelOptimize(graph);
|
||||
// Start gpu kernel runtime
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
|
@ -2343,6 +2344,12 @@ void SessionBasic::ClearAllBucket(const GraphId &graph_id) {
|
|||
}
|
||||
}
|
||||
|
||||
void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const {
|
||||
MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id();
|
||||
opt::CommonFinalOptimization(graph);
|
||||
MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id();
|
||||
}
|
||||
|
||||
void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
|
|
|
@ -172,6 +172,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
virtual void UpdateOutputTensors(const VectorRef *outputs,
|
||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node);
|
||||
virtual void UnifyMindIR(const KernelGraphPtr &graph) {}
|
||||
virtual void FinalOptimize(const KernelGraphPtr &graph) const;
|
||||
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; }
|
||||
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
|
||||
virtual void BuildGraphImpl(GraphId) {}
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "pipeline/jit/parse/parse_base.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||
#include "pipeline/jit/static_analysis/order_enforce.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "pipeline/jit/static_analysis/program_specialize.h"
|
||||
|
@ -343,6 +344,18 @@ bool AutoMonadAction(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool OrderEnforceAction(const ResourcePtr &res) {
|
||||
if (res->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Order-Enforce error, manager is null";
|
||||
}
|
||||
auto func_graph = res->func_graph();
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Order-Enforce error, graph is null";
|
||||
}
|
||||
pipeline::OrderEnforce(func_graph);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InferenceOptPrepareAction(const ResourcePtr &res) {
|
||||
if (res->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
|
||||
|
@ -752,6 +765,7 @@ std::vector<ActionItem> GePipeline() {
|
|||
// Add opt-stage python pass stub
|
||||
actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
|
||||
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
|
||||
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
return actions;
|
||||
}
|
||||
|
@ -765,6 +779,8 @@ std::vector<ActionItem> VmPipeline() {
|
|||
// Add opt-stage python pass stub
|
||||
actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));
|
||||
|
||||
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
|
||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::PSContext::instance()->is_worker()) {
|
||||
|
@ -784,6 +800,7 @@ std::vector<ActionItem> VmPipeline() {
|
|||
std::vector<ActionItem> PServerPipeline() {
|
||||
auto actions = CommonPipeline();
|
||||
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
actions.emplace_back(std::make_pair("pserver", StartPSServerAction));
|
||||
return actions;
|
||||
|
|
|
@ -0,0 +1,258 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/jit/static_analysis/order_enforce.h"
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore::pipeline {
|
||||
namespace {
|
||||
|
||||
class OrderEnforcer {
|
||||
public:
|
||||
explicit OrderEnforcer(const FuncGraphPtr &func_graph) : func_graph_(func_graph), manager_(func_graph->manager()) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
MS_EXCEPTION_IF_NULL(manager_);
|
||||
}
|
||||
~OrderEnforcer() = default;
|
||||
|
||||
void Run() {
|
||||
auto nodes = MakeTopoSortMap();
|
||||
for (auto &node : nodes) {
|
||||
HandleNode(node);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtrList MakeTopoSortMap() {
|
||||
auto nodes = TopoSort(func_graph_->get_return());
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
topo_sort_map_.emplace(nodes[i], i);
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
void HandleNode(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
// Skip nodes other than UpdateState.
|
||||
return;
|
||||
}
|
||||
auto update_state = node->cast<CNodePtr>();
|
||||
if (!HasAbstractUMonad(update_state->input(1))) {
|
||||
// Skip UpdateStates for IO.
|
||||
return;
|
||||
}
|
||||
auto updated_refs = FindUpdatedRefs(update_state);
|
||||
if (updated_refs.empty()) {
|
||||
// Skip UpdateStates that do not have updated refs.
|
||||
return;
|
||||
}
|
||||
auto &attach = update_state->input(2);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
|
||||
// Handle UpdateState with Load.
|
||||
EnforceOrderForLoad(update_state, attach->cast<CNodePtr>(), updated_refs);
|
||||
} else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
||||
// Handle UpdateState with MakeTuple.
|
||||
EnforceOrderForTuple(update_state, attach->cast<CNodePtr>(), updated_refs);
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<AnfNodePtr> FindUpdatedRefs(const CNodePtr &update_state) {
|
||||
std::unordered_set<AnfNodePtr> updated_refs;
|
||||
auto &users = manager_->node_users()[update_state];
|
||||
for (auto &user : users) {
|
||||
auto cnode = dyn_cast<CNode>(user.first);
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (cnode->IsApply(prim::kPrimLoad) || cnode->IsApply(prim::kPrimDepend) ||
|
||||
cnode->IsApply(prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
for (auto &input : cnode->inputs()) {
|
||||
if (IsRef(input)) {
|
||||
updated_refs.insert(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
return updated_refs;
|
||||
}
|
||||
|
||||
bool IsRef(const AnfNodePtr &node) {
|
||||
auto &abs = node->abstract();
|
||||
return abs != nullptr && abs->isa<abstract::AbstractRef>();
|
||||
}
|
||||
|
||||
void EnforceOrderForLoad(const CNodePtr &update_state, const CNodePtr &load,
|
||||
const std::unordered_set<AnfNodePtr> &refs) {
|
||||
if (refs.find(load->input(1)) == refs.end()) {
|
||||
// Skip if loaded parameter is not updated.
|
||||
return;
|
||||
}
|
||||
// Find load users, ignore processed nodes.
|
||||
auto load_users = FindLoadUsers(load, update_state);
|
||||
// Find load users that not depend on the UpdateState,
|
||||
// and than let UpdateState depend on them.
|
||||
AddInputEdges(update_state, load_users);
|
||||
}
|
||||
|
||||
void EnforceOrderForTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
|
||||
const std::unordered_set<AnfNodePtr> &refs) {
|
||||
// The UpdateState should be the only one user of the make_tuple.
|
||||
// for performance, we only check the number of output edges.
|
||||
if (manager_->node_users()[make_tuple].size() != 1) {
|
||||
return;
|
||||
}
|
||||
// Find load users from the tuple of Load nodes.
|
||||
std::unordered_set<AnfNodePtr> all_load_users;
|
||||
auto &inputs = make_tuple->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto &input = inputs.at(i);
|
||||
if (!IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
// Skip non-Load nodes.
|
||||
continue;
|
||||
}
|
||||
auto load = input->cast<CNodePtr>();
|
||||
if (refs.find(load->input(1)) == refs.end()) {
|
||||
// Skip if loaded parameter is not updated.
|
||||
continue;
|
||||
}
|
||||
auto load_users = FindLoadUsers(load, make_tuple);
|
||||
all_load_users.insert(load_users.begin(), load_users.end());
|
||||
}
|
||||
// Find load users that not depend on the UpdateState,
|
||||
// and than let UpdateState depend on them.
|
||||
AddInputEdges(update_state, all_load_users);
|
||||
}
|
||||
|
||||
// Add load users as input edges of the update_state node.
|
||||
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
|
||||
auto sorted_load_users = SortLoadUsers(load_users);
|
||||
for (auto &load_user : sorted_load_users) {
|
||||
if (!IsDependOn(load_user, update_state)) {
|
||||
processed_nodes_.insert(load_user);
|
||||
manager_->AddEdge(update_state, load_user);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort load users by their topo sort order.
|
||||
std::vector<AnfNodePtr> SortLoadUsers(const std::unordered_set<AnfNodePtr> &load_users) {
|
||||
std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()};
|
||||
std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); });
|
||||
return vec;
|
||||
}
|
||||
|
||||
// Check if the load user node depend on the given UpdateState node.
|
||||
bool IsDependOn(const AnfNodePtr &load_user, const AnfNodePtr &update_state) {
|
||||
size_t update_state_order = topo_sort_map_[update_state];
|
||||
if (topo_sort_map_[load_user] < update_state_order) {
|
||||
return false;
|
||||
}
|
||||
auto user_cnode = dyn_cast<CNode>(load_user);
|
||||
if (user_cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
size_t seen = NewSeenGeneration();
|
||||
std::queue<CNodePtr> q;
|
||||
user_cnode->seen_ = seen;
|
||||
q.push(user_cnode);
|
||||
while (!q.empty()) {
|
||||
auto cnode = q.front();
|
||||
q.pop();
|
||||
for (auto &input : cnode->inputs()) {
|
||||
if (input == update_state) {
|
||||
// Dependency found.
|
||||
return true;
|
||||
}
|
||||
if (input->seen_ == seen) {
|
||||
// Skip visited nodes.
|
||||
continue;
|
||||
}
|
||||
if (topo_sort_map_[input] < update_state_order) {
|
||||
// Skip input nodes that before the UpdateState node.
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = dyn_cast<CNode>(input);
|
||||
if (input_cnode != nullptr) {
|
||||
input_cnode->seen_ = seen;
|
||||
q.push(input_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsBefore(const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
return topo_sort_map_[node1] < topo_sort_map_[node2];
|
||||
}
|
||||
|
||||
// Find Load users as the candidate nodes to enforce order of execution.
|
||||
std::unordered_set<AnfNodePtr> FindLoadUsers(const CNodePtr &load, const AnfNodePtr &exclude) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(load);
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
std::unordered_set<AnfNodePtr> load_users;
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
auto &user_node = user.first;
|
||||
if (user_node == exclude) {
|
||||
// Skip excluded node.
|
||||
continue;
|
||||
}
|
||||
if (processed_nodes_.find(user_node) != processed_nodes_.end()) {
|
||||
// Skip processed nodes.
|
||||
continue;
|
||||
}
|
||||
auto cnode = dyn_cast<CNode>(user_node);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
const bool has_u_input =
|
||||
std::any_of(inputs.begin() + 1, inputs.end(), [](const AnfNodePtr &input) { return HasAbstractUMonad(input); });
|
||||
if (has_u_input) {
|
||||
// Skip nodes with memory side effects, which use u input.
|
||||
continue;
|
||||
}
|
||||
load_users.insert(cnode);
|
||||
}
|
||||
return load_users;
|
||||
}
|
||||
|
||||
private:
|
||||
const FuncGraphPtr &func_graph_;
|
||||
FuncGraphManagerPtr manager_;
|
||||
std::unordered_map<AnfNodePtr, size_t> topo_sort_map_;
|
||||
std::unordered_set<AnfNodePtr> processed_nodes_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//
|
||||
// Enforce order of execution for Load users node.
|
||||
//
|
||||
void OrderEnforce(const FuncGraphPtr &func_graph) {
|
||||
OrderEnforcer enforcer(func_graph);
|
||||
enforcer.Run();
|
||||
}
|
||||
} // namespace mindspore::pipeline
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::pipeline {
|
||||
// Enforce order of execution of the given graph.
|
||||
void OrderEnforce(const FuncGraphPtr &func_graph);
|
||||
} // namespace mindspore::pipeline
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_
|
|
@ -1456,7 +1456,10 @@ def test_while_forward():
|
|||
assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not supported yet")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_multi_add_assign():
|
||||
class Net(Cell):
|
||||
def __init__(self, i1):
|
||||
|
@ -1493,7 +1496,10 @@ def test_multi_add_assign():
|
|||
np.testing.assert_array_equal(outputs, expects)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not supported yet")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_multi_abs_add_assign():
|
||||
class Net(Cell):
|
||||
def __init__(self, para):
|
||||
|
|
Loading…
Reference in New Issue