!46516 add cell reuse inline

Merge pull request !46516 from 王禹程/inline
This commit is contained in:
i-robot 2022-12-30 06:51:15 +00:00 committed by Gitee
commit f65cfcaf4d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
20 changed files with 711 additions and 101 deletions

View File

@ -78,6 +78,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
device_target_(DeviceType::kUnknown), device_target_(DeviceType::kUnknown),
executable_(true), executable_(true),
summary_node_exist_(false), summary_node_exist_(false),
need_inline_(false),
start_label_(nullptr), start_label_(nullptr),
end_goto_(nullptr), end_goto_(nullptr),
current_epoch_(0), current_epoch_(0),
@ -102,6 +103,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
updated_parameters_ = graph.updated_parameters_; updated_parameters_ = graph.updated_parameters_;
executable_ = graph.executable_; executable_ = graph.executable_;
summary_node_exist_ = graph.summary_node_exist_; summary_node_exist_ = graph.summary_node_exist_;
need_inline_ = graph.need_inline_;
valid_inputs_ = graph.valid_inputs_; valid_inputs_ = graph.valid_inputs_;
child_graph_order_ = graph.child_graph_order_; child_graph_order_ = graph.child_graph_order_;
device_loop_ctrl_tensors_ = graph.device_loop_ctrl_tensors_; device_loop_ctrl_tensors_ = graph.device_loop_ctrl_tensors_;
@ -216,6 +218,10 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
#endif #endif
// check whether exist summary node in graph // check whether exist summary node in graph
bool summary_node_exist() const { return summary_node_exist_; } bool summary_node_exist() const { return summary_node_exist_; }
// set need inline
void set_need_inline(bool need_inline) { need_inline_ = need_inline; }
// check whether need inline
bool need_inline() const { return need_inline_; }
// set invalid inputs for control sink // set invalid inputs for control sink
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; } std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
std::vector<bool> valid_inputs() const { return valid_inputs_; } std::vector<bool> valid_inputs() const { return valid_inputs_; }
@ -520,6 +526,8 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
bool summary_node_exist_{false}; bool summary_node_exist_{false};
// valid inputs // valid inputs
std::vector<bool> valid_inputs_; std::vector<bool> valid_inputs_;
// need inline
bool need_inline_;
// child graph execute order in parent graph // child graph execute order in parent graph
std::vector<std::weak_ptr<KernelGraph>> child_graph_order_; std::vector<std::weak_ptr<KernelGraph>> child_graph_order_;

View File

