commit
f65cfcaf4d
|
@ -78,6 +78,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
device_target_(DeviceType::kUnknown),
|
||||
executable_(true),
|
||||
summary_node_exist_(false),
|
||||
need_inline_(false),
|
||||
start_label_(nullptr),
|
||||
end_goto_(nullptr),
|
||||
current_epoch_(0),
|
||||
|
@ -102,6 +103,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
updated_parameters_ = graph.updated_parameters_;
|
||||
executable_ = graph.executable_;
|
||||
summary_node_exist_ = graph.summary_node_exist_;
|
||||
need_inline_ = graph.need_inline_;
|
||||
valid_inputs_ = graph.valid_inputs_;
|
||||
child_graph_order_ = graph.child_graph_order_;
|
||||
device_loop_ctrl_tensors_ = graph.device_loop_ctrl_tensors_;
|
||||
|
@ -216,6 +218,10 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
#endif
|
||||
// check whether exist summary node in graph
|
||||
bool summary_node_exist() const { return summary_node_exist_; }
|
||||
// set need inline
|
||||
void set_need_inline(bool need_inline) { need_inline_ = need_inline; }
|
||||
// check whether need inline
|
||||
bool need_inline() const { return need_inline_; }
|
||||
// set invalid inputs for control sink
|
||||
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
|
||||
std::vector<bool> valid_inputs() const { return valid_inputs_; }
|
||||
|
@ -520,6 +526,8 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
bool summary_node_exist_{false};
|
||||
// valid inputs
|
||||
std::vector<bool> valid_inputs_;
|
||||
// need inline
|
||||
bool need_inline_;
|
||||
|
||||
// child graph execute order in parent graph
|
||||
std::vector<std::weak_ptr<KernelGraph>> child_graph_order_;
|
||||
|
|
|
@ -735,9 +735,20 @@ std::vector<AnfNodePtr> KernelGraphMgr::CreateSwitchOrPartialNode(const CNodePtr
|
|||
auto info = kernel_graph_partial_map_[sub_kernel_graph.get()];
|
||||
call_node->set_abstract(info.abstract);
|
||||
cnode_inputs.emplace_back(info.sub_graph);
|
||||
if (info.param_begin != tuple_get_idx) {
|
||||
if (common::GetEnv("MS_DEV_GRAPH_REUSE") == "2") {
|
||||
// call_graph and info.sub_graph need inline when cell reuse.
|
||||
sub_kernel_graph->set_need_inline(true);
|
||||
auto partial_sub_graph = GetValueNodeKernelGraph(info.sub_graph);
|
||||
MS_EXCEPTION_IF_NULL(partial_sub_graph);
|
||||
partial_sub_graph->set_need_inline(true);
|
||||
MS_LOG(INFO) << "Inline graph " << sub_kernel_graph->graph_id() << " and graph "
|
||||
<< partial_sub_graph->graph_id();
|
||||
}
|
||||
MS_LOG(INFO) << "Use cell reuse: " << sub_kernel_graph->graph_id();
|
||||
if (info.param_begin != tuple_get_idx + std::max(static_cast<int>(info.multi_tuple) - 1, 0)) {
|
||||
MS_LOG(EXCEPTION) << "Call param is not a graph, the TupleGetItem index: " << tuple_get_idx
|
||||
<< ", the partial graph index: " << info.param_begin
|
||||
<< ", need idx: " << tuple_get_idx + std::max(static_cast<int>(info.multi_tuple) - 1, 0)
|
||||
<< ", call graph: " << call_graph->fullname_with_scope();
|
||||
}
|
||||
for (size_t i = info.param_begin; i < info.param_end; i++) {
|
||||
|
@ -867,6 +878,31 @@ ParameterPtr KernelGraphMgr::CreateNewParameter(const AnfNodePtr &anf, KernelGra
|
|||
return new_parameter;
|
||||
}
|
||||
|
||||
void KernelGraphMgr::FlattenTuple(const CNodePtr &node, KernelGraph *graph) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
|
||||
auto call_graph = node->input(kFirstIndex);
|
||||
auto sub_kernel_graph = GetValueNodeKernelGraph(call_graph);
|
||||
MS_EXCEPTION_IF_NULL(sub_kernel_graph);
|
||||
auto iter = kernel_graph_partial_map_.find(sub_kernel_graph.get());
|
||||
if (iter != kernel_graph_partial_map_.end() && iter->second.multi_tuple != 0) {
|
||||
need_flatten_.insert(node);
|
||||
}
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
auto input = node->input(kFirstIndex);
|
||||
auto get_idx = common::AnfAlgo::GetTupleGetItemOutIndex(node);
|
||||
if (need_flatten_.find(input) != need_flatten_.end() && get_idx == 0) {
|
||||
need_flatten_tuple_map_[node] = input;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < common::AnfAlgo::GetInputNum(node); i++) {
|
||||
auto input = common::AnfAlgo::GetInputNode(node, i);
|
||||
auto iter = need_flatten_tuple_map_.find(input);
|
||||
if (iter != need_flatten_tuple_map_.end()) {
|
||||
node->set_input(i + 1, iter->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool KernelGraphMgr::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -883,6 +919,7 @@ bool KernelGraphMgr::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGrap
|
|||
new_cnode->set_scope(cnode->scope());
|
||||
graph->FrontBackendMapAdd(node, new_cnode);
|
||||
SetReturnNode(new_cnode, graph);
|
||||
FlattenTuple(new_cnode, graph);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -938,14 +975,40 @@ void KernelGraphMgr::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
|
|||
auto last_input_node = common::AnfAlgo::GetInputNode(make_tuple, tuple_input_num - 1);
|
||||
MS_EXCEPTION_IF_NULL(last_input_node);
|
||||
if (last_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(last_input_node, prim::kPrimPartial)) {
|
||||
size_t multi_tuple = 0;
|
||||
auto partial_node = last_input_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
size_t partial_input_num = common::AnfAlgo::GetInputTensorNum(partial_node);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
// skip last return node (is a partial)
|
||||
size_t param_begin = 0;
|
||||
for (size_t i = 0; i < tuple_input_num - 1; i++) {
|
||||
make_tuple_inputs.emplace_back(common::AnfAlgo::GetInputNode(make_tuple, i));
|
||||
auto input = common::AnfAlgo::GetInputNode(make_tuple, i);
|
||||
auto node_abs = input->abstract();
|
||||
if (node_abs->isa<abstract::AbstractTuple>()) {
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(
|
||||
i == 0, "Input index: " + std::to_string(i) + " is a make tuple, input node: " + input->DebugString());
|
||||
MS_LOG(DEBUG) << "Flatten the make tuple, input node: " << input->DebugString()
|
||||
<< ", output num: " << AnfUtils::GetOutputTensorNum(input);
|
||||
// flatten the make tuple
|
||||
for (size_t j = 0; j < AnfUtils::GetOutputTensorNum(input); j++) {
|
||||
auto idx = NewValueNode(SizeToLong(j));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int64Imm>(j);
|
||||
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
|
||||
auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
|
||||
std::vector<TypeId> types = {common::AnfAlgo::GetOutputInferDataType(input, j)};
|
||||
auto shapes = {common::AnfAlgo::GetOutputInferShape(input, j)};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
|
||||
param_begin++;
|
||||
multi_tuple++;
|
||||
make_tuple_inputs.emplace_back(getitem);
|
||||
}
|
||||
} else {
|
||||
param_begin++;
|
||||
make_tuple_inputs.emplace_back(input);
|
||||
}
|
||||
}
|
||||
// skip partial graph
|
||||
for (size_t i = kFirstIndex; i < partial_input_num; i++) {
|
||||
|
@ -965,8 +1028,8 @@ void KernelGraphMgr::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
|
|||
g_output->set_abstract(abstract);
|
||||
graph->set_output(g_output);
|
||||
|
||||
kernel_graph_partial_map_[graph] = {abstract, common::AnfAlgo::GetInputNode(partial_node, 0),
|
||||
tuple_input_num - 1, common::AnfAlgo::GetInputTensorNum(g_output)};
|
||||
kernel_graph_partial_map_[graph] = {abstract, common::AnfAlgo::GetInputNode(partial_node, 0), param_begin,
|
||||
common::AnfAlgo::GetInputTensorNum(g_output), multi_tuple};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1041,6 +1104,10 @@ std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(const FuncGrap
|
|||
auto graph = NewKernelGraph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
front_backend_graph_map_[func_graph.get()] = graph;
|
||||
if (func_graph->has_flag(FUNC_GRAPH_FLAG_NEED_BACKEND_INLINE)) {
|
||||
MS_LOG(INFO) << "Need backend inline: " << graph->graph_id();
|
||||
graph->set_need_inline(true);
|
||||
}
|
||||
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
||||
graph->set_device_target(device_target);
|
||||
// Create parameter
|
||||
|
|
|
@ -46,6 +46,7 @@ struct PartialFuncInfo {
|
|||
AnfNodePtr sub_graph;
|
||||
size_t param_begin;
|
||||
size_t param_end;
|
||||
size_t multi_tuple;
|
||||
};
|
||||
|
||||
class BACKEND_EXPORT KernelGraphMgr {
|
||||
|
@ -111,6 +112,7 @@ class BACKEND_EXPORT KernelGraphMgr {
|
|||
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) const;
|
||||
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph) const;
|
||||
void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph);
|
||||
void FlattenTuple(const CNodePtr &node, KernelGraph *graph);
|
||||
|
||||
protected:
|
||||
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
|
||||
|
@ -123,6 +125,8 @@ class BACKEND_EXPORT KernelGraphMgr {
|
|||
mindspore::HashMap<AnfNodePtr, ParameterPtr> default_param_map_;
|
||||
mindspore::HashMap<FuncGraph *, KernelGraphPtr> front_backend_graph_map_;
|
||||
mindspore::HashMap<KernelGraph *, PartialFuncInfo> kernel_graph_partial_map_;
|
||||
mindspore::HashSet<AnfNodePtr> need_flatten_;
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> need_flatten_tuple_map_;
|
||||
static GraphId graph_sum_;
|
||||
};
|
||||
} // namespace session
|
||||
|
|
|
@ -918,6 +918,10 @@ constexpr auto kAttrInsertDefaultValue = "insert_default_value";
|
|||
constexpr auto kAttrIsSparse = "IsSparse";
|
||||
constexpr auto kAttrKernelBackoffWithFailureInfo = "kernel_backoff_with_failure_info";
|
||||
constexpr auto kAttrKernelBackoffWithFailureType = "kernel_backoff_with_failure_type";
|
||||
constexpr auto kAttrKernelGraph = "kernel_graph";
|
||||
constexpr auto kAttrPreKernelGraph = "pre_kernel_graph";
|
||||
constexpr auto kAttrNeedInline = "need_inline";
|
||||
constexpr auto kAttrOriFusionName = "ori_fusion_name";
|
||||
|
||||
// FuncGraph Flags
|
||||
constexpr auto kFlagIsDynamicStructure = "is_dynamic_structure";
|
||||
|
|
|
@ -628,6 +628,9 @@ FuncGraphPtr GenerateReusingGraph(const FuncGraphPtr &fg) {
|
|||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "The reusable graph parameter size: " << reusing_graph->parameters().size();
|
||||
if (common::GetEnv("MS_DEV_GRAPH_REUSE") == "2") {
|
||||
reusing_graph->set_flag(FUNC_GRAPH_FLAG_NEED_BACKEND_INLINE, true);
|
||||
}
|
||||
return reusing_graph;
|
||||
}
|
||||
|
||||
|
@ -652,14 +655,9 @@ void ReplaceWithReusingGraph(const FuncGraphPtr &reusing_graph, const FuncGraphP
|
|||
// Make the reusable cell to be the reusable function graph.
|
||||
bool GraphReusingAction(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
constexpr size_t graph_reusing_count = 2;
|
||||
const auto &obj_map = parse::data_converter::GetObjGraphs();
|
||||
for (const auto &[cell_key, graphs] : obj_map) {
|
||||
MS_LOG(DEBUG) << "Start to handle the reusable graph: " << cell_key << ", size: " << graphs.size();
|
||||
// Only make the reusable cell that is used more than graph_reusing_count to be reusable.
|
||||
if (graphs.size() < graph_reusing_count) {
|
||||
continue;
|
||||
}
|
||||
const auto &fg = graphs[0];
|
||||
// fg->paramter_obj_nodes().empty() have been handled by combine like.
|
||||
if (!fg->paramter_obj_nodes().empty()) {
|
||||
|
@ -1514,7 +1512,8 @@ static std::vector<ActionItem> CommonPipeline() {
|
|||
}
|
||||
|
||||
// Make the reusable cell to be the reusable function graph
|
||||
static bool enable_graph_reusing = (common::GetEnv("MS_DEV_GRAPH_REUSE") == "1");
|
||||
static bool enable_graph_reusing =
|
||||
(common::GetEnv("MS_DEV_GRAPH_REUSE") == "1" || common::GetEnv("MS_DEV_GRAPH_REUSE") == "2");
|
||||
if (enable_graph_reusing) {
|
||||
(void)actions.emplace_back(std::make_pair("graph_reusing", GraphReusingAction));
|
||||
}
|
||||
|
|
|
@ -869,7 +869,11 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
|
|||
SelectGraphKernelInfo(kernel_node, func_graph);
|
||||
return result;
|
||||
}
|
||||
|
||||
if (IsPrimitiveCNode(kernel_node, prim::kPrimCallInline)) {
|
||||
opt::SelectCallInlineKernelInfo(kernel_node);
|
||||
SetTensorDeviceInfo(kernel_node);
|
||||
return result;
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(ops::kBatchRank, kernel_node)) {
|
||||
std::stringstream ss;
|
||||
ss << common::AnfAlgo::GetCNodeName(kernel_node)
|
||||
|
|
|
@ -57,6 +57,7 @@ const char ITEREND[] = "PROFILING_ITER_END";
|
|||
|
||||
const auto kSingleOutput = 1;
|
||||
const auto kFirstOutput = 0;
|
||||
constexpr size_t kFirstIndex = 1;
|
||||
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
bool IsSaveGraph() {
|
||||
|
@ -142,6 +143,20 @@ void DumpExecuteOrder(const NotNull<KernelGraphPtr> &kg) {
|
|||
}
|
||||
#endif
|
||||
|
||||
KernelGraphPtr GetValueNodeKernelGraph(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto value = value_node->value();
|
||||
if (value == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto kernel_graph = value->cast<KernelGraphPtr>();
|
||||
return kernel_graph;
|
||||
}
|
||||
|
||||
// Return kNoLabel when label id attribute not set for the graph.
|
||||
uint32_t GetGraphLabel(const KernelGraphPtr &kg) {
|
||||
auto value = kg->get_attr(kAttrLabelIndex);
|
||||
|
@ -151,6 +166,15 @@ uint32_t GetGraphLabel(const KernelGraphPtr &kg) {
|
|||
return GetValue<uint32_t>(value);
|
||||
}
|
||||
|
||||
bool CheckCallInline(const CNodePtr &cnode) {
|
||||
if (!common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
|
||||
return false;
|
||||
}
|
||||
auto call_graph = cnode->input(kFirstIndex);
|
||||
auto sub_kernel_graph = GetValueNodeKernelGraph(call_graph);
|
||||
return sub_kernel_graph->need_inline();
|
||||
}
|
||||
|
||||
bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2);
|
||||
|
||||
bool CheckAbstractTupleIsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) {
|
||||
|
@ -501,6 +525,12 @@ class CallInfoFinder {
|
|||
|
||||
// Find call-return pairs.
|
||||
void FindCallReturns() {
|
||||
for (auto &[caller, call_info] : context_.call_info_map) {
|
||||
if (caller->need_inline() && (call_info.recursive || call_info.call_sites.size() != 0)) {
|
||||
MS_LOG(INFO) << "Do not inline cell reuse because it has sub-graph call, graph id: " << caller->graph_id();
|
||||
caller->set_need_inline(false);
|
||||
}
|
||||
}
|
||||
for (auto &[caller, call_info] : context_.call_info_map) {
|
||||
for (auto &call_site : call_info.call_sites) {
|
||||
for (auto &callee : call_site.callees) {
|
||||
|
@ -623,7 +653,7 @@ class CallInfoFinder {
|
|||
}
|
||||
|
||||
// Create a parameter for the return value.
|
||||
if (call_site->out_param == nullptr) {
|
||||
if (call_site->out_param == nullptr && !CheckCallInline(call_site->cnode)) {
|
||||
call_site->out_param = context_.CreateParameter(call_site->cnode->abstract());
|
||||
}
|
||||
// Add a return point for the callee graph.
|
||||
|
@ -634,7 +664,7 @@ class CallInfoFinder {
|
|||
// Setup label index if there are multi return points.
|
||||
const auto n_return_points = call_info.return_points.size();
|
||||
const size_t return_point_sizes = 2;
|
||||
if (n_return_points > 1) {
|
||||
if (n_return_points > 1 && !CheckCallInline(call_site->cnode)) {
|
||||
if (n_return_points == return_point_sizes) {
|
||||
// Create a parameter to store label index.
|
||||
const ShapeVector shape = {1};
|
||||
|
@ -753,6 +783,10 @@ class AscendAutoMonadConverter {
|
|||
~AscendAutoMonadConverter() = default;
|
||||
|
||||
void Run() {
|
||||
// need inline
|
||||
if (kernel_graph_->need_inline()) {
|
||||
return;
|
||||
}
|
||||
// Create an stack
|
||||
InitStack();
|
||||
// Setup entry label if found.
|
||||
|
@ -1033,6 +1067,20 @@ class AscendAutoMonadConverter {
|
|||
|
||||
// The call/switch/switch_layer cnode.
|
||||
auto &cnode = call_site->cnode;
|
||||
if (CheckCallInline(cnode)) {
|
||||
auto call_graph = cnode->input(kFirstIndex);
|
||||
auto sub_kernel_graph = GetValueNodeKernelGraph(call_graph);
|
||||
std::vector<AnfNodePtr> call_inline_inputs = {NewPrimitive(prim::kPrimCallInline)};
|
||||
for (size_t i = kFirstIndex; i < common::AnfAlgo::GetInputNum(cnode); i++) {
|
||||
call_inline_inputs.emplace_back(common::AnfAlgo::GetInputNode(cnode, i));
|
||||
}
|
||||
auto call_inline = kernel_graph_->NewCNode(call_inline_inputs);
|
||||
MS_EXCEPTION_IF_NULL(call_inline);
|
||||
call_inline->set_abstract(cnode->abstract());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrKernelGraph, MakeValue(sub_kernel_graph), call_inline);
|
||||
ReplaceNode(cnode, call_inline);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get branches of the call_site.
|
||||
// for call, there is one branch;
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#ifndef ENABLE_SECURITY
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "include/common/debug/dump_proto.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -95,6 +96,56 @@ bool SetDefaultFormatForSpecialAclOp(const KernelGraphPtr &graph) {
|
|||
}
|
||||
return need_change_format;
|
||||
}
|
||||
|
||||
AnfNodePtr DoInline(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
|
||||
const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(target_func_graph);
|
||||
Cloner cloner({}, false);
|
||||
if (scope != nullptr) {
|
||||
cloner.set_scope(scope);
|
||||
}
|
||||
cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
|
||||
auto node_list = TopoSort(func_graph->output());
|
||||
for (auto &ori_node : node_list) {
|
||||
if (ori_node->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
auto new_node = cloner[ori_node];
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(new_node->kernel_info());
|
||||
// deep copy kernel info
|
||||
if (kernel_info != nullptr && new_node->kernel_info()->has_build_info()) {
|
||||
// some check
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->MutableKernelMod() == nullptr,
|
||||
"Inline ERROR: " + ori_node->DebugString() + ", kernel mod is not nullptr");
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->output_address_list().empty(),
|
||||
"Inline ERROR: " + ori_node->DebugString() + ", output_address_list is not empty");
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->workspace_address_list().empty(),
|
||||
"Inline ERROR: " + ori_node->DebugString() + ", workspace_address_list is not empty");
|
||||
|
||||
auto new_kernel_info = std::make_shared<device::KernelInfo>();
|
||||
auto builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(new_node));
|
||||
new_kernel_info->set_select_kernel_build_info(builder->Build());
|
||||
new_kernel_info->set_graph_id(kernel_info->graph_id());
|
||||
new_kernel_info->set_feature_map_flag(kernel_info->is_feature_map());
|
||||
new_kernel_info->set_ref_map(false, kernel_info->out_in_ref_map());
|
||||
new_node->set_kernel_info(new_kernel_info);
|
||||
}
|
||||
if (ori_node->isa<CNode>()) {
|
||||
auto ori_cnode = ori_node->cast<CNodePtr>();
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrIsUBFusionOp, ori_cnode) &&
|
||||
common::AnfAlgo::GetNodeAttr<bool>(ori_node, kAttrIsUBFusionOp)) {
|
||||
// already done fusion compile
|
||||
auto ori_full_name = ori_cnode->fullname_with_scope();
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOriFusionName, MakeValue(ori_full_name), new_node);
|
||||
}
|
||||
common::AnfAlgo::SetNodeAttr(kAttrNeedInline, MakeValue(ori_node->fullname_with_scope()), new_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPreKernelGraph, MakeValue(func_graph), new_node);
|
||||
}
|
||||
}
|
||||
return cloner[func_graph->output()];
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AscendGraphOptimization::Reset() {
|
||||
|
@ -103,6 +154,40 @@ void AscendGraphOptimization::Reset() {
|
|||
graph_manager_->Clear();
|
||||
}
|
||||
|
||||
void AscendGraphOptimization::InlineSubGraph(const KernelGraphPtr &graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_d_before_inline_graph_" + std::to_string(graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, graph, true, kWholeStack);
|
||||
}
|
||||
#endif
|
||||
auto kernel_cnodes = graph->execution_order();
|
||||
for (auto &kernel_cnode : kernel_cnodes) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnode, prim::kPrimCallInline)) {
|
||||
auto sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(kernel_cnode, kAttrKernelGraph);
|
||||
MS_LOG(INFO) << "InlineSubGraph: " << kernel_cnode->DebugString() << ", sub graph: " << sub_graph->graph_id()
|
||||
<< ", need inline: " << sub_graph->need_inline();
|
||||
auto main_graph = kernel_cnode->func_graph();
|
||||
auto mng = main_graph->manager();
|
||||
AnfNodePtrList inp(kernel_cnode->inputs().begin() + 1, kernel_cnode->inputs().end());
|
||||
auto out = DoInline(sub_graph, main_graph, inp, kernel_cnode->input(0)->scope());
|
||||
(void)mng->Replace(kernel_cnode, out);
|
||||
}
|
||||
}
|
||||
memo_.clear();
|
||||
opt::AscendAfterInlineOptimization(graph);
|
||||
memo_.clear();
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_d_after_inline_graph_" + std::to_string(graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, graph, true, kWholeStack);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Status record: start optimize graph. graph id: " << graph->graph_id();
|
||||
|
@ -118,6 +203,9 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
|
|||
OptimizeGraphWithoutDeviceInfo(graph);
|
||||
SelectKernel(graph);
|
||||
OptimizeGraphWithDeviceInfo(graph);
|
||||
|
||||
// inline func before gen execution order
|
||||
InlineSubGraph(graph);
|
||||
OptimizeExecutionOrder(graph);
|
||||
PostOptimization(graph);
|
||||
|
||||
|
@ -195,8 +283,8 @@ void AscendGraphOptimization::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
memo_.clear();
|
||||
HardWareOptimization(graph);
|
||||
// copy child graph ref output map to father graph ref output map
|
||||
memo_.clear();
|
||||
// copy child graph ref output map to father graph ref output map
|
||||
UpdateRefOutputMap(graph);
|
||||
AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(graph));
|
||||
RemoveUnusedValueNode(graph);
|
||||
|
@ -288,6 +376,9 @@ void AscendGraphOptimization::HardWareOptimization(const KernelGraphPtr &graph)
|
|||
return;
|
||||
}
|
||||
(void)memo_.insert(graph);
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
HardWareOptimization(child_graph.lock());
|
||||
}
|
||||
opt::AscendBackendOptimization(graph);
|
||||
opt::CommonFinalOptimization(graph);
|
||||
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
|
@ -295,10 +386,6 @@ void AscendGraphOptimization::HardWareOptimization(const KernelGraphPtr &graph)
|
|||
graph->SetExecOrderByDefault();
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: end hardware optimize. graph id: " << graph->graph_id();
|
||||
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
HardWareOptimization(child_graph.lock());
|
||||
}
|
||||
}
|
||||
|
||||
void AscendGraphOptimization::AddGraphToManager(const NotNull<KernelGraphPtr> graph,
|
||||
|
@ -349,6 +436,11 @@ void AscendGraphOptimization::RecurseSelectKernelInfo(const KernelGraphPtr &grap
|
|||
return;
|
||||
}
|
||||
(void)memo_.insert(graph);
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
if (child_graph.lock()->need_inline()) {
|
||||
RecurseSelectKernelInfo(child_graph.lock());
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
@ -361,16 +453,16 @@ void AscendGraphOptimization::RecurseSelectKernelInfo(const KernelGraphPtr &grap
|
|||
MS_LOG(INFO) << "Status record: start select kernel info. graph id: " << graph->graph_id();
|
||||
SetOperatorInfo(graph);
|
||||
MS_LOG(INFO) << "Status record: end select kernel info. graph id: " << graph->graph_id();
|
||||
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
if (save_graphs) {
|
||||
std::string file_name = "select_kernel_after_graph_" + std::to_string(graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, graph);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
RecurseSelectKernelInfo(child_graph.lock());
|
||||
if (!child_graph.lock()->need_inline()) {
|
||||
RecurseSelectKernelInfo(child_graph.lock());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ class AscendGraphOptimization {
|
|||
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph);
|
||||
void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph);
|
||||
void OptimizeExecutionOrder(const KernelGraphPtr &graph);
|
||||
void InlineSubGraph(const KernelGraphPtr &graph);
|
||||
void PostOptimization(const KernelGraphPtr &graph) const;
|
||||
|
||||
// Graph Optimized level-3 interface
|
||||
|
|
|
@ -350,7 +350,7 @@ void TbeKernelCompileManager::GetAllTbeNodes(const std::shared_ptr<session::Kern
|
|||
auto all_nodes = kernel_graph->execution_order();
|
||||
for (const auto &anf_node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
if (!AnfUtils::IsRealKernel(anf_node)) {
|
||||
if (!AnfUtils::IsRealKernel(anf_node) || IsPrimitiveCNode(anf_node, prim::kPrimCallInline)) {
|
||||
continue;
|
||||
}
|
||||
KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
|
||||
|
@ -564,6 +564,9 @@ std::pair<std::vector<CNodePtr>, std::vector<CNodePtr>> TbeKernelCompileManager:
|
|||
continue; // kernel mod already exist, continue;
|
||||
}
|
||||
auto full_name = node->fullname_with_scope();
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrOriFusionName, node)) {
|
||||
full_name = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrOriFusionName);
|
||||
}
|
||||
auto json_name = full_name_to_json_name_[full_name];
|
||||
auto kernel_pack = tbe::TbeUtils::SearchCache(json_name, false);
|
||||
if (kernel_pack == nullptr) {
|
||||
|
|
|
@ -94,6 +94,7 @@
|
|||
#include "plugin/device/ascend/optimizer/ir_fusion/softmax_dropout_do_mask_v3_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/conv2d_backprop_input_dilation_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/insert_trans_op.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/reselect_call_inline_format.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/trans_op_format_refine.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/dynamic_rnn_grad_reformat.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/insert_transpose_for_basiclstm_op.h"
|
||||
|
@ -285,6 +286,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
|
||||
data_layout_pm->AddPass(std::make_shared<ReselectCallInlineFormat>());
|
||||
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
|
||||
data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>());
|
||||
data_layout_pm->AddPass(std::make_shared<ChangeAxisOfReduceKernel>());
|
||||
|
@ -509,6 +511,16 @@ void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph>
|
|||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void AscendAfterInlineOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto after_inline_pm = std::make_shared<PassManager>("after_inline_pm");
|
||||
after_inline_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
after_inline_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
||||
optimizer->AddPassManager(after_inline_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_LOG(INFO) << "Status record: start ascend backend(data layer & mix precision ...) pass. graph id: "
|
||||
|
|
|
@ -25,6 +25,7 @@ void RunOpAscendBackendOptimization(const std::shared_ptr<session::KernelGraph>
|
|||
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendAfterInlineOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
|
|
|
@ -635,6 +635,41 @@ void SetInputOutputNames(const std::vector<std::string> &input_names, const std:
|
|||
common::AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), node);
|
||||
}
|
||||
|
||||
void SelectCallInlineKernelInfo(const CNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimCallInline)) {
|
||||
return;
|
||||
}
|
||||
// need inline
|
||||
auto sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(node, kAttrKernelGraph);
|
||||
auto sub_ret = sub_graph->output();
|
||||
std::vector<std::string> input_formats;
|
||||
std::vector<TypeId> input_types;
|
||||
std::vector<std::string> output_formats;
|
||||
std::vector<TypeId> output_types;
|
||||
for (auto ¶m : sub_graph->inputs()) {
|
||||
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(param, 0);
|
||||
if (type_id == kTypeUnknown) {
|
||||
type_id = common::AnfAlgo::GetOutputInferDataType(param, 0);
|
||||
}
|
||||
if (type_id > kMonadTypeBegin && type_id < kMonadTypeEnd) {
|
||||
continue;
|
||||
}
|
||||
input_types.push_back(type_id);
|
||||
input_formats.push_back(AnfAlgo::GetOutputFormat(param, 0));
|
||||
}
|
||||
for (size_t i = 0; i < AnfUtils::GetOutputTensorNum(node); ++i) {
|
||||
output_formats.push_back(AnfAlgo::GetOutputFormat(sub_ret, i));
|
||||
output_types.push_back(common::AnfAlgo::GetOutputInferDataType(sub_ret, i));
|
||||
}
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetInputsFormat(input_formats);
|
||||
builder->SetInputsDeviceType(input_types);
|
||||
builder->SetOutputsFormat(output_formats);
|
||||
builder->SetOutputsDeviceType(output_types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||
}
|
||||
|
||||
template <typename T, typename Scalar>
|
||||
ValuePtr GetTensorValue(const tensor::TensorPtr &tensor) {
|
||||
ValuePtr ret;
|
||||
|
|
|
@ -118,6 +118,8 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const
|
|||
void SetInputOutputNames(const std::vector<std::string> &input_names, const std::vector<std::string> &output_names,
|
||||
const AnfNodePtr &node);
|
||||
|
||||
void SelectCallInlineKernelInfo(const CNodePtr &node);
|
||||
|
||||
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/optimizer/format_type/reselect_call_inline_format.h"
|
||||
#include "plugin/device/ascend/optimizer/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto Xs = "Xs";
|
||||
constexpr auto call_inline = "call_inline";
|
||||
constexpr auto new_call_inline = "new_call_inline";
|
||||
} // namespace
|
||||
bool ReselectCallInlineFormat::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr BuildCallInline(const PatternMap &m, const AnfNodePtr &) {
|
||||
auto anf = m.Get(call_inline);
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
auto cnode = anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
SelectCallInlineKernelInfo(cnode);
|
||||
return cnode;
|
||||
}
|
||||
|
||||
void ReselectCallInlineFormat::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddSeqVar(Xs).AddCNode(call_inline, {prim::kPrimCallInline, Xs});
|
||||
}
|
||||
|
||||
void ReselectCallInlineFormat::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(new_call_inline, {prim::kPrimCallInline, Xs}, BuildCallInline);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RESELECT_CALL_INLINE_FORMAT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RESELECT_CALL_INLINE_FORMAT_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ReselectCallInlineFormat : public PatternToPatternPass {
|
||||
public:
|
||||
ReselectCallInlineFormat() : PatternToPatternPass("reselect_call_inline_format") {}
|
||||
~ReselectCallInlineFormat() override = default;
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RESELECT_CALL_INLINE_FORMAT_H_
|
|
@ -619,6 +619,9 @@ void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value
|
|||
|
||||
void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph &graph) {
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
if (graph.need_inline()) {
|
||||
return;
|
||||
}
|
||||
auto graph_id = graph.graph_id();
|
||||
MS_LOG(INFO) << "AssignStaticMemoryInput start for graph " << graph_id;
|
||||
auto graph_inputs = GetGraphInputs(graph);
|
||||
|
@ -711,6 +714,9 @@ void KernelRuntime::GetDeviceAddress(const AnfNodePtr &item,
|
|||
}
|
||||
|
||||
void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph &graph) {
|
||||
if (graph.need_inline()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "AssignStaticMemoryOutput start for graph " << graph.graph_id();
|
||||
auto nodes = common::AnfAlgo::GetAllOutput(graph.output(), {prim::kPrimTupleGetItem});
|
||||
std::vector<session::KernelWithIndex> non_communication_op;
|
||||
|
|
|
@ -83,6 +83,7 @@ const char FUNC_GRAPH_FLAG_IGNORE_VALUE[] = "ignore_value";
|
|||
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
|
||||
const char FUNC_GRAPH_FLAG_SPARSE_BPROP[] = "sparse_bprop";
|
||||
const char FUNC_GRAPH_FLAG_NO_INLINE[] = "no_inline";
|
||||
const char FUNC_GRAPH_FLAG_NEED_BACKEND_INLINE[] = "need_backend_inline";
|
||||
const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
|
||||
const char FUNC_GRAPH_FLAG_CORE[] = "core";
|
||||
const char FUNC_GRAPH_FLAG_K_GRAPH[] = "k_graph";
|
||||
|
|
|
@ -1481,6 +1481,7 @@ GVAR_DEF(PrimitivePtr, kPrimSelect, std::make_shared<Primitive>(kSelect));
|
|||
GVAR_DEF(PrimitivePtr, kPrimCall, std::make_shared<Primitive>("call"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimRaise, std::make_shared<Primitive>("raise"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimJoinedStr, std::make_shared<Primitive>("joinedstr"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCallInline, std::make_shared<Primitive>("call_inline"));
|
||||
|
||||
GVAR_DEF(PrimitivePtr, kPrimMakeTuple, std::make_shared<Primitive>(kMakeTuple));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMakeSlice, std::make_shared<Primitive>("make_slice"));
|
||||
|
|
|
@ -21,7 +21,9 @@ import pytest
|
|||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore as ms
|
||||
|
||||
from mindspore.common import mutable
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
|
@ -29,6 +31,28 @@ from mindspore.nn.optim import Momentum
|
|||
from mindspore._extends import cell_attr_register
|
||||
|
||||
|
||||
def seed_set():
|
||||
set_seed(1)
|
||||
np.random.seed(1)
|
||||
random.seed(1)
|
||||
|
||||
|
||||
def train(net, data, label):
|
||||
learning_rate = 0.05
|
||||
momentum = 0.9
|
||||
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||
train_network.set_train()
|
||||
res_list = []
|
||||
for _ in range(20):
|
||||
res = train_network(data, label)
|
||||
res_list.append(res[0].asnumpy())
|
||||
return res_list
|
||||
|
||||
|
||||
class CellDense(nn.Cell):
|
||||
@cell_attr_register
|
||||
def __init__(self):
|
||||
|
@ -60,6 +84,297 @@ class MLP(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
def get_pynative_mlp_cell_reuse_loss():
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
# gen data
|
||||
seed_set()
|
||||
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
|
||||
|
||||
# cell reuse
|
||||
net = MLP()
|
||||
loss_list = train(net, data, label)
|
||||
return loss_list
|
||||
|
||||
|
||||
def get_mlp_cell_reuse_loss(reuse):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
os.environ['MS_DEV_GRAPH_REUSE'] = reuse
|
||||
|
||||
# gen data
|
||||
seed_set()
|
||||
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
|
||||
|
||||
# cell reuse
|
||||
net = MLP()
|
||||
loss_list = train(net, data, label)
|
||||
del os.environ['MS_DEV_GRAPH_REUSE']
|
||||
|
||||
return loss_list
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell_reuse_0():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP with cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
loss_graph = get_mlp_cell_reuse_loss(str(0))
|
||||
loss_pynative = get_pynative_mlp_cell_reuse_loss()
|
||||
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell_reuse_1():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP with cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
loss_graph = get_mlp_cell_reuse_loss(str(1))
|
||||
loss_pynative = get_pynative_mlp_cell_reuse_loss()
|
||||
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell_reuse_2():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP with cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
loss_graph = get_mlp_cell_reuse_loss(str(2))
|
||||
loss_pynative = get_pynative_mlp_cell_reuse_loss()
|
||||
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
|
||||
|
||||
|
||||
class CellDense2(nn.Cell):
|
||||
@cell_attr_register
|
||||
def __init__(self):
|
||||
super(CellDense2, self).__init__()
|
||||
self.fc = nn.Dense(100, 100)
|
||||
|
||||
def construct(self, input_x):
|
||||
out = self.fc(input_x)
|
||||
return input_x, out
|
||||
|
||||
|
||||
class MLP2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MLP2, self).__init__()
|
||||
self.batch_size = 1
|
||||
self.fc = nn.Dense(200, 100)
|
||||
|
||||
layers = []
|
||||
for _ in range(12):
|
||||
layer = CellDense2()
|
||||
layers.append(layer)
|
||||
|
||||
self.layers = nn.CellList(layers)
|
||||
|
||||
def construct(self, out):
|
||||
out = self.fc(out)
|
||||
for layer_module in self.layers:
|
||||
tmp, out = layer_module(out)
|
||||
out += tmp
|
||||
return out
|
||||
|
||||
|
||||
def get_pynative_mlp_cell_reuse_loss_2():
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
# gen data
|
||||
seed_set()
|
||||
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
|
||||
|
||||
# cell reuse
|
||||
net = MLP2()
|
||||
loss_list = train(net, data, label)
|
||||
return loss_list
|
||||
|
||||
|
||||
def get_mlp_cell_reuse_loss_2(reuse):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
os.environ['MS_DEV_GRAPH_REUSE'] = reuse
|
||||
|
||||
# gen data
|
||||
seed_set()
|
||||
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
|
||||
|
||||
# cell reuse
|
||||
net = MLP2()
|
||||
loss_list = train(net, data, label)
|
||||
del os.environ['MS_DEV_GRAPH_REUSE']
|
||||
return loss_list
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell_2_reuse_0():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP (need flatten maketuple) with cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
loss_graph = get_mlp_cell_reuse_loss_2(str(0))
|
||||
loss_pynative = get_pynative_mlp_cell_reuse_loss_2()
|
||||
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell_2_reuse_1():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP (need flatten maketuple) with cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
loss_graph = get_mlp_cell_reuse_loss_2(str(1))
|
||||
loss_pynative = get_pynative_mlp_cell_reuse_loss_2()
|
||||
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell_2_reuse_2():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP (need flatten maketuple) with cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
loss_graph = get_mlp_cell_reuse_loss_2(str(2))
|
||||
loss_pynative = get_pynative_mlp_cell_reuse_loss_2()
|
||||
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
|
||||
|
||||
|
||||
class CellDenseWithControlFlow(nn.Cell):
|
||||
@cell_attr_register
|
||||
def __init__(self):
|
||||
super(CellDenseWithControlFlow, self).__init__()
|
||||
self.fc = nn.Dense(100, 100)
|
||||
|
||||
def construct(self, input_x, x):
|
||||
out = self.fc(input_x)
|
||||
if x > 0:
|
||||
out = self.fc(out)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
class MLPWithControlFlow(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MLPWithControlFlow, self).__init__()
|
||||
self.batch_size = 1
|
||||
self.fc = nn.Dense(200, 100)
|
||||
|
||||
layers = []
|
||||
for _ in range(12):
|
||||
layer = CellDenseWithControlFlow()
|
||||
layers.append(layer)
|
||||
|
||||
self.layers = nn.CellList(layers)
|
||||
|
||||
def construct(self, out):
|
||||
out = self.fc(out)
|
||||
for layer_module in self.layers:
|
||||
x = mutable(ms.Tensor(np.array(1), dtype=ms.int32))
|
||||
out = layer_module(out, x)
|
||||
return out
|
||||
|
||||
|
||||
def get_pynative_mlp_cell_reuse_infer():
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
# gen data
|
||||
seed_set()
|
||||
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
|
||||
|
||||
# cell reuse
|
||||
net = MLPWithControlFlow()
|
||||
ret = net(data)
|
||||
return ret.asnumpy()
|
||||
|
||||
|
||||
def get_mlp_cell_reuse_infer(reuse):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
os.environ['MS_DEV_GRAPH_REUSE'] = reuse
|
||||
|
||||
# gen data
|
||||
seed_set()
|
||||
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
|
||||
|
||||
# cell reuse
|
||||
net = MLPWithControlFlow()
|
||||
ret = net(data)
|
||||
del os.environ['MS_DEV_GRAPH_REUSE']
|
||||
return ret.asnumpy()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell_with_control_flow_reuse_0():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP with cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
loss_graph = get_mlp_cell_reuse_infer(str(0))
|
||||
loss_pynative = get_pynative_mlp_cell_reuse_infer()
|
||||
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell_with_control_flow_reuse_1():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP with cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
loss_graph = get_mlp_cell_reuse_infer(str(1))
|
||||
loss_pynative = get_pynative_mlp_cell_reuse_infer()
|
||||
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell_with_control_flow_reuse_2():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP with cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
loss_graph = get_mlp_cell_reuse_infer(str(2))
|
||||
loss_pynative = get_pynative_mlp_cell_reuse_infer()
|
||||
assert np.allclose(loss_pynative, loss_graph, 0.001, 0.001)
|
||||
|
||||
|
||||
class CellDropDense(nn.Cell):
|
||||
@cell_attr_register
|
||||
def __init__(self):
|
||||
|
@ -93,85 +408,6 @@ class DropMLP(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
def seed_set():
|
||||
set_seed(1)
|
||||
np.random.seed(1)
|
||||
random.seed(1)
|
||||
|
||||
|
||||
def train(net, data, label):
|
||||
learning_rate = 0.05
|
||||
momentum = 0.9
|
||||
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||
train_network.set_train()
|
||||
res_list = []
|
||||
for _ in range(20):
|
||||
res = train_network(data, label)
|
||||
res_list.append(res[0].asnumpy())
|
||||
return res_list
|
||||
|
||||
|
||||
expect_value = [4.6052, 4.5553, 4.4607, 4.3261, 4.1556, 3.9532, 3.7227,
|
||||
3.4675, 3.1912, 2.8974, 2.5900, 2.2736, 1.9538, 1.6376,
|
||||
1.3335, 1.0511, 0.8002, 0.5884, 0.4195, 0.2920]
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell_reuse():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP with cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
os.environ['MS_DEV_GRAPH_REUSE'] = str(1)
|
||||
|
||||
# gen data
|
||||
seed_set()
|
||||
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
|
||||
|
||||
# cell reuse
|
||||
net = MLP()
|
||||
loss_list = train(net, data, label)
|
||||
del os.environ['MS_DEV_GRAPH_REUSE']
|
||||
|
||||
assert np.allclose(loss_list, expect_value, 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mlp_cell():
|
||||
"""
|
||||
Feature: cell reuse.
|
||||
Description: MLP without cell reuse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
os.environ['MS_DEV_GRAPH_REUSE'] = str(0)
|
||||
|
||||
# gen data
|
||||
seed_set()
|
||||
data = Tensor(np.random.random([1, 200]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array(np.random.randint(100, size=[1]), dtype=np.int32))
|
||||
|
||||
# cell not reuse
|
||||
net = MLP()
|
||||
loss_list = train(net, data, label)
|
||||
del os.environ['MS_DEV_GRAPH_REUSE']
|
||||
|
||||
assert np.allclose(loss_list, expect_value, 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue