[auto-monad] Enforce order of exection for Loads user nodes in frontend

This commit is contained in:
He Wei 2021-04-27 11:26:08 +08:00
parent 9edf30bd05
commit 121a6a28d9
27 changed files with 568 additions and 81 deletions

View File

@ -132,7 +132,7 @@ def Depend(value, expr):
return value
def UpdateState(monad, expr):
def UpdateState(monad, *exprs):
"""Implement `UpdateState`."""
return monad

View File

@ -90,7 +90,7 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
MS_EXCEPTION_IF_NULL(node_with_index.first);
auto real_input = node_with_index.first;
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select);
MS_EXCEPTION_IF_NULL(input_node);
AnfAlgo::SetNodeInput(node, input_node, index);
}
@ -120,10 +120,16 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
return node;
}
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node,
const AnfNodePtr &node, const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_node);
for (auto &update_state : update_states) {
manager->SetEdge(update_state.first, update_state.second, node);
}
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
size_t out_num = AnfAlgo::GetOutputTensorNum(node);
@ -282,7 +288,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &
return cast;
}
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
size_t outputs_num = AnfAlgo::GetOutputTensorNum(node);
if (outputs_num == 0) {
@ -298,7 +304,7 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP
return new_node;
}
// Multiple output
return InsertTransOpForMultipleOutput(func_graph, node, kernel_select);
return InsertTransOpForMultipleOutput(func_graph, orig_node, node, kernel_select);
}
AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,

View File

@ -103,7 +103,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &
AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select);
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select);
CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);

View File

@ -66,6 +66,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
continue;
}
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(2);

View File

@ -43,6 +43,9 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
continue;
}
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(2);

View File

@ -297,9 +297,11 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
} else {
int64_t prev_idx = 0;
std::vector<AnfNodePtr> tuple_getitem_nodes;
std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(),
std::back_inserter(tuple_getitem_nodes),
[](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; });
for (auto &user : manager->node_users()[node]) {
if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimTupleGetItem)) {
tuple_getitem_nodes.emplace_back(user.first);
}
}
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare);
for (auto &getitem : tuple_getitem_nodes) {
MS_EXCEPTION_IF_NULL(getitem);

View File

@ -163,7 +163,20 @@ CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get
return func_graph->NewCNode(depend_nodes);
}
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto cnode = orig_cnode;
auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode);
if (!update_states.empty()) {
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
cnode = kernel_graph->NewCNode(orig_cnode);
cnode->set_inputs(orig_cnode->inputs());
for (auto &update_state : update_states) {
manager->SetEdge(update_state.first, update_state.second, cnode);
}
}
MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos();
std::vector<AnfNodePtr> make_tuple_inputs;

View File

@ -30,9 +30,16 @@
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode,
const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode);
for (auto &update_state : update_states) {
manager->SetEdge(update_state.first, update_state.second, cnode);
}
std::vector<AnfNodePtr> make_tuple_inputs;
AbstractBasePtrList abstract_list;
make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
@ -69,9 +76,9 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
MS_EXCEPTION_IF_NULL(make_tuple);
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return make_tuple;
} // namespace
}
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
@ -99,7 +106,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
return replace_node;
}
// Multiple output
return InsertCastForMultipleOutput(func_graph, cnode);
return InsertCastForMultipleOutput(func_graph, orig_cnode, cnode);
}
} // namespace
@ -124,7 +131,7 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
kernel_graph->ReplaceInternalOutput(node, new_node);
}
// process output
return InsertCastForOutput(func_graph, new_node);
return InsertCastForOutput(func_graph, cnode, new_node);
}
} // namespace opt
} // namespace mindspore

View File

@ -43,7 +43,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
kernel_graph->ReplaceInternalOutput(node, new_node);
}
return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
return InsertTransOpForOutput(func_graph, node, new_node, kernel_select_);
}
} // namespace opt
} // namespace mindspore

View File

@ -25,6 +25,7 @@
#include "backend/optimizer/pass/convert_const_scalar_to_tensor.h"
#include "backend/optimizer/pass/convert_attr_to_unify_mindir.h"
#include "backend/optimizer/pass/add_training_attr.h"
#include "backend/optimizer/pass/optimize_updatestate.h"
#include "utils/ms_context.h"
#include "debug/anf_ir_dump.h"
@ -58,5 +59,24 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
DumpIR(file_name, kernel_graph);
}
}
void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
// Run optimizer passes.
auto optimizer = std::make_shared<GraphOptimizer>();
auto pm = std::make_shared<PassManager>("final_opt");
pm->AddPass(std::make_shared<OptimizeUpdateState>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
// Dump IR if save_graphs is set.
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
const bool save_graphs = context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (save_graphs) {
std::string filename = "hwopt_common_final_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(filename, kernel_graph);
}
}
} // namespace opt
} // namespace mindspore

View File

@ -20,6 +20,7 @@
namespace mindspore {
namespace opt {
void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
} // namespace opt
} // namespace mindspore

View File

@ -401,11 +401,9 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
}
auto output_info_list = iter->second;
for (const auto &output_info : output_info_list) {
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
output_info.second == kDependAttachNodeIndex) {
continue;
}
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) {
auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
(cnode_name == prim::kPrimUpdateState->name())) {
continue;
}
output_node_list->push_back(output_info);
@ -426,12 +424,13 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu
}
auto output_info_list = iter->second;
for (const auto &output_info : output_info_list) {
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
output_info.second == kDependAttachNodeIndex) {
auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
(cnode_name == prim::kPrimUpdateState->name())) {
continue;
}
size_t used_output_index;
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) {
if (cnode_name == prim::kPrimTupleGetItem->name()) {
used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
} else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
used_output_index = output_index;
@ -906,12 +905,13 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
// find BatchNorm's output which is a Depend
// Find BatchNorm's output which is a Depend or UpdateState.
for (const auto &node_index : manager->node_users()[old_node]) {
AnfNodePtr output = node_index.first;
size_t index = IntToSize(node_index.second);
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
auto depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend);
depend->set_input(index, new_node);

View File

@ -66,13 +66,14 @@ bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr
auto output_num = output->cast<CNodePtr>()->size() - 1;
getitem_list->clear();
getitem_list->resize(output_num, nullptr);
const auto &users = mng->node_users()[node];
auto users = mng->node_users()[node];
bool changed = false;
AnfNodePtrList user_nodes;
std::transform(users.begin(), users.end(), std::back_inserter(user_nodes),
[](const std::pair<AnfNodePtr, int> &user) { return user.first; });
for (const auto &getitem : user_nodes) {
MS_EXCEPTION_IF_NULL(getitem);
for (const auto &user : users) {
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
// Sometime, the user of MakeTuple is not a TupleGetItem, but a UpdateState.
continue;
}
auto &getitem = user.first;
auto idx = GetIndex(getitem);
if (idx >= output_num) {
MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString();

View File

@ -35,19 +35,17 @@ CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cno
const std::vector<AnfNodePtr> &new_depend_inputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
CNodePtr new_depend = nullptr;
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
if (kernel_graph == nullptr) {
new_depend = func_graph->NewCNode(new_depend_inputs);
auto new_depend = func_graph->NewCNode(new_depend_inputs);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_abstract(cnode->abstract());
new_depend->set_scope(cnode->scope());
} else {
new_depend = kernel_graph->NewCNode(cnode);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_inputs(new_depend_inputs);
return new_depend;
}
func_graph->manager()->Replace(cnode, new_depend);
auto new_depend = kernel_graph->NewCNode(cnode);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_inputs(new_depend_inputs);
return new_depend;
}
@ -77,9 +75,9 @@ AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, con
auto replace_node = eliminate_node->input(kSingleInputIndex);
std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs();
new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node;
auto new_cnode = CreateNewDependNode(func_graph, cnode, new_depend_inputs);
auto new_node = new_cnode->cast<AnfNodePtr>();
return new_node;
auto new_depend = CreateNewDependNode(func_graph, cnode, new_depend_inputs);
func_graph->manager()->Replace(cnode, new_depend);
return new_depend;
}
AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
@ -157,55 +155,53 @@ const BaseRef OptimizeDependence::DefinePattern() const {
return VectorRef({X, Xs});
}
std::pair<AnfNodePtr, size_t> SearchTransDataAndCast(const AnfNodePtr &node, bool is_first_node) {
if (node == nullptr || !node->isa<CNode>()) {
return std::pair<AnfNodePtr, size_t>(nullptr, 0);
std::vector<size_t> SearchTransDataAndCast(const CNodePtr &cnode) {
// Search Depend and UpdateState only.
if (!cnode->IsApply(prim::kPrimDepend) && !cnode->IsApply(prim::kPrimUpdateState)) {
return {};
}
// get real input of depend and update state.
size_t replace_input_index = 0;
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
replace_input_index = is_first_node ? kDependAttachNodeIndex : kRealInputIndexInDepend;
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
replace_input_index = is_first_node ? kUpdateStateStateInput : kUpdateStateRealInput;
} else {
return std::pair<AnfNodePtr, size_t>(nullptr, 0);
// Find inputs which is Cast or TransData.
std::vector<size_t> result;
for (size_t i = 1; i < cnode->size(); ++i) {
auto &input = cnode->input(i);
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimCast) ||
AnfAlgo::CheckPrimitiveType(input, prim::KPrimTransData) ||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
result.emplace_back(i);
}
}
// check whether real input is cast or trans data
auto real_input = node->cast<CNodePtr>()->input(replace_input_index);
if (AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimCast) ||
AnfAlgo::CheckPrimitiveType(real_input, prim::KPrimTransData) ||
AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimMakeTuple)) {
return std::pair<AnfNodePtr, size_t>(node, replace_input_index);
}
return SearchTransDataAndCast(real_input, false);
return result;
}
const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) {
return nullptr;
}
// Get the cnode with repalce input index
auto cnode_with_input_index = SearchTransDataAndCast(node, true);
if (cnode_with_input_index.first == nullptr) {
// Search inputs to be replaced.
auto candidate_inputs = SearchTransDataAndCast(cnode);
if (candidate_inputs.empty()) {
return nullptr;
}
size_t replace_index = cnode_with_input_index.second;
auto depend_cnode = cnode_with_input_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
// Get new node which will act as new input of depend or UpdateState.
std::vector<AnfNodePtr> new_depend_inputs = depend_cnode->inputs();
auto replace_node = GetConvertNode(func_graph, depend_cnode, replace_index);
if (replace_node == nullptr) {
return nullptr;
}
new_depend_inputs[replace_index] = replace_node;
auto new_depend = CreateNewDependNode(func_graph, depend_cnode, new_depend_inputs);
if (new_depend == nullptr) {
// Get new nodes which will act as new inputs of Depend or UpdateState.
std::vector<AnfNodePtr> new_inputs = cnode->inputs();
bool inputs_changed = false;
for (auto index : candidate_inputs) {
auto replace_node = GetConvertNode(func_graph, cnode, index);
if (replace_node != nullptr) {
new_inputs[index] = replace_node;
inputs_changed = true;
}
}
if (!inputs_changed) {
return nullptr;
}
// Create a new Depend node to replace the old one if inputs changed.
auto new_depend = CreateNewDependNode(func_graph, cnode, new_inputs);
func_graph->manager()->Replace(cnode, new_depend);
return nullptr;
}

View File

@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/pass/optimize_updatestate.h"
#include <memory>
#include <vector>
#include <string>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace opt {
constexpr size_t kInputIndex = 1;
constexpr size_t kAttachIndex = 2;
constexpr size_t kAdditionalAttachIndex = 3;
const BaseRef OptimizeUpdateState::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimUpdateState, Xs});
}
const AnfNodePtr OptimizeUpdateState::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
auto update_state = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(update_state);
if (update_state->size() <= kAdditionalAttachIndex) {
// Skip UpdateState nodes with no additional attaches.
return nullptr;
}
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
std::vector<AnfNodePtr> new_inputs;
new_inputs.emplace_back(update_state->input(0));
new_inputs.emplace_back(update_state->input(kInputIndex));
new_inputs.emplace_back(update_state->input(kAttachIndex));
for (size_t i = kAdditionalAttachIndex; i < update_state->size(); ++i) {
auto &attach = update_state->input(i);
auto &users = node_users[attach];
if ((users.size() == 1) && (users.front().first == update_state)) {
// If the only user of attach is the UpdateState node, drop the attach node.
continue;
}
new_inputs.emplace_back(attach);
}
if (new_inputs.size() == update_state->size()) {
// Attaches not changed.
return nullptr;
}
// Attaches changed, make a new UpdateState.
auto new_update_state = func_graph->NewCNode(new_inputs);
new_update_state->set_abstract(update_state->abstract());
new_update_state->set_scope(update_state->scope());
return new_update_state;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class OptimizeUpdateState : public PatternProcessPass {
public:
explicit OptimizeUpdateState(bool multigraph = true) : PatternProcessPass("optimize_updatestate", multigraph) {}
~OptimizeUpdateState() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_

View File

@ -1931,5 +1931,15 @@ void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_
{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
root_graph->set_output(make_tuple);
}
AnfNodeIndexSet AnfRuntimeAlgorithm::GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
AnfNodeIndexSet update_states;
for (auto &user : manager->node_users()[node]) {
if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) {
update_states.insert(user);
}
}
return update_states;
}
} // namespace session
} // namespace mindspore