@ -735,9 +735,20 @@ std::vector<AnfNodePtr> KernelGraphMgr::CreateSwitchOrPartialNode(const CNodePtr
auto info = kernel_graph_partial_map_[sub_kernel_graph.get()]; auto info = kernel_graph_partial_map_[sub_kernel_graph.get()];
call_node->set_abstract(info.abstract); call_node->set_abstract(info.abstract);
cnode_inputs.emplace_back(info.sub_graph); cnode_inputs.emplace_back(info.sub_graph);
if (info.param_begin != tuple_get_idx) { if (common::GetEnv("MS_DEV_GRAPH_REUSE") == "2") {
// call_graph and info.sub_graph need inline when cell reuse.
sub_kernel_graph->set_need_inline(true);
auto partial_sub_graph = GetValueNodeKernelGraph(info.sub_graph);
MS_EXCEPTION_IF_NULL(partial_sub_graph);
partial_sub_graph->set_need_inline(true);
MS_LOG(INFO) << "Inline graph " << sub_kernel_graph->graph_id() << " and graph "
<< partial_sub_graph->graph_id();
}
MS_LOG(INFO) << "Use cell reuse: " << sub_kernel_graph->graph_id();
if (info.param_begin != tuple_get_idx + std::max(static_cast<int>(info.multi_tuple) - 1, 0)) {
MS_LOG(EXCEPTION) << "Call param is not a graph, the TupleGetItem index: " << tuple_get_idx MS_LOG(EXCEPTION) << "Call param is not a graph, the TupleGetItem index: " << tuple_get_idx
<< ", the partial graph index: " << info.param_begin << ", the partial graph index: " << info.param_begin
<< ", need idx: " << tuple_get_idx + std::max(static_cast<int>(info.multi_tuple) - 1, 0)
<< ", call graph: " << call_graph->fullname_with_scope(); << ", call graph: " << call_graph->fullname_with_scope();
} }
for (size_t i = info.param_begin; i < info.param_end; i++) { for (size_t i = info.param_begin; i < info.param_end; i++) {
@ -867,6 +878,31 @@ ParameterPtr KernelGraphMgr::CreateNewParameter(const AnfNodePtr &anf, KernelGra
return new_parameter; return new_parameter;
} }
void KernelGraphMgr::FlattenTuple(const CNodePtr &node, KernelGraph *graph) {
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
auto call_graph = node->input(kFirstIndex);
auto sub_kernel_graph = GetValueNodeKernelGraph(call_graph);
MS_EXCEPTION_IF_NULL(sub_kernel_graph);
auto iter = kernel_graph_partial_map_.find(sub_kernel_graph.get());
if (iter != kernel_graph_partial_map_.end() && iter->second.multi_tuple != 0) {
need_flatten_.insert(node);
}
} else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
auto input = node->input(kFirstIndex);
auto get_idx = common::AnfAlgo::GetTupleGetItemOutIndex(node);
if (need_flatten_.find(input) != need_flatten_.end() && get_idx == 0) {
need_flatten_tuple_map_[node] = input;
}
}
for (size_t i = 0; i < common::AnfAlgo::GetInputNum(node); i++) {
auto input = common::AnfAlgo::GetInputNode(node, i);
auto iter = need_flatten_tuple_map_.find(input);
if (iter != need_flatten_tuple_map_.end()) {
node->set_input(i + 1, iter->second);
}
}
}
bool KernelGraphMgr::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) { bool KernelGraphMgr::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
@ -883,6 +919,7 @@ bool KernelGraphMgr::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGrap
new_cnode->set_scope(cnode->scope()); new_cnode->set_scope(cnode->scope());
graph->FrontBackendMapAdd(node, new_cnode); graph->FrontBackendMapAdd(node, new_cnode);
SetReturnNode(new_cnode, graph); SetReturnNode(new_cnode, graph);
FlattenTuple(new_cnode, graph);
return true; return true;
} }
@ -938,14 +975,40 @@ void KernelGraphMgr::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
auto last_input_node = common::AnfAlgo::GetInputNode(make_tuple, tuple_input_num - 1); auto last_input_node = common::AnfAlgo::GetInputNode(make_tuple, tuple_input_num - 1);
MS_EXCEPTION_IF_NULL(last_input_node); MS_EXCEPTION_IF_NULL(last_input_node);
if (last_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(last_input_node, prim::kPrimPartial)) { if (last_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(last_input_node, prim::kPrimPartial)) {
size_t multi_tuple = 0;
auto partial_node = last_input_node->cast<CNodePtr>(); auto partial_node = last_input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_node); MS_EXCEPTION_IF_NULL(partial_node);
size_t partial_input_num = common::AnfAlgo::GetInputTensorNum(partial_node); size_t partial_input_num = common::AnfAlgo::GetInputTensorNum(partial_node);
std::vector<AnfNodePtr> make_tuple_inputs; std::vector<AnfNodePtr> make_tuple_inputs;
make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
// skip last return node (is a partial) // skip last return node (is a partial)
size_t param_begin = 0;
for (size_t i = 0; i < tuple_input_num - 1; i++) { for (size_t i = 0; i < tuple_input_num - 1; i++) {
make_tuple_inputs.emplace_back(common::AnfAlgo::GetInputNode(make_tuple, i)); auto input = common::AnfAlgo::GetInputNode(make_tuple, i);
auto node_abs = input->abstract();
if (node_abs->isa<abstract::AbstractTuple>()) {
MS_EXCEPTION_IF_CHECK_FAIL(
i == 0, "Input index: " + std::to_string(i) + " is a make tuple, input node: " + input->DebugString());
MS_LOG(DEBUG) << "Flatten the make tuple, input node: " << input->DebugString()
<< ", output num: " << AnfUtils::GetOutputTensorNum(input);
// flatten the make tuple
for (size_t j = 0; j < AnfUtils::GetOutputTensorNum(input); j++) {
auto idx = NewValueNode(SizeToLong(j));
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(j);
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
std::vector<TypeId> types = {common::AnfAlgo::GetOutputInferDataType(input, j)};
auto shapes = {common::AnfAlgo::GetOutputInferShape(input, j)};
common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
param_begin++;
multi_tuple++;
make_tuple_inputs.emplace_back(getitem);
}
} else {
param_begin++;
make_tuple_inputs.emplace_back(input);
}
} }
// skip partial graph // skip partial graph
for (size_t i = kFirstIndex; i < partial_input_num; i++) { for (size_t i = kFirstIndex; i < partial_input_num; i++) {
@ -965,8 +1028,8 @@ void KernelGraphMgr::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
g_output->set_abstract(abstract); g_output->set_abstract(abstract);
graph->set_output(g_output); graph->set_output(g_output);
kernel_graph_partial_map_[graph] = {abstract, common::AnfAlgo::GetInputNode(partial_node, 0), kernel_graph_partial_map_[graph] = {abstract, common::AnfAlgo::GetInputNode(partial_node, 0), param_begin,
tuple_input_num - 1, common::AnfAlgo::GetInputTensorNum(g_output)}; common::AnfAlgo::GetInputTensorNum(g_output), multi_tuple};
} }
} }
} }
@ -1041,6 +1104,10 @@ std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(const FuncGrap
auto graph = NewKernelGraph(); auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
front_backend_graph_map_[func_graph.get()] = graph; front_backend_graph_map_[func_graph.get()] = graph;
if (func_graph->has_flag(FUNC_GRAPH_FLAG_NEED_BACKEND_INLINE)) {
MS_LOG(INFO) << "Need backend inline: " << graph->graph_id();
graph->set_need_inline(true);
}
MS_LOG(INFO) << "Create graph: " << graph->graph_id(); MS_LOG(INFO) << "Create graph: " << graph->graph_id();
graph->set_device_target(device_target); graph->set_device_target(device_target);
// Create parameter // Create parameter

View File

@ -46,6 +46,7 @@ struct PartialFuncInfo {
AnfNodePtr sub_graph; AnfNodePtr sub_graph;
size_t param_begin; size_t param_begin;
size_t param_end; size_t param_end;
size_t multi_tuple;
}; };
class BACKEND_EXPORT KernelGraphMgr { class BACKEND_EXPORT KernelGraphMgr {
@ -111,6 +112,7 @@ class BACKEND_EXPORT KernelGraphMgr {
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) const; ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) const;
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) const; void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) const;
void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph); void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph);
void FlattenTuple(const CNodePtr &node, KernelGraph *graph);
protected: protected:
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph); CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
@ -123,6 +125,8 @@ class BACKEND_EXPORT KernelGraphMgr {
mindspore::HashMap<AnfNodePtr, ParameterPtr> default_param_map_; mindspore::HashMap<AnfNodePtr, ParameterPtr> default_param_map_;
mindspore::HashMap<FuncGraph *, KernelGraphPtr> front_backend_graph_map_; mindspore::HashMap<FuncGraph *, KernelGraphPtr> front_backend_graph_map_;
mindspore::HashMap<KernelGraph *, PartialFuncInfo> kernel_graph_partial_map_; mindspore::HashMap<KernelGraph *, PartialFuncInfo> kernel_graph_partial_map_;
mindspore::HashSet<AnfNodePtr> need_flatten_;
mindspore::HashMap<AnfNodePtr, AnfNodePtr> need_flatten_tuple_map_;
static GraphId graph_sum_; static GraphId graph_sum_;
}; };
} // namespace session } // namespace session

View File

