!46516 add cell reuse inline

Merge pull request !46516 from 王禹程/inline
This commit is contained in:
i-robot 2022-12-30 06:51:15 +00:00 committed by Gitee
commit f65cfcaf4d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
20 changed files with 711 additions and 101 deletions

View File

@ -78,6 +78,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
device_target_(DeviceType::kUnknown),
executable_(true),
summary_node_exist_(false),
need_inline_(false),
start_label_(nullptr),
end_goto_(nullptr),
current_epoch_(0),
@ -102,6 +103,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
updated_parameters_ = graph.updated_parameters_;
executable_ = graph.executable_;
summary_node_exist_ = graph.summary_node_exist_;
need_inline_ = graph.need_inline_;
valid_inputs_ = graph.valid_inputs_;
child_graph_order_ = graph.child_graph_order_;
device_loop_ctrl_tensors_ = graph.device_loop_ctrl_tensors_;
@ -216,6 +218,10 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
#endif
// check whether exist summary node in graph
bool summary_node_exist() const { return summary_node_exist_; }
// set need inline
void set_need_inline(bool need_inline) { need_inline_ = need_inline; }
// check whether need inline
bool need_inline() const { return need_inline_; }
// set invalid inputs for control sink
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
std::vector<bool> valid_inputs() const { return valid_inputs_; }
@ -520,6 +526,8 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
bool summary_node_exist_{false};
// valid inputs
std::vector<bool> valid_inputs_;
// need inline
bool need_inline_;
// child graph execute order in parent graph
std::vector<std::weak_ptr<KernelGraph>> child_graph_order_;

View File

@ -735,9 +735,20 @@ std::vector<AnfNodePtr> KernelGraphMgr::CreateSwitchOrPartialNode(const CNodePtr
auto info = kernel_graph_partial_map_[sub_kernel_graph.get()];
call_node->set_abstract(info.abstract);
cnode_inputs.emplace_back(info.sub_graph);
if (info.param_begin != tuple_get_idx) {
if (common::GetEnv("MS_DEV_GRAPH_REUSE") == "2") {
// call_graph and info.sub_graph need inline when cell reuse.
sub_kernel_graph->set_need_inline(true);
auto partial_sub_graph = GetValueNodeKernelGraph(info.sub_graph);
MS_EXCEPTION_IF_NULL(partial_sub_graph);
partial_sub_graph->set_need_inline(true);
MS_LOG(INFO) << "Inline graph " << sub_kernel_graph->graph_id() << " and graph "
<< partial_sub_graph->graph_id();
}
MS_LOG(INFO) << "Use cell reuse: " << sub_kernel_graph->graph_id();
if (info.param_begin != tuple_get_idx + std::max(static_cast<int>(info.multi_tuple) - 1, 0)) {
MS_LOG(EXCEPTION) << "Call param is not a graph, the TupleGetItem index: " << tuple_get_idx
<< ", the partial graph index: " << info.param_begin
<< ", need idx: " << tuple_get_idx + std::max(static_cast<int>(info.multi_tuple) - 1, 0)
<< ", call graph: " << call_graph->fullname_with_scope();
}
for (size_t i = info.param_begin; i < info.param_end; i++) {
@ -867,6 +878,31 @@ ParameterPtr KernelGraphMgr::CreateNewParameter(const AnfNodePtr &anf, KernelGra
return new_parameter;
}
void KernelGraphMgr::FlattenTuple(const CNodePtr &node, KernelGraph *graph) {
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
auto call_graph = node->input(kFirstIndex);
auto sub_kernel_graph = GetValueNodeKernelGraph(call_graph);
MS_EXCEPTION_IF_NULL(sub_kernel_graph);
auto iter = kernel_graph_partial_map_.find(sub_kernel_graph.get());
if (iter != kernel_graph_partial_map_.end() && iter->second.multi_tuple != 0) {
need_flatten_.insert(node);
}
} else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
auto input = node->input(kFirstIndex);
auto get_idx = common::AnfAlgo::GetTupleGetItemOutIndex(node);
if (need_flatten_.find(input) != need_flatten_.end() && get_idx == 0) {
need_flatten_tuple_map_[node] = input;
}
}
for (size_t i = 0; i < common::AnfAlgo::GetInputNum(node); i++) {
auto input = common::AnfAlgo::GetInputNode(node, i);
auto iter = need_flatten_tuple_map_.find(input);
if (iter != need_flatten_tuple_map_.end()) {
node->set_input(i + 1, iter->second);
}
}
}
bool KernelGraphMgr::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
@ -883,6 +919,7 @@ bool KernelGraphMgr::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGrap
new_cnode->set_scope(cnode->scope());
graph->FrontBackendMapAdd(node, new_cnode);
SetReturnNode(new_cnode, graph);
FlattenTuple(new_cnode, graph);
return true;
}
@ -938,14 +975,40 @@ void KernelGraphMgr::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
auto last_input_node = common::AnfAlgo::GetInputNode(make_tuple, tuple_input_num - 1);
MS_EXCEPTION_IF_NULL(last_input_node);
if (last_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(last_input_node, prim::kPrimPartial)) {
size_t multi_tuple = 0;
auto partial_node = last_input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_node);
size_t partial_input_num = common::AnfAlgo::GetInputTensorNum(partial_node);
std::vector<AnfNodePtr> make_tuple_inputs;
make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
// skip last return node (is a partial)
size_t param_begin = 0;
for (size_t i = 0; i < tuple_input_num - 1; i++) {
make_tuple_inputs.emplace_back(common::AnfAlgo::GetInputNode(make_tuple, i));
auto input = common::AnfAlgo::GetInputNode(make_tuple, i);
auto node_abs = input->abstract();
if (node_abs->isa<abstract::AbstractTuple>()) {
MS_EXCEPTION_IF_CHECK_FAIL(
i == 0, "Input index: " + std::to_string(i) + " is a make tuple, input node: " + input->DebugString());
MS_LOG(DEBUG) << "Flatten the make tuple, input node: " << input->DebugString()
<< ", output num: " << AnfUtils::GetOutputTensorNum(input);
// flatten the make tuple
for (size_t j = 0; j < AnfUtils::GetOutputTensorNum(input); j++) {
auto idx = NewValueNode(SizeToLong(j));
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(j);
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
std::vector<TypeId> types = {common::AnfAlgo::GetOutputInferDataType(input, j)};
auto shapes = {common::AnfAlgo::GetOutputInferShape(input, j)};
common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
param_begin++;
multi_tuple++;
make_tuple_inputs.emplace_back(getitem);
}
} else {
param_begin++;
make_tuple_inputs.emplace_back(input);
}
}
// skip partial graph
for (size_t i = kFirstIndex; i < partial_input_num; i++) {
@ -965,8 +1028,8 @@ void KernelGraphMgr::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
g_output->set_abstract(abstract);
graph->set_output(g_output);
kernel_graph_partial_map_[graph] = {abstract, common::AnfAlgo::GetInputNode(partial_node, 0),
tuple_input_num - 1, common::AnfAlgo::GetInputTensorNum(g_output)};
kernel_graph_partial_map_[graph] = {abstract, common::AnfAlgo::GetInputNode(partial_node, 0), param_begin,
common::AnfAlgo::GetInputTensorNum(g_output), multi_tuple};
}
}
}
@ -1041,6 +1104,10 @@ std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(const FuncGrap
auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph);
front_backend_graph_map_[func_graph.get()] = graph;
if (func_graph->has_flag(FUNC_GRAPH_FLAG_NEED_BACKEND_INLINE)) {
MS_LOG(INFO) << "Need backend inline: " << graph->graph_id();
graph->set_need_inline(true);
}
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
graph->set_device_target(device_target);
// Create parameter