View File

@ -267,6 +267,7 @@ class AnfRuntimeAlgorithm {
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited);
static void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph);
static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -936,6 +936,7 @@ void AscendSession::InitRuntimeResource() {
void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "HardwareOptimize start!";
opt::AscendBackendOptimization(kernel_graph);
FinalOptimize(kernel_graph);
GraphKernelOptimize(kernel_graph);
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->SetExecOrderByDefault();

View File

@ -104,6 +104,7 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
SetKernelInfo(graph.get());
MS_LOG(INFO) << "Set kernel info end";
Optimize(graph);
FinalOptimize(graph);
MS_LOG(INFO) << "Build kernel";
BuildKernel(graph.get());
// Remove reorder after PS feature finish adapting push/pull in auto_monad.

View File

@ -341,6 +341,8 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) {
SelectKernel(graph);
// Graph optimization relevant to device data format
HardwareOptimize(graph);
// Run final optimization
FinalOptimize(graph);
// Graph kernel fusion optimization
GraphKernelOptimize(graph);
// Start gpu kernel runtime

View File

@ -17,6 +17,7 @@
#include <algorithm>
#include <set>
#include <queue>
#include <unordered_map>
#include <utility>
@ -2343,6 +2344,12 @@ void SessionBasic::ClearAllBucket(const GraphId &graph_id) {
}
}
void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const {
MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id();
opt::CommonFinalOptimization(graph);
MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id();
}
void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
#ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance();