@ -918,6 +918,10 @@ constexpr auto kAttrInsertDefaultValue = "insert_default_value";
constexpr auto kAttrIsSparse = "IsSparse"; constexpr auto kAttrIsSparse = "IsSparse";
constexpr auto kAttrKernelBackoffWithFailureInfo = "kernel_backoff_with_failure_info"; constexpr auto kAttrKernelBackoffWithFailureInfo = "kernel_backoff_with_failure_info";
constexpr auto kAttrKernelBackoffWithFailureType = "kernel_backoff_with_failure_type"; constexpr auto kAttrKernelBackoffWithFailureType = "kernel_backoff_with_failure_type";
constexpr auto kAttrKernelGraph = "kernel_graph";
constexpr auto kAttrPreKernelGraph = "pre_kernel_graph";
constexpr auto kAttrNeedInline = "need_inline";
constexpr auto kAttrOriFusionName = "ori_fusion_name";
// FuncGraph Flags // FuncGraph Flags
constexpr auto kFlagIsDynamicStructure = "is_dynamic_structure"; constexpr auto kFlagIsDynamicStructure = "is_dynamic_structure";

View File

@ -628,6 +628,9 @@ FuncGraphPtr GenerateReusingGraph(const FuncGraphPtr &fg) {
} }
} }
MS_LOG(DEBUG) << "The reusable graph parameter size: " << reusing_graph->parameters().size(); MS_LOG(DEBUG) << "The reusable graph parameter size: " << reusing_graph->parameters().size();
if (common::GetEnv("MS_DEV_GRAPH_REUSE") == "2") {
reusing_graph->set_flag(FUNC_GRAPH_FLAG_NEED_BACKEND_INLINE, true);
}
return reusing_graph; return reusing_graph;
} }
@ -652,14 +655,9 @@ void ReplaceWithReusingGraph(const FuncGraphPtr &reusing_graph, const FuncGraphP
// Make the reusable cell to be the reusable function graph. // Make the reusable cell to be the reusable function graph.
bool GraphReusingAction(const ResourcePtr &resource) { bool GraphReusingAction(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource); MS_EXCEPTION_IF_NULL(resource);
constexpr size_t graph_reusing_count = 2;
const auto &obj_map = parse::data_converter::GetObjGraphs(); const auto &obj_map = parse::data_converter::GetObjGraphs();
for (const auto &[cell_key, graphs] : obj_map) { for (const auto &[cell_key, graphs] : obj_map) {
MS_LOG(DEBUG) << "Start to handle the reusable graph: " << cell_key << ", size: " << graphs.size(); MS_LOG(DEBUG) << "Start to handle the reusable graph: " << cell_key << ", size: " << graphs.size();
// Only make the reusable cell that is used more than graph_reusing_count to be reusable.
if (graphs.size() < graph_reusing_count) {
continue;
}
const auto &fg = graphs[0]; const auto &fg = graphs[0];
// fg->paramter_obj_nodes().empty() have been handled by combine like. // fg->paramter_obj_nodes().empty() have been handled by combine like.
if (!fg->paramter_obj_nodes().empty()) { if (!fg->paramter_obj_nodes().empty()) {
@ -1514,7 +1512,8 @@ static std::vector<ActionItem> CommonPipeline() {
} }
// Make the reusable cell to be the reusable function graph // Make the reusable cell to be the reusable function graph
static bool enable_graph_reusing = (common::GetEnv("MS_DEV_GRAPH_REUSE") == "1"); static bool enable_graph_reusing =
(common::GetEnv("MS_DEV_GRAPH_REUSE") == "1" || common::GetEnv("MS_DEV_GRAPH_REUSE") == "2");
if (enable_graph_reusing) { if (enable_graph_reusing) {
(void)actions.emplace_back(std::make_pair("graph_reusing", GraphReusingAction)); (void)actions.emplace_back(std::make_pair("graph_reusing", GraphReusingAction));
} }

View File

@ -869,7 +869,11 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
SelectGraphKernelInfo(kernel_node, func_graph); SelectGraphKernelInfo(kernel_node, func_graph);
return result; return result;
} }
if (IsPrimitiveCNode(kernel_node, prim::kPrimCallInline)) {
opt::SelectCallInlineKernelInfo(kernel_node);
SetTensorDeviceInfo(kernel_node);
return result;
}
if (common::AnfAlgo::HasNodeAttr(ops::kBatchRank, kernel_node)) { if (common::AnfAlgo::HasNodeAttr(ops::kBatchRank, kernel_node)) {
std::stringstream ss; std::stringstream ss;
ss << common::AnfAlgo::GetCNodeName(kernel_node) ss << common::AnfAlgo::GetCNodeName(kernel_node)

View File

@ -57,6 +57,7 @@ const char ITEREND[] = "PROFILING_ITER_END";
const auto kSingleOutput = 1; const auto kSingleOutput = 1;
const auto kFirstOutput = 0; const auto kFirstOutput = 0;
constexpr size_t kFirstIndex = 1;
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
bool IsSaveGraph() { bool IsSaveGraph() {
@ -142,6 +143,20 @@ void DumpExecuteOrder(const NotNull<KernelGraphPtr> &kg) {
} }
#endif #endif
KernelGraphPtr GetValueNodeKernelGraph(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return nullptr;
}
auto value = value_node->value();
if (value == nullptr) {
return nullptr;
}
auto kernel_graph = value->cast<KernelGraphPtr>();
return kernel_graph;
}
// Return kNoLabel when label id attribute not set for the graph. // Return kNoLabel when label id attribute not set for the graph.
uint32_t GetGraphLabel(const KernelGraphPtr &kg) { uint32_t GetGraphLabel(const KernelGraphPtr &kg) {
auto value = kg->get_attr(kAttrLabelIndex); auto value = kg->get_attr(kAttrLabelIndex);
@ -151,6 +166,15 @@ uint32_t GetGraphLabel(const KernelGraphPtr &kg) {
return GetValue<uint32_t>(value); return GetValue<uint32_t>(value);
} }
bool CheckCallInline(const CNodePtr &cnode) {
if (!common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
return false;
}
auto call_graph = cnode->input(kFirstIndex);
auto sub_kernel_graph = GetValueNodeKernelGraph(call_graph);
return sub_kernel_graph->need_inline();
}
bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2); bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2);
bool CheckAbstractTupleIsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) { bool CheckAbstractTupleIsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) {
@ -501,6 +525,12 @@ class CallInfoFinder {
// Find call-return pairs. // Find call-return pairs.
void FindCallReturns() { void FindCallReturns() {
for (auto &[caller, call_info] : context_.call_info_map) {
if (caller->need_inline() && (call_info.recursive || call_info.call_sites.size() != 0)) {
MS_LOG(INFO) << "Do not inline cell reuse because it has sub-graph call, graph id: " << caller->graph_id();
caller->set_need_inline(false);
}
}
for (auto &[caller, call_info] : context_.call_info_map) { for (auto &[caller, call_info] : context_.call_info_map) {
for (auto &call_site : call_info.call_sites) { for (auto &call_site : call_info.call_sites) {
for (auto &callee : call_site.callees) { for (auto &callee : call_site.callees) {
@ -623,7 +653,7 @@ class CallInfoFinder {
} }
// Create a parameter for the return value. // Create a parameter for the return value.
if (call_site->out_param == nullptr) { if (call_site->out_param == nullptr && !CheckCallInline(call_site->cnode)) {
call_site->out_param = context_.CreateParameter(call_site->cnode->abstract()); call_site->out_param = context_.CreateParameter(call_site->cnode->abstract());
} }
// Add a return point for the callee graph. // Add a return point for the callee graph.
@ -634,7 +664,7 @@ class CallInfoFinder {
// Setup label index if there are multi return points. // Setup label index if there are multi return points.
const auto n_return_points = call_info.return_points.size(); const auto n_return_points = call_info.return_points.size();
const size_t return_point_sizes = 2; const size_t return_point_sizes = 2;
if (n_return_points > 1) { if (n_return_points > 1 && !CheckCallInline(call_site->cnode)) {
if (n_return_points == return_point_sizes) { if (n_return_points == return_point_sizes) {
// Create a parameter to store label index. // Create a parameter to store label index.
const ShapeVector shape = {1}; const ShapeVector shape = {1};
@ -753,6 +783,10 @@ class AscendAutoMonadConverter {
~AscendAutoMonadConverter() = default; ~AscendAutoMonadConverter() = default;
void Run() { void Run() {
// need inline
if (kernel_graph_->need_inline()) {
return;
}
// Create an stack // Create an stack
InitStack(); InitStack();
// Setup entry label if found. // Setup entry label if found.
@ -1033,6 +1067,20 @@ class AscendAutoMonadConverter {
// The call/switch/switch_layer cnode. // The call/switch/switch_layer cnode.
auto &cnode = call_site->cnode; auto &cnode = call_site->cnode;
if (CheckCallInline(cnode)) {
auto call_graph = cnode->input(kFirstIndex);
auto sub_kernel_graph = GetValueNodeKernelGraph(call_graph);
std::vector<AnfNodePtr> call_inline_inputs = {NewPrimitive(prim::kPrimCallInline)};
for (size_t i = kFirstIndex; i < common::AnfAlgo::GetInputNum(cnode); i++) {
call_inline_inputs.emplace_back(common::AnfAlgo::GetInputNode(cnode, i));
}
auto call_inline = kernel_graph_->NewCNode(call_inline_inputs);
MS_EXCEPTION_IF_NULL(call_inline);
call_inline->set_abstract(cnode->abstract());
common::AnfAlgo::SetNodeAttr(kAttrKernelGraph, MakeValue(sub_kernel_graph), call_inline);
ReplaceNode(cnode, call_inline);
return;
}
// Get branches of the call_site. // Get branches of the call_site.
// for call, there is one branch; // for call, there is one branch;

View File

@ -36,6 +36,7 @@
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
#include "include/common/debug/anf_ir_dump.h" #include "include/common/debug/anf_ir_dump.h"
#include "include/common/debug/dump_proto.h" #include "include/common/debug/dump_proto.h"
#include "ir/func_graph_cloner.h"
#endif #endif
namespace mindspore { namespace mindspore {
@ -95,6 +96,56 @@ bool SetDefaultFormatForSpecialAclOp(const KernelGraphPtr &graph) {
} }
return need_change_format; return need_change_format;
} }
AnfNodePtr DoInline(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
Cloner cloner({}, false);
if (scope != nullptr) {
cloner.set_scope(scope);
}
cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
auto node_list = TopoSort(func_graph->output());
for (auto &ori_node : node_list) {
if (ori_node->isa<Parameter>()) {
continue;
}
auto new_node = cloner[ori_node];
auto kernel_info = dynamic_cast<device::KernelInfo *>(new_node->kernel_info());
// deep copy kernel info
if (kernel_info != nullptr && new_node->kernel_info()->has_build_info()) {
// some check
MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->MutableKernelMod() == nullptr,
"Inline ERROR: " + ori_node->DebugString() + ", kernel mod is not nullptr");
MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->output_address_list().empty(),
"Inline ERROR: " + ori_node->DebugString() + ", output_address_list is not empty");
MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->workspace_address_list().empty(),
"Inline ERROR: " + ori_node->DebugString() + ", workspace_address_list is not empty");
auto new_kernel_info = std::make_shared<device::KernelInfo>();
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(new_node));
new_kernel_info->set_select_kernel_build_info(builder->Build());
new_kernel_info->set_graph_id(kernel_info->graph_id());
new_kernel_info->set_feature_map_flag(kernel_info->is_feature_map());
new_kernel_info->set_ref_map(false, kernel_info->out_in_ref_map());
new_node->set_kernel_info(new_kernel_info);
}
if (ori_node->isa<CNode>()) {
auto ori_cnode = ori_node->cast<CNodePtr>();
if (common::AnfAlgo::HasNodeAttr(kAttrIsUBFusionOp, ori_cnode) &&
common::AnfAlgo::GetNodeAttr<bool>(ori_node, kAttrIsUBFusionOp)) {
// already done fusion compile
auto ori_full_name = ori_cnode->fullname_with_scope();
common::AnfAlgo::SetNodeAttr(kAttrOriFusionName, MakeValue(ori_full_name), new_node);
}
common::AnfAlgo::SetNodeAttr(kAttrNeedInline, MakeValue(ori_node->fullname_with_scope()), new_node);
common::AnfAlgo::SetNodeAttr(kAttrPreKernelGraph, MakeValue(func_graph), new_node);
}
}
return cloner[func_graph->output()];
}
} // namespace } // namespace
void AscendGraphOptimization::Reset() { void AscendGraphOptimization::Reset() {
@ -103,6 +154,40 @@ void AscendGraphOptimization::Reset() {
graph_manager_->Clear(); graph_manager_->Clear();
} }
void AscendGraphOptimization::InlineSubGraph(const KernelGraphPtr &graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#ifdef ENABLE_DUMP_IR
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (save_graphs) {
std::string file_name = "hwopt_d_before_inline_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph, true, kWholeStack);
}
#endif
auto kernel_cnodes = graph->execution_order();
for (auto &kernel_cnode : kernel_cnodes) {
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnode, prim::kPrimCallInline)) {
auto sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(kernel_cnode, kAttrKernelGraph);
MS_LOG(INFO) << "InlineSubGraph: " << kernel_cnode->DebugString() << ", sub graph: " << sub_graph->graph_id()
<< ", need inline: " << sub_graph->need_inline();
auto main_graph = kernel_cnode->func_graph();
auto mng = main_graph->manager();
AnfNodePtrList inp(kernel_cnode->inputs().begin() + 1, kernel_cnode->inputs().end());
auto out = DoInline(sub_graph, main_graph, inp, kernel_cnode->input(0)->scope());
(void)mng->Replace(kernel_cnode, out);
}
}
memo_.clear();
opt::AscendAfterInlineOptimization(graph);
memo_.clear();
#ifdef ENABLE_DUMP_IR
if (save_graphs) {
std::string file_name = "hwopt_d_after_inline_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph, true, kWholeStack);
}
#endif
}
void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) { void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Status record: start optimize graph. graph id: " << graph->graph_id(); MS_LOG(INFO) << "Status record: start optimize graph. graph id: " << graph->graph_id();
@ -118,6 +203,9 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
OptimizeGraphWithoutDeviceInfo(graph); OptimizeGraphWithoutDeviceInfo(graph);
SelectKernel(graph); SelectKernel(graph);
OptimizeGraphWithDeviceInfo(graph); OptimizeGraphWithDeviceInfo(graph);
// inline func before gen execution order
InlineSubGraph(graph);
OptimizeExecutionOrder(graph); OptimizeExecutionOrder(graph);
PostOptimization(graph); PostOptimization(graph);
@ -195,8 +283,8 @@ void AscendGraphOptimization::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
memo_.clear(); memo_.clear();
HardWareOptimization(graph); HardWareOptimization(graph);
// copy child graph ref output map to father graph ref output map
memo_.clear(); memo_.clear();
// copy child graph ref output map to father graph ref output map
UpdateRefOutputMap(graph); UpdateRefOutputMap(graph);
AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(graph)); AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(graph));
RemoveUnusedValueNode(graph); RemoveUnusedValueNode(graph);
@ -288,6 +376,9 @@ void AscendGraphOptimization::HardWareOptimization(const KernelGraphPtr &graph)
return; return;
} }
(void)memo_.insert(graph); (void)memo_.insert(graph);
for (auto &child_graph : graph->child_graph_order()) {
HardWareOptimization(child_graph.lock());
}
opt::AscendBackendOptimization(graph); opt::AscendBackendOptimization(graph);
opt::CommonFinalOptimization(graph); opt::CommonFinalOptimization(graph);
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
@ -295,10 +386,6 @@ void AscendGraphOptimization::HardWareOptimization(const KernelGraphPtr &graph)
graph->SetExecOrderByDefault(); graph->SetExecOrderByDefault();
} }
MS_LOG(INFO) << "Status record: end hardware optimize. graph id: " << graph->graph_id(); MS_LOG(INFO) << "Status record: end hardware optimize. graph id: " << graph->graph_id();
for (auto &child_graph : graph->child_graph_order()) {
HardWareOptimization(child_graph.lock());
}
} }
void AscendGraphOptimization::AddGraphToManager(const NotNull<KernelGraphPtr> graph, void AscendGraphOptimization::AddGraphToManager(const NotNull<KernelGraphPtr> graph,
@ -349,6 +436,11 @@ void AscendGraphOptimization::RecurseSelectKernelInfo(const KernelGraphPtr &grap
return; return;
} }
(void)memo_.insert(graph); (void)memo_.insert(graph);
for (auto &child_graph : graph->child_graph_order()) {
if (child_graph.lock()->need_inline()) {
RecurseSelectKernelInfo(child_graph.lock());
}
}
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
@ -361,16 +453,16 @@ void AscendGraphOptimization::RecurseSelectKernelInfo(const KernelGraphPtr &grap
MS_LOG(INFO) << "Status record: start select kernel info. graph id: " << graph->graph_id(); MS_LOG(INFO) << "Status record: start select kernel info. graph id: " << graph->graph_id();
SetOperatorInfo(graph); SetOperatorInfo(graph);
MS_LOG(INFO) << "Status record: end select kernel info. graph id: " << graph->graph_id(); MS_LOG(INFO) << "Status record: end select kernel info. graph id: " << graph->graph_id();
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
if (save_graphs) { if (save_graphs) {
std::string file_name = "select_kernel_after_graph_" + std::to_string(graph->graph_id()) + ".ir"; std::string file_name = "select_kernel_after_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph); DumpIR(file_name, graph);
} }
#endif #endif
for (auto &child_graph : graph->child_graph_order()) { for (auto &child_graph : graph->child_graph_order()) {
RecurseSelectKernelInfo(child_graph.lock()); if (!child_graph.lock()->need_inline()) {
RecurseSelectKernelInfo(child_graph.lock());
}
} }
} }

View File

@ -51,6 +51,7 @@ class AscendGraphOptimization {
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph); void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph);
void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph); void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph);
void OptimizeExecutionOrder(const KernelGraphPtr &graph); void OptimizeExecutionOrder(const KernelGraphPtr &graph);
void InlineSubGraph(const KernelGraphPtr &graph);
void PostOptimization(const KernelGraphPtr &graph) const; void PostOptimization(const KernelGraphPtr &graph) const;
// Graph Optimized level-3 interface // Graph Optimized level-3 interface

View File

@ -350,7 +350,7 @@ void TbeKernelCompileManager::GetAllTbeNodes(const std::shared_ptr<session::Kern
auto all_nodes = kernel_graph->execution_order(); auto all_nodes = kernel_graph->execution_order();
for (const auto &anf_node : all_nodes) { for (const auto &anf_node : all_nodes) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
if (!AnfUtils::IsRealKernel(anf_node)) { if (!AnfUtils::IsRealKernel(anf_node) || IsPrimitiveCNode(anf_node, prim::kPrimCallInline)) {
continue; continue;
} }
KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
@ -564,6 +564,9 @@ std::pair<std::vector<CNodePtr>, std::vector<CNodePtr>> TbeKernelCompileManager:
continue; // kernel mod already exist, continue; continue; // kernel mod already exist, continue;
} }
auto full_name = node->fullname_with_scope(); auto full_name = node->fullname_with_scope();
if (common::AnfAlgo::HasNodeAttr(kAttrOriFusionName, node)) {
full_name = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrOriFusionName);
}
auto json_name = full_name_to_json_name_[full_name]; auto json_name = full_name_to_json_name_[full_name];
auto kernel_pack = tbe::TbeUtils::SearchCache(json_name, false); auto kernel_pack = tbe::TbeUtils::SearchCache(json_name, false);
if (kernel_pack == nullptr) { if (kernel_pack == nullptr) {

View File

@ -94,6 +94,7 @@
#include "plugin/device/ascend/optimizer/ir_fusion/softmax_dropout_do_mask_v3_fusion.h" #include "plugin/device/ascend/optimizer/ir_fusion/softmax_dropout_do_mask_v3_fusion.h"
#include "plugin/device/ascend/optimizer/ir_fusion/conv2d_backprop_input_dilation_fusion.h" #include "plugin/device/ascend/optimizer/ir_fusion/conv2d_backprop_input_dilation_fusion.h"
#include "plugin/device/ascend/optimizer/format_type/insert_trans_op.h" #include "plugin/device/ascend/optimizer/format_type/insert_trans_op.h"
#include "plugin/device/ascend/optimizer/format_type/reselect_call_inline_format.h"
#include "plugin/device/ascend/optimizer/format_type/trans_op_format_refine.h" #include "plugin/device/ascend/optimizer/format_type/trans_op_format_refine.h"
#include "plugin/device/ascend/optimizer/format_type/dynamic_rnn_grad_reformat.h" #include "plugin/device/ascend/optimizer/format_type/dynamic_rnn_grad_reformat.h"
#include "plugin/device/ascend/optimizer/format_type/insert_transpose_for_basiclstm_op.h" #include "plugin/device/ascend/optimizer/format_type/insert_transpose_for_basiclstm_op.h"
@ -285,6 +286,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("transop_pm"); auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
data_layout_pm->AddPass(std::make_shared<ReselectCallInlineFormat>());
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>()); data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>()); data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>());
data_layout_pm->AddPass(std::make_shared<ChangeAxisOfReduceKernel>()); data_layout_pm->AddPass(std::make_shared<ChangeAxisOfReduceKernel>());
@ -509,6 +511,16 @@ void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph>
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();
} }
void AscendAfterInlineOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<GraphOptimizer>();
auto after_inline_pm = std::make_shared<PassManager>("after_inline_pm");
after_inline_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
after_inline_pm->AddPass(std::make_shared<EliminateRedundantOp>());
optimizer->AddPassManager(after_inline_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "Status record: start ascend backend(data layer & mix precision ...) pass. graph id: " MS_LOG(INFO) << "Status record: start ascend backend(data layer & mix precision ...) pass. graph id: "

View File

@ -25,6 +25,7 @@ void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph>
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendAfterInlineOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph);