View File

@ -46,6 +46,7 @@ struct PartialFuncInfo {
AnfNodePtr sub_graph;
size_t param_begin;
size_t param_end;
size_t multi_tuple;
};
class BACKEND_EXPORT KernelGraphMgr {
@ -111,6 +112,7 @@ class BACKEND_EXPORT KernelGraphMgr {
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) const;
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) const;
void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph);
void FlattenTuple(const CNodePtr &node, KernelGraph *graph);
protected:
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
@ -123,6 +125,8 @@ class BACKEND_EXPORT KernelGraphMgr {
mindspore::HashMap<AnfNodePtr, ParameterPtr> default_param_map_;
mindspore::HashMap<FuncGraph *, KernelGraphPtr> front_backend_graph_map_;
mindspore::HashMap<KernelGraph *, PartialFuncInfo> kernel_graph_partial_map_;
mindspore::HashSet<AnfNodePtr> need_flatten_;
mindspore::HashMap<AnfNodePtr, AnfNodePtr> need_flatten_tuple_map_;
static GraphId graph_sum_;
};
} // namespace session

View File

@ -918,6 +918,10 @@ constexpr auto kAttrInsertDefaultValue = "insert_default_value";
constexpr auto kAttrIsSparse = "IsSparse";
constexpr auto kAttrKernelBackoffWithFailureInfo = "kernel_backoff_with_failure_info";
constexpr auto kAttrKernelBackoffWithFailureType = "kernel_backoff_with_failure_type";
constexpr auto kAttrKernelGraph = "kernel_graph";
constexpr auto kAttrPreKernelGraph = "pre_kernel_graph";
constexpr auto kAttrNeedInline = "need_inline";
constexpr auto kAttrOriFusionName = "ori_fusion_name";
// FuncGraph Flags
constexpr auto kFlagIsDynamicStructure = "is_dynamic_structure";