View File

@ -172,6 +172,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void UpdateOutputTensors(const VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node);
virtual void UnifyMindIR(const KernelGraphPtr &graph) {}
virtual void FinalOptimize(const KernelGraphPtr &graph) const;
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; }
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
virtual void BuildGraphImpl(GraphId) {}

View File

@ -32,6 +32,7 @@
#include "pipeline/jit/parse/parse_base.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "pipeline/jit/static_analysis/order_enforce.h"
#include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "pipeline/jit/static_analysis/program_specialize.h"
@ -343,6 +344,18 @@ bool AutoMonadAction(const ResourcePtr &res) {
return true;
}
bool OrderEnforceAction(const ResourcePtr &res) {
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "Order-Enforce error, manager is null";
}
auto func_graph = res->func_graph();
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Order-Enforce error, graph is null";
}
pipeline::OrderEnforce(func_graph);
return true;
}
bool InferenceOptPrepareAction(const ResourcePtr &res) {
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
@ -752,6 +765,7 @@ std::vector<ActionItem> GePipeline() {
// Add opt-stage python pass stub
actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
actions.emplace_back(std::make_pair("validate", ValidateAction));
return actions;
}
@ -765,6 +779,8 @@ std::vector<ActionItem> VmPipeline() {
// Add opt-stage python pass stub
actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
actions.emplace_back(std::make_pair("validate", ValidateAction));
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PSContext::instance()->is_worker()) {
@ -784,6 +800,7 @@ std::vector<ActionItem> VmPipeline() {
std::vector<ActionItem> PServerPipeline() {
auto actions = CommonPipeline();
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
actions.emplace_back(std::make_pair("validate", ValidateAction));
actions.emplace_back(std::make_pair("pserver", StartPSServerAction));
return actions;

View File

@ -0,0 +1,258 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pipeline/jit/static_analysis/order_enforce.h"
#include <algorithm>
#include <map>
#include <queue>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "base/core_ops.h"
namespace mindspore::pipeline {
namespace {
class OrderEnforcer {
public:
explicit OrderEnforcer(const FuncGraphPtr &func_graph) : func_graph_(func_graph), manager_(func_graph->manager()) {
MS_EXCEPTION_IF_NULL(func_graph_);
MS_EXCEPTION_IF_NULL(manager_);
}
~OrderEnforcer() = default;
void Run() {
auto nodes = MakeTopoSortMap();
for (auto &node : nodes) {
HandleNode(node);
}
}
private:
AnfNodePtrList MakeTopoSortMap() {
auto nodes = TopoSort(func_graph_->get_return());
for (size_t i = 0; i < nodes.size(); ++i) {
topo_sort_map_.emplace(nodes[i], i);
}
return nodes;
}
void HandleNode(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
// Skip nodes other than UpdateState.
return;
}
auto update_state = node->cast<CNodePtr>();
if (!HasAbstractUMonad(update_state->input(1))) {
// Skip UpdateStates for IO.
return;
}
auto updated_refs = FindUpdatedRefs(update_state);
if (updated_refs.empty()) {
// Skip UpdateStates that do not have updated refs.
return;
}
auto &attach = update_state->input(2);
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
// Handle UpdateState with Load.
EnforceOrderForLoad(update_state, attach->cast<CNodePtr>(), updated_refs);
} else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
// Handle UpdateState with MakeTuple.
EnforceOrderForTuple(update_state, attach->cast<CNodePtr>(), updated_refs);
}
}
std::unordered_set<AnfNodePtr> FindUpdatedRefs(const CNodePtr &update_state) {
std::unordered_set<AnfNodePtr> updated_refs;
auto &users = manager_->node_users()[update_state];
for (auto &user : users) {
auto cnode = dyn_cast<CNode>(user.first);
if (cnode == nullptr) {
continue;
}
if (cnode->IsApply(prim::kPrimLoad) || cnode->IsApply(prim::kPrimDepend) ||
cnode->IsApply(prim::kPrimUpdateState)) {
continue;
}
for (auto &input : cnode->inputs()) {
if (IsRef(input)) {
updated_refs.insert(input);
}
}
}
return updated_refs;
}
bool IsRef(const AnfNodePtr &node) {
auto &abs = node->abstract();
return abs != nullptr && abs->isa<abstract::AbstractRef>();
}
void EnforceOrderForLoad(const CNodePtr &update_state, const CNodePtr &load,
const std::unordered_set<AnfNodePtr> &refs) {
if (refs.find(load->input(1)) == refs.end()) {
// Skip if loaded parameter is not updated.
return;
}
// Find load users, ignore processed nodes.
auto load_users = FindLoadUsers(load, update_state);
// Find load users that not depend on the UpdateState,
// and than let UpdateState depend on them.
AddInputEdges(update_state, load_users);
}
void EnforceOrderForTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
const std::unordered_set<AnfNodePtr> &refs) {
// The UpdateState should be the only one user of the make_tuple.
// for performance, we only check the number of output edges.
if (manager_->node_users()[make_tuple].size() != 1) {
return;
}
// Find load users from the tuple of Load nodes.
std::unordered_set<AnfNodePtr> all_load_users;
auto &inputs = make_tuple->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto &input = inputs.at(i);
if (!IsPrimitiveCNode(input, prim::kPrimLoad)) {
// Skip non-Load nodes.
continue;
}
auto load = input->cast<CNodePtr>();
if (refs.find(load->input(1)) == refs.end()) {
// Skip if loaded parameter is not updated.
continue;
}
auto load_users = FindLoadUsers(load, make_tuple);
all_load_users.insert(load_users.begin(), load_users.end());
}
// Find load users that not depend on the UpdateState,
// and than let UpdateState depend on them.
AddInputEdges(update_state, all_load_users);
}
// Add load users as input edges of the update_state node.
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
auto sorted_load_users = SortLoadUsers(load_users);
for (auto &load_user : sorted_load_users) {
if (!IsDependOn(load_user, update_state)) {
processed_nodes_.insert(load_user);
manager_->AddEdge(update_state, load_user);
}
}
}
// Sort load users by their topo sort order.
std::vector<AnfNodePtr> SortLoadUsers(const std::unordered_set<AnfNodePtr> &load_users) {
std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()};
std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); });
return vec;
}
// Check if the load user node depend on the given UpdateState node.
bool IsDependOn(const AnfNodePtr &load_user, const AnfNodePtr &update_state) {
size_t update_state_order = topo_sort_map_[update_state];
if (topo_sort_map_[load_user] < update_state_order) {
return false;
}
auto user_cnode = dyn_cast<CNode>(load_user);
if (user_cnode == nullptr) {
return false;
}
size_t seen = NewSeenGeneration();
std::queue<CNodePtr> q;
user_cnode->seen_ = seen;
q.push(user_cnode);
while (!q.empty()) {
auto cnode = q.front();
q.pop();
for (auto &input : cnode->inputs()) {
if (input == update_state) {
// Dependency found.
return true;
}
if (input->seen_ == seen) {
// Skip visited nodes.
continue;
}
if (topo_sort_map_[input] < update_state_order) {
// Skip input nodes that before the UpdateState node.
continue;
}
auto input_cnode = dyn_cast<CNode>(input);
if (input_cnode != nullptr) {
input_cnode->seen_ = seen;
q.push(input_cnode);
}
}
}
return false;
}
bool IsBefore(const AnfNodePtr &node1, const AnfNodePtr &node2) {
return topo_sort_map_[node1] < topo_sort_map_[node2];
}
// Find Load users as the candidate nodes to enforce order of execution.
std::unordered_set<AnfNodePtr> FindLoadUsers(const CNodePtr &load, const AnfNodePtr &exclude) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(load);
if (iter == node_users.end()) {
return {};
}
std::unordered_set<AnfNodePtr> load_users;
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
if (user_node == exclude) {
// Skip excluded node.
continue;
}
if (processed_nodes_.find(user_node) != processed_nodes_.end()) {
// Skip processed nodes.
continue;
}
auto cnode = dyn_cast<CNode>(user_node);
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
const bool has_u_input =
std::any_of(inputs.begin() + 1, inputs.end(), [](const AnfNodePtr &input) { return HasAbstractUMonad(input); });
if (has_u_input) {
// Skip nodes with memory side effects, which use u input.
continue;
}
load_users.insert(cnode);
}
return load_users;
}
private:
const FuncGraphPtr &func_graph_;
FuncGraphManagerPtr manager_;
std::unordered_map<AnfNodePtr, size_t> topo_sort_map_;
std::unordered_set<AnfNodePtr> processed_nodes_;
};
} // namespace
//
// Enforce order of execution for Load users node.
//
void OrderEnforce(const FuncGraphPtr &func_graph) {
OrderEnforcer enforcer(func_graph);
enforcer.Run();
}
} // namespace mindspore::pipeline

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_
#include "ir/func_graph.h"
namespace mindspore::pipeline {
// Enforce order of execution of the given graph.
void OrderEnforce(const FuncGraphPtr &func_graph);
} // namespace mindspore::pipeline
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_

View File

@ -1456,7 +1456,10 @@ def test_while_forward():
assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001)
@pytest.mark.skip(reason="not supported yet")
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_multi_add_assign():
class Net(Cell):
def __init__(self, i1):
@ -1493,7 +1496,10 @@ def test_multi_add_assign():
np.testing.assert_array_equal(outputs, expects)
@pytest.mark.skip(reason="not supported yet")
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_multi_abs_add_assign():
class Net(Cell):
def __init__(self, para):