View File

@ -635,6 +635,41 @@ void SetInputOutputNames(const std::vector<std::string> &input_names, const std:
common::AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), node); common::AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), node);
} }
void SelectCallInlineKernelInfo(const CNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimCallInline)) {
return;
}
// need inline
auto sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(node, kAttrKernelGraph);
auto sub_ret = sub_graph->output();
std::vector<std::string> input_formats;
std::vector<TypeId> input_types;
std::vector<std::string> output_formats;
std::vector<TypeId> output_types;
for (auto &param : sub_graph->inputs()) {
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(param, 0);
if (type_id == kTypeUnknown) {
type_id = common::AnfAlgo::GetOutputInferDataType(param, 0);
}
if (type_id > kMonadTypeBegin && type_id < kMonadTypeEnd) {
continue;
}
input_types.push_back(type_id);
input_formats.push_back(AnfAlgo::GetOutputFormat(param, 0));
}
for (size_t i = 0; i < AnfUtils::GetOutputTensorNum(node); ++i) {
output_formats.push_back(AnfAlgo::GetOutputFormat(sub_ret, i));
output_types.push_back(common::AnfAlgo::GetOutputInferDataType(sub_ret, i));
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
builder->SetInputsFormat(input_formats);
builder->SetInputsDeviceType(input_types);
builder->SetOutputsFormat(output_formats);
builder->SetOutputsDeviceType(output_types);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
template <typename T, typename Scalar> template <typename T, typename Scalar>
ValuePtr GetTensorValue(const tensor::TensorPtr &tensor) { ValuePtr GetTensorValue(const tensor::TensorPtr &tensor) {
ValuePtr ret; ValuePtr ret;

View File

@ -118,6 +118,8 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const
void SetInputOutputNames(const std::vector<std::string> &input_names, const std::vector<std::string> &output_names, void SetInputOutputNames(const std::vector<std::string> &input_names, const std::vector<std::string> &output_names,
const AnfNodePtr &node); const AnfNodePtr &node);
void SelectCallInlineKernelInfo(const CNodePtr &node);
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW}; const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/format_type/reselect_call_inline_format.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
namespace mindspore {
namespace opt {
namespace {
constexpr auto Xs = "Xs";
constexpr auto call_inline = "call_inline";
constexpr auto new_call_inline = "new_call_inline";
} // namespace
bool ReselectCallInlineFormat::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &graph,
const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
return true;
}
AnfNodePtr BuildCallInline(const PatternMap &m, const AnfNodePtr &) {
auto anf = m.Get(call_inline);
MS_EXCEPTION_IF_NULL(anf);
auto cnode = anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
SelectCallInlineKernelInfo(cnode);
return cnode;
}
void ReselectCallInlineFormat::DefineSrcPattern(SrcPattern *src_pattern) {
(*src_pattern).AddSeqVar(Xs).AddCNode(call_inline, {prim::kPrimCallInline, Xs});
}
void ReselectCallInlineFormat::DefineDstPattern(DstPattern *dst_pattern) {
(*dst_pattern).AddCNode(new_call_inline, {prim::kPrimCallInline, Xs}, BuildCallInline);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RESELECT_CALL_INLINE_FORMAT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RESELECT_CALL_INLINE_FORMAT_H_
#include <memory>
#include "backend/common/optimizer/pattern_to_pattern.h"
namespace mindspore {
namespace opt {
class ReselectCallInlineFormat : public PatternToPatternPass {
public:
ReselectCallInlineFormat() : PatternToPatternPass("reselect_call_inline_format") {}
~ReselectCallInlineFormat() override = default;
void DefineSrcPattern(SrcPattern *src_pattern) override;
void DefineDstPattern(DstPattern *dst_pattern) override;
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RESELECT_CALL_INLINE_FORMAT_H_

View File

@ -619,6 +619,9 @@ void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value
void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph &graph) { void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph &graph) {
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
if (graph.need_inline()) {
return;
}
auto graph_id = graph.graph_id(); auto graph_id = graph.graph_id();
MS_LOG(INFO) << "AssignStaticMemoryInput start for graph " << graph_id; MS_LOG(INFO) << "AssignStaticMemoryInput start for graph " << graph_id;
auto graph_inputs = GetGraphInputs(graph); auto graph_inputs = GetGraphInputs(graph);
@ -711,6 +714,9 @@ void KernelRuntime::GetDeviceAddress(const AnfNodePtr &item,
} }
void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph &graph) { void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph &graph) {
if (graph.need_inline()) {
return;
}
MS_LOG(INFO) << "AssignStaticMemoryOutput start for graph " << graph.graph_id(); MS_LOG(INFO) << "AssignStaticMemoryOutput start for graph " << graph.graph_id();
auto nodes = common::AnfAlgo::GetAllOutput(graph.output(), {prim::kPrimTupleGetItem}); auto nodes = common::AnfAlgo::GetAllOutput(graph.output(), {prim::kPrimTupleGetItem});
std::vector<session::KernelWithIndex> non_communication_op; std::vector<session::KernelWithIndex> non_communication_op;

View File

@ -83,6 +83,7 @@ const char FUNC_GRAPH_FLAG_IGNORE_VALUE[] = "ignore_value";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
const char FUNC_GRAPH_FLAG_SPARSE_BPROP[] = "sparse_bprop"; const char FUNC_GRAPH_FLAG_SPARSE_BPROP[] = "sparse_bprop";
const char FUNC_GRAPH_FLAG_NO_INLINE[] = "no_inline"; const char FUNC_GRAPH_FLAG_NO_INLINE[] = "no_inline";
const char FUNC_GRAPH_FLAG_NEED_BACKEND_INLINE[] = "need_backend_inline";
const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block"; const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
const char FUNC_GRAPH_FLAG_CORE[] = "core"; const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_FLAG_K_GRAPH[] = "k_graph"; const char FUNC_GRAPH_FLAG_K_GRAPH[] = "k_graph";

View File

@ -1481,6 +1481,7 @@ GVAR_DEF(PrimitivePtr, kPrimSelect, std::make_shared<Primitive>(kSelect));
GVAR_DEF(PrimitivePtr, kPrimCall, std::make_shared<Primitive>("call")); GVAR_DEF(PrimitivePtr, kPrimCall, std::make_shared<Primitive>("call"));
GVAR_DEF(PrimitivePtr, kPrimRaise, std::make_shared<Primitive>("raise")); GVAR_DEF(PrimitivePtr, kPrimRaise, std::make_shared<Primitive>("raise"));
GVAR_DEF(PrimitivePtr, kPrimJoinedStr, std::make_shared<Primitive>("joinedstr")); GVAR_DEF(PrimitivePtr, kPrimJoinedStr, std::make_shared<Primitive>("joinedstr"));
GVAR_DEF(PrimitivePtr, kPrimCallInline, std::make_shared<Primitive>("call_inline"));
GVAR_DEF(PrimitivePtr, kPrimMakeTuple, std::make_shared<Primitive>(kMakeTuple)); GVAR_DEF(PrimitivePtr, kPrimMakeTuple, std::make_shared<Primitive>(kMakeTuple));
GVAR_DEF(PrimitivePtr, kPrimMakeSlice, std::make_shared<Primitive>("make_slice")); GVAR_DEF(PrimitivePtr, kPrimMakeSlice, std::make_shared<Primitive>("make_slice"));

View File

@ -21,7 +21,9 @@ import pytest
import numpy as np import numpy as np
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore as ms
from mindspore.common import mutable
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
@ -29,6 +31,28 @@ from mindspore.nn.optim import Momentum
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
def seed_set():
set_seed(1)
np.random.seed(1)
random.seed(1)
def train(net, data, label):
learning_rate = 0.05
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res_list = []
for _ in range(20):
res = train_network(data, label)
res_list.append(res[0].asnumpy())
return res_list
class CellDense(nn.Cell): class CellDense(nn.Cell):
@cell_attr_register @cell_attr_register
def __init__(self): def __init__(self):
@ -60,6 +84,297 @@ class MLP(nn.Cell):
return out return out
def get_pynative_mlp_cell_reuse_loss():
context.set_context(mode=context.PYNATIVE_MODE)
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell reuse
net = MLP()
loss_list = train(net, data, label)
return loss_list
def get_mlp_cell_reuse_loss(reuse):
context.set_context(mode=context.GRAPH_MODE)
os.environ['MS_DEV_GRAPH_REUSE'] = reuse
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell reuse
net = MLP()
loss_list = train(net, data, label)
del os.environ['MS_DEV_GRAPH_REUSE']
return loss_list
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_reuse_0():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss(str(0))
loss_pynative = get_pynative_mlp_cell_reuse_loss()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_reuse_1():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss(str(1))
loss_pynative = get_pynative_mlp_cell_reuse_loss()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_reuse_2():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss(str(2))
loss_pynative = get_pynative_mlp_cell_reuse_loss()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
class CellDense2(nn.Cell):
@cell_attr_register
def __init__(self):
super(CellDense2, self).__init__()
self.fc = nn.Dense(100, 100)
def construct(self, input_x):
out = self.fc(input_x)
return input_x, out
class MLP2(nn.Cell):
def __init__(self):
super(MLP2, self).__init__()
self.batch_size = 1
self.fc = nn.Dense(200, 100)
layers = []
for _ in range(12):
layer = CellDense2()
layers.append(layer)
self.layers = nn.CellList(layers)
def construct(self, out):
out = self.fc(out)
for layer_module in self.layers:
tmp, out = layer_module(out)
out += tmp
return out
def get_pynative_mlp_cell_reuse_loss_2():
context.set_context(mode=context.PYNATIVE_MODE)
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell reuse
net = MLP2()
loss_list = train(net, data, label)
return loss_list
def get_mlp_cell_reuse_loss_2(reuse):
context.set_context(mode=context.GRAPH_MODE)
os.environ['MS_DEV_GRAPH_REUSE'] = reuse
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell reuse
net = MLP2()
loss_list = train(net, data, label)
del os.environ['MS_DEV_GRAPH_REUSE']
return loss_list
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_2_reuse_0():
"""
Feature: cell reuse.
Description: MLP (need flatten maketuple) with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss_2(str(0))
loss_pynative = get_pynative_mlp_cell_reuse_loss_2()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_2_reuse_1():
"""
Feature: cell reuse.
Description: MLP (need flatten maketuple) with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss_2(str(1))
loss_pynative = get_pynative_mlp_cell_reuse_loss_2()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_2_reuse_2():
"""
Feature: cell reuse.
Description: MLP (need flatten maketuple) with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss_2(str(2))
loss_pynative = get_pynative_mlp_cell_reuse_loss_2()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
class CellDenseWithControlFlow(nn.Cell):
@cell_attr_register
def __init__(self):
super(CellDenseWithControlFlow, self).__init__()
self.fc = nn.Dense(100, 100)
def construct(self, input_x, x):
out = self.fc(input_x)
if x > 0:
out = self.fc(out)
out = self.fc(out)
return out
class MLPWithControlFlow(nn.Cell):
def __init__(self):
super(MLPWithControlFlow, self).__init__()
self.batch_size = 1
self.fc = nn.Dense(200, 100)
layers = []
for _ in range(12):
layer = CellDenseWithControlFlow()
layers.append(layer)
self.layers = nn.CellList(layers)
def construct(self, out):
out = self.fc(out)
for layer_module in self.layers:
x = mutable(ms.Tensor(np.array(1), dtype=ms.int32))
out = layer_module(out, x)
return out
def get_pynative_mlp_cell_reuse_infer():
context.set_context(mode=context.PYNATIVE_MODE)
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
# cell reuse
net = MLPWithControlFlow()
ret = net(data)
return ret.asnumpy()
def get_mlp_cell_reuse_infer(reuse):
context.set_context(mode=context.GRAPH_MODE)
os.environ['MS_DEV_GRAPH_REUSE'] = reuse
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
# cell reuse
net = MLPWithControlFlow()
ret = net(data)
del os.environ['MS_DEV_GRAPH_REUSE']
return ret.asnumpy()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_with_control_flow_reuse_0():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_infer(str(0))
loss_pynative = get_pynative_mlp_cell_reuse_infer()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_with_control_flow_reuse_1():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_infer(str(1))
loss_pynative = get_pynative_mlp_cell_reuse_infer()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_with_control_flow_reuse_2():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_infer(str(2))
loss_pynative = get_pynative_mlp_cell_reuse_infer()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
class CellDropDense(nn.Cell): class CellDropDense(nn.Cell):
@cell_attr_register @cell_attr_register
def __init__(self): def __init__(self):
@ -93,85 +408,6 @@ class DropMLP(nn.Cell):
return out return out
def seed_set():
set_seed(1)
np.random.seed(1)
random.seed(1)
def train(net, data, label):
learning_rate = 0.05
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res_list = []
for _ in range(20):
res = train_network(data, label)
res_list.append(res[0].asnumpy())
return res_list
expect_value = [4.6052, 4.5553, 4.4607, 4.3261, 4.1556, 3.9532, 3.7227,
3.4675, 3.1912, 2.8974, 2.5900, 2.2736, 1.9538, 1.6376,
1.3335, 1.0511, 0.8002, 0.5884, 0.4195, 0.2920]
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_reuse():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
os.environ['MS_DEV_GRAPH_REUSE'] = str(1)
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell reuse
net = MLP()
loss_list = train(net, data, label)
del os.environ['MS_DEV_GRAPH_REUSE']
assert np.allclose(loss_list, expect_value, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell():
"""
Feature: cell reuse.
Description: MLP without cell reuse.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
os.environ['MS_DEV_GRAPH_REUSE'] = str(0)
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell not reuse
net = MLP()
loss_list = train(net, data, label)
del os.environ['MS_DEV_GRAPH_REUSE']
assert np.allclose(loss_list, expect_value, 0.001, 0.001)
@pytest.mark.level1 @pytest.mark.level1
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training