View File

@ -628,6 +628,9 @@ FuncGraphPtr GenerateReusingGraph(const FuncGraphPtr &fg) {
}
}
MS_LOG(DEBUG) << "The reusable graph parameter size: " << reusing_graph->parameters().size();
if (common::GetEnv("MS_DEV_GRAPH_REUSE") == "2") {
reusing_graph->set_flag(FUNC_GRAPH_FLAG_NEED_BACKEND_INLINE, true);
}
return reusing_graph;
}
@ -652,14 +655,9 @@ void ReplaceWithReusingGraph(const FuncGraphPtr &reusing_graph, const FuncGraphP
// Make the reusable cell to be the reusable function graph.
bool GraphReusingAction(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
constexpr size_t graph_reusing_count = 2;
const auto &obj_map = parse::data_converter::GetObjGraphs();
for (const auto &[cell_key, graphs] : obj_map) {
MS_LOG(DEBUG) << "Start to handle the reusable graph: " << cell_key << ", size: " << graphs.size();
// Only make the reusable cell that is used more than graph_reusing_count to be reusable.
if (graphs.size() < graph_reusing_count) {
continue;
}
const auto &fg = graphs[0];
// fg->paramter_obj_nodes().empty() have been handled by combine like.
if (!fg->paramter_obj_nodes().empty()) {
@ -1514,7 +1512,8 @@ static std::vector<ActionItem> CommonPipeline() {
}
// Make the reusable cell to be the reusable function graph
static bool enable_graph_reusing = (common::GetEnv("MS_DEV_GRAPH_REUSE") == "1");
static bool enable_graph_reusing =
(common::GetEnv("MS_DEV_GRAPH_REUSE") == "1" || common::GetEnv("MS_DEV_GRAPH_REUSE") == "2");
if (enable_graph_reusing) {
(void)actions.emplace_back(std::make_pair("graph_reusing", GraphReusingAction));
}

View File

@ -869,7 +869,11 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
SelectGraphKernelInfo(kernel_node, func_graph);
return result;
}
if (IsPrimitiveCNode(kernel_node, prim::kPrimCallInline)) {
opt::SelectCallInlineKernelInfo(kernel_node);
SetTensorDeviceInfo(kernel_node);
return result;
}
if (common::AnfAlgo::HasNodeAttr(ops::kBatchRank, kernel_node)) {
std::stringstream ss;
ss << common::AnfAlgo::GetCNodeName(kernel_node)

View File

@ -57,6 +57,7 @@ const char ITEREND[] = "PROFILING_ITER_END";
const auto kSingleOutput = 1;
const auto kFirstOutput = 0;
constexpr size_t kFirstIndex = 1;
#ifdef ENABLE_DUMP_IR
bool IsSaveGraph() {
@ -142,6 +143,20 @@ void DumpExecuteOrder(const NotNull<KernelGraphPtr> &kg) {
}
#endif
KernelGraphPtr GetValueNodeKernelGraph(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return nullptr;
}
auto value = value_node->value();
if (value == nullptr) {
return nullptr;
}
auto kernel_graph = value->cast<KernelGraphPtr>();
return kernel_graph;
}
// Return kNoLabel when label id attribute not set for the graph.
uint32_t GetGraphLabel(const KernelGraphPtr &kg) {
auto value = kg->get_attr(kAttrLabelIndex);
@ -151,6 +166,15 @@ uint32_t GetGraphLabel(const KernelGraphPtr &kg) {
return GetValue<uint32_t>(value);
}
bool CheckCallInline(const CNodePtr &cnode) {
if (!common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
return false;
}
auto call_graph = cnode->input(kFirstIndex);
auto sub_kernel_graph = GetValueNodeKernelGraph(call_graph);
return sub_kernel_graph->need_inline();
}
bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2);
bool CheckAbstractTupleIsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) {
@ -501,6 +525,12 @@ class CallInfoFinder {
// Find call-return pairs.
void FindCallReturns() {
for (auto &[caller, call_info] : context_.call_info_map) {
if (caller->need_inline() && (call_info.recursive || call_info.call_sites.size() != 0)) {
MS_LOG(INFO) << "Do not inline cell reuse because it has sub-graph call, graph id: " << caller->graph_id();
caller->set_need_inline(false);
}
}
for (auto &[caller, call_info] : context_.call_info_map) {
for (auto &call_site : call_info.call_sites) {
for (auto &callee : call_site.callees) {
@ -623,7 +653,7 @@ class CallInfoFinder {
}
// Create a parameter for the return value.
if (call_site->out_param == nullptr) {
if (call_site->out_param == nullptr && !CheckCallInline(call_site->cnode)) {
call_site->out_param = context_.CreateParameter(call_site->cnode->abstract());
}
// Add a return point for the callee graph.
@ -634,7 +664,7 @@ class CallInfoFinder {
// Setup label index if there are multi return points.
const auto n_return_points = call_info.return_points.size();
const size_t return_point_sizes = 2;
if (n_return_points > 1) {
if (n_return_points > 1 && !CheckCallInline(call_site->cnode)) {
if (n_return_points == return_point_sizes) {
// Create a parameter to store label index.
const ShapeVector shape = {1};
@ -753,6 +783,10 @@ class AscendAutoMonadConverter {
~AscendAutoMonadConverter() = default;
void Run() {
// need inline
if (kernel_graph_->need_inline()) {
return;
}
// Create an stack
InitStack();
// Setup entry label if found.
@ -1033,6 +1067,20 @@ class AscendAutoMonadConverter {
// The call/switch/switch_layer cnode.
auto &cnode = call_site->cnode;
if (CheckCallInline(cnode)) {
auto call_graph = cnode->input(kFirstIndex);
auto sub_kernel_graph = GetValueNodeKernelGraph(call_graph);
std::vector<AnfNodePtr> call_inline_inputs = {NewPrimitive(prim::kPrimCallInline)};
for (size_t i = kFirstIndex; i < common::AnfAlgo::GetInputNum(cnode); i++) {
call_inline_inputs.emplace_back(common::AnfAlgo::GetInputNode(cnode, i));
}
auto call_inline = kernel_graph_->NewCNode(call_inline_inputs);
MS_EXCEPTION_IF_NULL(call_inline);
call_inline->set_abstract(cnode->abstract());
common::AnfAlgo::SetNodeAttr(kAttrKernelGraph, MakeValue(sub_kernel_graph), call_inline);
ReplaceNode(cnode, call_inline);
return;
}
// Get branches of the call_site.
// for call, there is one branch;

View File

@ -36,6 +36,7 @@
#ifndef ENABLE_SECURITY
#include "include/common/debug/anf_ir_dump.h"
#include "include/common/debug/dump_proto.h"
#include "ir/func_graph_cloner.h"
#endif
namespace mindspore {
@ -95,6 +96,56 @@ bool SetDefaultFormatForSpecialAclOp(const KernelGraphPtr &graph) {
}
return need_change_format;
}
AnfNodePtr DoInline(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
Cloner cloner({}, false);
if (scope != nullptr) {
cloner.set_scope(scope);
}
cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
auto node_list = TopoSort(func_graph->output());
for (auto &ori_node : node_list) {
if (ori_node->isa<Parameter>()) {
continue;
}
auto new_node = cloner[ori_node];
auto kernel_info = dynamic_cast<device::KernelInfo *>(new_node->kernel_info());
// deep copy kernel info
if (kernel_info != nullptr && new_node->kernel_info()->has_build_info()) {
// some check
MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->MutableKernelMod() == nullptr,
"Inline ERROR: " + ori_node->DebugString() + ", kernel mod is not nullptr");
MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->output_address_list().empty(),
"Inline ERROR: " + ori_node->DebugString() + ", output_address_list is not empty");
MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->workspace_address_list().empty(),
"Inline ERROR: " + ori_node->DebugString() + ", workspace_address_list is not empty");
auto new_kernel_info = std::make_shared<device::KernelInfo>();
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(new_node));
new_kernel_info->set_select_kernel_build_info(builder->Build());
new_kernel_info->set_graph_id(kernel_info->graph_id());
new_kernel_info->set_feature_map_flag(kernel_info->is_feature_map());
new_kernel_info->set_ref_map(false, kernel_info->out_in_ref_map());
new_node->set_kernel_info(new_kernel_info);
}
if (ori_node->isa<CNode>()) {
auto ori_cnode = ori_node->cast<CNodePtr>();
if (common::AnfAlgo::HasNodeAttr(kAttrIsUBFusionOp, ori_cnode) &&
common::AnfAlgo::GetNodeAttr<bool>(ori_node, kAttrIsUBFusionOp)) {
// already done fusion compile
auto ori_full_name = ori_cnode->fullname_with_scope();
common::AnfAlgo::SetNodeAttr(kAttrOriFusionName, MakeValue(ori_full_name), new_node);
}
common::AnfAlgo::SetNodeAttr(kAttrNeedInline, MakeValue(ori_node->fullname_with_scope()), new_node);
common::AnfAlgo::SetNodeAttr(kAttrPreKernelGraph, MakeValue(func_graph), new_node);
}
}
return cloner[func_graph->output()];
}
} // namespace
void AscendGraphOptimization::Reset() {
@ -103,6 +154,40 @@ void AscendGraphOptimization::Reset() {
graph_manager_->Clear();
}
void AscendGraphOptimization::InlineSubGraph(const KernelGraphPtr &graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#ifdef ENABLE_DUMP_IR
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (save_graphs) {
std::string file_name = "hwopt_d_before_inline_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph, true, kWholeStack);
}
#endif
auto kernel_cnodes = graph->execution_order();
for (auto &kernel_cnode : kernel_cnodes) {
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnode, prim::kPrimCallInline)) {
auto sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(kernel_cnode, kAttrKernelGraph);
MS_LOG(INFO) << "InlineSubGraph: " << kernel_cnode->DebugString() << ", sub graph: " << sub_graph->graph_id()
<< ", need inline: " << sub_graph->need_inline();
auto main_graph = kernel_cnode->func_graph();
auto mng = main_graph->manager();
AnfNodePtrList inp(kernel_cnode->inputs().begin() + 1, kernel_cnode->inputs().end());
auto out = DoInline(sub_graph, main_graph, inp, kernel_cnode->input(0)->scope());
(void)mng->Replace(kernel_cnode, out);
}
}
memo_.clear();
opt::AscendAfterInlineOptimization(graph);
memo_.clear();
#ifdef ENABLE_DUMP_IR
if (save_graphs) {
std::string file_name = "hwopt_d_after_inline_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph, true, kWholeStack);
}
#endif
}
void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Status record: start optimize graph. graph id: " << graph->graph_id();
@ -118,6 +203,9 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
OptimizeGraphWithoutDeviceInfo(graph);
SelectKernel(graph);
OptimizeGraphWithDeviceInfo(graph);
// inline func before gen execution order
InlineSubGraph(graph);
OptimizeExecutionOrder(graph);
PostOptimization(graph);
@ -195,8 +283,8 @@ void AscendGraphOptimization::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &
MS_EXCEPTION_IF_NULL(graph);
memo_.clear();
HardWareOptimization(graph);
// copy child graph ref output map to father graph ref output map
memo_.clear();
// copy child graph ref output map to father graph ref output map
UpdateRefOutputMap(graph);
AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(graph));
RemoveUnusedValueNode(graph);
@ -288,6 +376,9 @@ void AscendGraphOptimization::HardWareOptimization(const KernelGraphPtr &graph)
return;
}
(void)memo_.insert(graph);
for (auto &child_graph : graph->child_graph_order()) {
HardWareOptimization(child_graph.lock());
}
opt::AscendBackendOptimization(graph);
opt::CommonFinalOptimization(graph);
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
@ -295,10 +386,6 @@ void AscendGraphOptimization::HardWareOptimization(const KernelGraphPtr &graph)
graph->SetExecOrderByDefault();
}
MS_LOG(INFO) << "Status record: end hardware optimize. graph id: " << graph->graph_id();
for (auto &child_graph : graph->child_graph_order()) {
HardWareOptimization(child_graph.lock());
}
}
void AscendGraphOptimization::AddGraphToManager(const NotNull<KernelGraphPtr> graph,
@ -349,6 +436,11 @@ void AscendGraphOptimization::RecurseSelectKernelInfo(const KernelGraphPtr &grap
return;
}
(void)memo_.insert(graph);
for (auto &child_graph : graph->child_graph_order()) {
if (child_graph.lock()->need_inline()) {
RecurseSelectKernelInfo(child_graph.lock());
}
}
#ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -361,16 +453,16 @@ void AscendGraphOptimization::RecurseSelectKernelInfo(const KernelGraphPtr &grap
MS_LOG(INFO) << "Status record: start select kernel info. graph id: " << graph->graph_id();
SetOperatorInfo(graph);
MS_LOG(INFO) << "Status record: end select kernel info. graph id: " << graph->graph_id();
#ifdef ENABLE_DUMP_IR
if (save_graphs) {
std::string file_name = "select_kernel_after_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph);
}
#endif
for (auto &child_graph : graph->child_graph_order()) {
RecurseSelectKernelInfo(child_graph.lock());
if (!child_graph.lock()->need_inline()) {
RecurseSelectKernelInfo(child_graph.lock());
}
}
}

View File

@ -51,6 +51,7 @@ class AscendGraphOptimization {
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph);
void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph);
void OptimizeExecutionOrder(const KernelGraphPtr &graph);
void InlineSubGraph(const KernelGraphPtr &graph);
void PostOptimization(const KernelGraphPtr &graph) const;
// Graph Optimized level-3 interface

View File

@ -350,7 +350,7 @@ void TbeKernelCompileManager::GetAllTbeNodes(const std::shared_ptr<session::Kern
auto all_nodes = kernel_graph->execution_order();
for (const auto &anf_node : all_nodes) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!AnfUtils::IsRealKernel(anf_node)) {
if (!AnfUtils::IsRealKernel(anf_node) || IsPrimitiveCNode(anf_node, prim::kPrimCallInline)) {
continue;
}
KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
@ -564,6 +564,9 @@ std::pair<std::vector<CNodePtr>, std::vector<CNodePtr>> TbeKernelCompileManager:
continue; // kernel mod already exist, continue;
}
auto full_name = node->fullname_with_scope();
if (common::AnfAlgo::HasNodeAttr(kAttrOriFusionName, node)) {
full_name = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrOriFusionName);
}
auto json_name = full_name_to_json_name_[full_name];
auto kernel_pack = tbe::TbeUtils::SearchCache(json_name, false);
if (kernel_pack == nullptr) {

View File

@ -94,6 +94,7 @@
#include "plugin/device/ascend/optimizer/ir_fusion/softmax_dropout_do_mask_v3_fusion.h"
#include "plugin/device/ascend/optimizer/ir_fusion/conv2d_backprop_input_dilation_fusion.h"
#include "plugin/device/ascend/optimizer/format_type/insert_trans_op.h"
#include "plugin/device/ascend/optimizer/format_type/reselect_call_inline_format.h"
#include "plugin/device/ascend/optimizer/format_type/trans_op_format_refine.h"
#include "plugin/device/ascend/optimizer/format_type/dynamic_rnn_grad_reformat.h"
#include "plugin/device/ascend/optimizer/format_type/insert_transpose_for_basiclstm_op.h"
@ -285,6 +286,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
data_layout_pm->AddPass(std::make_shared<ReselectCallInlineFormat>());
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>());
data_layout_pm->AddPass(std::make_shared<ChangeAxisOfReduceKernel>());
@ -509,6 +511,16 @@ void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph>
kernel_graph->SetExecOrderByDefault();
}
void AscendAfterInlineOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<GraphOptimizer>();
auto after_inline_pm = std::make_shared<PassManager>("after_inline_pm");
after_inline_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
after_inline_pm->AddPass(std::make_shared<EliminateRedundantOp>());
optimizer->AddPassManager(after_inline_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "Status record: start ascend backend(data layer & mix precision ...) pass. graph id: "

View File

@ -25,6 +25,7 @@ void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph>
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendAfterInlineOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph);

View File

@ -635,6 +635,41 @@ void SetInputOutputNames(const std::vector<std::string> &input_names, const std:
common::AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), node);
}
void SelectCallInlineKernelInfo(const CNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimCallInline)) {
return;
}
// need inline
auto sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(node, kAttrKernelGraph);
auto sub_ret = sub_graph->output();
std::vector<std::string> input_formats;
std::vector<TypeId> input_types;
std::vector<std::string> output_formats;
std::vector<TypeId> output_types;
for (auto &param : sub_graph->inputs()) {
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(param, 0);
if (type_id == kTypeUnknown) {
type_id = common::AnfAlgo::GetOutputInferDataType(param, 0);
}
if (type_id > kMonadTypeBegin && type_id < kMonadTypeEnd) {
continue;
}
input_types.push_back(type_id);
input_formats.push_back(AnfAlgo::GetOutputFormat(param, 0));
}
for (size_t i = 0; i < AnfUtils::GetOutputTensorNum(node); ++i) {
output_formats.push_back(AnfAlgo::GetOutputFormat(sub_ret, i));
output_types.push_back(common::AnfAlgo::GetOutputInferDataType(sub_ret, i));
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
builder->SetInputsFormat(input_formats);
builder->SetInputsDeviceType(input_types);
builder->SetOutputsFormat(output_formats);
builder->SetOutputsDeviceType(output_types);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
template <typename T, typename Scalar>
ValuePtr GetTensorValue(const tensor::TensorPtr &tensor) {
ValuePtr ret;

View File

@ -118,6 +118,8 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const
void SetInputOutputNames(const std::vector<std::string> &input_names, const std::vector<std::string> &output_names,
const AnfNodePtr &node);
void SelectCallInlineKernelInfo(const CNodePtr &node);
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/format_type/reselect_call_inline_format.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
namespace mindspore {
namespace opt {
namespace {
constexpr auto Xs = "Xs";
constexpr auto call_inline = "call_inline";
constexpr auto new_call_inline = "new_call_inline";
} // namespace
bool ReselectCallInlineFormat::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &graph,
const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
return true;
}
AnfNodePtr BuildCallInline(const PatternMap &m, const AnfNodePtr &) {
auto anf = m.Get(call_inline);
MS_EXCEPTION_IF_NULL(anf);
auto cnode = anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
SelectCallInlineKernelInfo(cnode);
return cnode;
}
void ReselectCallInlineFormat::DefineSrcPattern(SrcPattern *src_pattern) {
(*src_pattern).AddSeqVar(Xs).AddCNode(call_inline, {prim::kPrimCallInline, Xs});
}
void ReselectCallInlineFormat::DefineDstPattern(DstPattern *dst_pattern) {
(*dst_pattern).AddCNode(new_call_inline, {prim::kPrimCallInline, Xs}, BuildCallInline);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RESELECT_CALL_INLINE_FORMAT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RESELECT_CALL_INLINE_FORMAT_H_
#include <memory>
#include "backend/common/optimizer/pattern_to_pattern.h"
namespace mindspore {
namespace opt {
class ReselectCallInlineFormat : public PatternToPatternPass {
public:
ReselectCallInlineFormat() : PatternToPatternPass("reselect_call_inline_format") {}
~ReselectCallInlineFormat() override = default;
void DefineSrcPattern(SrcPattern *src_pattern) override;
void DefineDstPattern(DstPattern *dst_pattern) override;
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RESELECT_CALL_INLINE_FORMAT_H_

View File

@ -619,6 +619,9 @@ void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value
void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph &graph) {
MS_EXCEPTION_IF_NULL(mem_manager_);
if (graph.need_inline()) {
return;
}
auto graph_id = graph.graph_id();
MS_LOG(INFO) << "AssignStaticMemoryInput start for graph " << graph_id;
auto graph_inputs = GetGraphInputs(graph);
@ -711,6 +714,9 @@ void KernelRuntime::GetDeviceAddress(const AnfNodePtr &item,
}
void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph &graph) {
if (graph.need_inline()) {
return;
}
MS_LOG(INFO) << "AssignStaticMemoryOutput start for graph " << graph.graph_id();
auto nodes = common::AnfAlgo::GetAllOutput(graph.output(), {prim::kPrimTupleGetItem});
std::vector<session::KernelWithIndex> non_communication_op;

View File

@ -83,6 +83,7 @@ const char FUNC_GRAPH_FLAG_IGNORE_VALUE[] = "ignore_value";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
const char FUNC_GRAPH_FLAG_SPARSE_BPROP[] = "sparse_bprop";
const char FUNC_GRAPH_FLAG_NO_INLINE[] = "no_inline";
const char FUNC_GRAPH_FLAG_NEED_BACKEND_INLINE[] = "need_backend_inline";
const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_FLAG_K_GRAPH[] = "k_graph";

View File

@ -1481,6 +1481,7 @@ GVAR_DEF(PrimitivePtr, kPrimSelect, std::make_shared<Primitive>(kSelect));
GVAR_DEF(PrimitivePtr, kPrimCall, std::make_shared<Primitive>("call"));
GVAR_DEF(PrimitivePtr, kPrimRaise, std::make_shared<Primitive>("raise"));
GVAR_DEF(PrimitivePtr, kPrimJoinedStr, std::make_shared<Primitive>("joinedstr"));
GVAR_DEF(PrimitivePtr, kPrimCallInline, std::make_shared<Primitive>("call_inline"));
GVAR_DEF(PrimitivePtr, kPrimMakeTuple, std::make_shared<Primitive>(kMakeTuple));
GVAR_DEF(PrimitivePtr, kPrimMakeSlice, std::make_shared<Primitive>("make_slice"));

View File

@ -21,7 +21,9 @@ import pytest
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore as ms
from mindspore.common import mutable
from mindspore.common import set_seed
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
@ -29,6 +31,28 @@ from mindspore.nn.optim import Momentum
from mindspore._extends import cell_attr_register
def seed_set():
set_seed(1)
np.random.seed(1)
random.seed(1)
def train(net, data, label):
learning_rate = 0.05
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res_list = []
for _ in range(20):
res = train_network(data, label)
res_list.append(res[0].asnumpy())
return res_list
class CellDense(nn.Cell):
@cell_attr_register
def __init__(self):
@ -60,6 +84,297 @@ class MLP(nn.Cell):
return out
def get_pynative_mlp_cell_reuse_loss():
context.set_context(mode=context.PYNATIVE_MODE)
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell reuse
net = MLP()
loss_list = train(net, data, label)
return loss_list
def get_mlp_cell_reuse_loss(reuse):
context.set_context(mode=context.GRAPH_MODE)
os.environ['MS_DEV_GRAPH_REUSE'] = reuse
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell reuse
net = MLP()
loss_list = train(net, data, label)
del os.environ['MS_DEV_GRAPH_REUSE']
return loss_list
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_reuse_0():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss(str(0))
loss_pynative = get_pynative_mlp_cell_reuse_loss()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_reuse_1():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss(str(1))
loss_pynative = get_pynative_mlp_cell_reuse_loss()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_reuse_2():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss(str(2))
loss_pynative = get_pynative_mlp_cell_reuse_loss()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
class CellDense2(nn.Cell):
@cell_attr_register
def __init__(self):
super(CellDense2, self).__init__()
self.fc = nn.Dense(100, 100)
def construct(self, input_x):
out = self.fc(input_x)
return input_x, out
class MLP2(nn.Cell):
def __init__(self):
super(MLP2, self).__init__()
self.batch_size = 1
self.fc = nn.Dense(200, 100)
layers = []
for _ in range(12):
layer = CellDense2()
layers.append(layer)
self.layers = nn.CellList(layers)
def construct(self, out):
out = self.fc(out)
for layer_module in self.layers:
tmp, out = layer_module(out)
out += tmp
return out
def get_pynative_mlp_cell_reuse_loss_2():
context.set_context(mode=context.PYNATIVE_MODE)
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell reuse
net = MLP2()
loss_list = train(net, data, label)
return loss_list
def get_mlp_cell_reuse_loss_2(reuse):
context.set_context(mode=context.GRAPH_MODE)
os.environ['MS_DEV_GRAPH_REUSE'] = reuse
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell reuse
net = MLP2()
loss_list = train(net, data, label)
del os.environ['MS_DEV_GRAPH_REUSE']
return loss_list
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_2_reuse_0():
"""
Feature: cell reuse.
Description: MLP (need flatten maketuple) with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss_2(str(0))
loss_pynative = get_pynative_mlp_cell_reuse_loss_2()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_2_reuse_1():
"""
Feature: cell reuse.
Description: MLP (need flatten maketuple) with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss_2(str(1))
loss_pynative = get_pynative_mlp_cell_reuse_loss_2()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_2_reuse_2():
"""
Feature: cell reuse.
Description: MLP (need flatten maketuple) with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_loss_2(str(2))
loss_pynative = get_pynative_mlp_cell_reuse_loss_2()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
class CellDenseWithControlFlow(nn.Cell):
@cell_attr_register
def __init__(self):
super(CellDenseWithControlFlow, self).__init__()
self.fc = nn.Dense(100, 100)
def construct(self, input_x, x):
out = self.fc(input_x)
if x > 0:
out = self.fc(out)
out = self.fc(out)
return out
class MLPWithControlFlow(nn.Cell):
def __init__(self):
super(MLPWithControlFlow, self).__init__()
self.batch_size = 1
self.fc = nn.Dense(200, 100)
layers = []
for _ in range(12):
layer = CellDenseWithControlFlow()
layers.append(layer)
self.layers = nn.CellList(layers)
def construct(self, out):
out = self.fc(out)
for layer_module in self.layers:
x = mutable(ms.Tensor(np.array(1), dtype=ms.int32))
out = layer_module(out, x)
return out
def get_pynative_mlp_cell_reuse_infer():
context.set_context(mode=context.PYNATIVE_MODE)
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
# cell reuse
net = MLPWithControlFlow()
ret = net(data)
return ret.asnumpy()
def get_mlp_cell_reuse_infer(reuse):
context.set_context(mode=context.GRAPH_MODE)
os.environ['MS_DEV_GRAPH_REUSE'] = reuse
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
# cell reuse
net = MLPWithControlFlow()
ret = net(data)
del os.environ['MS_DEV_GRAPH_REUSE']
return ret.asnumpy()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_with_control_flow_reuse_0():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_infer(str(0))
loss_pynative = get_pynative_mlp_cell_reuse_infer()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_with_control_flow_reuse_1():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_infer(str(1))
loss_pynative = get_pynative_mlp_cell_reuse_infer()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_with_control_flow_reuse_2():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
loss_graph = get_mlp_cell_reuse_infer(str(2))
loss_pynative = get_pynative_mlp_cell_reuse_infer()
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
class CellDropDense(nn.Cell):
@cell_attr_register
def __init__(self):
@ -93,85 +408,6 @@ class DropMLP(nn.Cell):
return out
def seed_set():
set_seed(1)
np.random.seed(1)
random.seed(1)
def train(net, data, label):
learning_rate = 0.05
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res_list = []
for _ in range(20):
res = train_network(data, label)
res_list.append(res[0].asnumpy())
return res_list
expect_value = [4.6052, 4.5553, 4.4607, 4.3261, 4.1556, 3.9532, 3.7227,
3.4675, 3.1912, 2.8974, 2.5900, 2.2736, 1.9538, 1.6376,
1.3335, 1.0511, 0.8002, 0.5884, 0.4195, 0.2920]
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell_reuse():
"""
Feature: cell reuse.
Description: MLP with cell reuse.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
os.environ['MS_DEV_GRAPH_REUSE'] = str(1)
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell reuse
net = MLP()
loss_list = train(net, data, label)
del os.environ['MS_DEV_GRAPH_REUSE']
assert np.allclose(loss_list, expect_value, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlp_cell():
"""
Feature: cell reuse.
Description: MLP without cell reuse.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
os.environ['MS_DEV_GRAPH_REUSE'] = str(0)
# gen data
seed_set()
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
# cell not reuse
net = MLP()
loss_list = train(net, data, label)
del os.environ['MS_DEV_GRAPH_REUSE']
assert np.allclose(loss_list, expect_value, 0.001, 0.001)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training