forked from mindspore-Ecosystem/mindspore
!2931 Ascend control flow not split graphs
Merge pull request !2931 from zhoufeng/liantiao1
This commit is contained in:
commit
130cc29603
|
@ -40,6 +40,9 @@ using kernel::KernelBuildInfoPtr;
|
|||
using kernel::KernelMod;
|
||||
using kernel::KernelModPtr;
|
||||
namespace {
|
||||
constexpr size_t kNopNodeInputSize = 2;
|
||||
constexpr size_t kNopNodeRealInputIndex = 1;
|
||||
|
||||
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
std::vector<size_t> shape_size_t;
|
||||
|
@ -48,6 +51,26 @@ std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
if (tuple_get_item->size() != kTupleGetItemInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||
}
|
||||
return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
|
||||
}
|
||||
|
||||
size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
if (tuple_get_item->size() != kTupleGetItemInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||
}
|
||||
auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(output_index_value_node);
|
||||
auto value_node = output_index_value_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
return IntToSize(GetValue<int>(value_node->value()));
|
||||
}
|
||||
|
||||
KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
if (anf_node->isa<ValueNode>()) {
|
||||
|
@ -83,49 +106,47 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
|
|||
}
|
||||
}
|
||||
|
||||
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
|
||||
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, int index,
|
||||
bool visit_nop_node,
|
||||
const std::vector<PrimitivePtr> &return_types) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
for (const auto &prim_type : return_types) {
|
||||
if (CheckPrimitiveType(anf_node, prim_type)) {
|
||||
return std::make_pair(anf_node, index);
|
||||
}
|
||||
if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
|
||||
return CheckPrimitiveType(anf_node, prim_type);
|
||||
})) {
|
||||
return KernelWithIndex(anf_node, index);
|
||||
}
|
||||
if (anf_node->isa<ValueNode>()) {
|
||||
return std::make_pair(anf_node, 0);
|
||||
} else if (anf_node->isa<Parameter>()) {
|
||||
return std::make_pair(anf_node, 0);
|
||||
} else if (anf_node->isa<CNode>()) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input0 = cnode->input(0);
|
||||
MS_EXCEPTION_IF_NULL(input0);
|
||||
if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
||||
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||
}
|
||||
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(input2);
|
||||
auto value_node = input2->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
int item_idx = GetValue<int>(value_node->value());
|
||||
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
|
||||
visit_nop_node, return_types);
|
||||
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
|
||||
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
|
||||
} else if (opt::IsNopNode(cnode) && visit_nop_node) {
|
||||
if (cnode->inputs().size() == 2) {
|
||||
return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
|
||||
}
|
||||
} else {
|
||||
return std::make_pair(anf_node, index);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The input is invalid";
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
return KernelWithIndex(anf_node, 0);
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
|
||||
auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode),
|
||||
GetTupleGetItemOutIndex(cnode), visit_nop_node, return_types);
|
||||
if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
|
||||
MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
|
||||
auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
const std::vector<AnfNodePtr> &make_tuple_inputs = make_tuple->inputs();
|
||||
size_t make_tuple_input_index = item_with_index_tmp.second + 1;
|
||||
if (make_tuple_input_index >= make_tuple_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size()
|
||||
<< "].";
|
||||
}
|
||||
return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, visit_nop_node, return_types);
|
||||
}
|
||||
return item_with_index_tmp;
|
||||
}
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimDepend) || CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
|
||||
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types);
|
||||
}
|
||||
if (opt::IsNopNode(cnode) && visit_nop_node) {
|
||||
if (cnode->size() != kNopNodeInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString();
|
||||
}
|
||||
return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, visit_nop_node, return_types);
|
||||
}
|
||||
return KernelWithIndex(anf_node, index);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
|
||||
|
@ -591,7 +612,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
|
|||
if (opt::IsNopNode(node) && visit_nop_node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() == 2) {
|
||||
if (cnode->size() == kNopNodeInputSize) {
|
||||
return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
|
||||
|
@ -613,7 +634,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
|
|||
if (opt::IsNopNode(node) && visit_nop_node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() == 2) {
|
||||
if (cnode->inputs().size() == kNopNodeInputSize) {
|
||||
return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
|
||||
|
@ -806,7 +827,7 @@ bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
|
|||
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
|
||||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
|
||||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
|
||||
IsPrimitive(input, prim::kPrimReturn);
|
||||
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
|
||||
return !is_virtual_node;
|
||||
}
|
||||
|
||||
|
@ -1117,5 +1138,14 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s
|
|||
}
|
||||
return GetCNodeOutputPrecision(kernel_with_index.first);
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal null input of cnode.";
|
||||
}
|
||||
auto input = node->input(kAnfPrimitiveIndex);
|
||||
return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,9 +42,12 @@ using DeviceAddress = device::DeviceAddress;
|
|||
using DeviceAddressPtr = device::DeviceAddressPtr;
|
||||
class AnfRuntimeAlgorithm {
|
||||
public:
|
||||
// get real input node of tuple_get_item
|
||||
static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
|
||||
static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
|
||||
// get input_anf_node's real kernel by recurse
|
||||
static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
|
||||
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
|
||||
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, int output_index,
|
||||
bool visit_nop_node = false,
|
||||
const std::vector<PrimitivePtr> &return_types = {
|
||||
prim::kPrimMakeTuple});
|
||||
|
@ -205,6 +208,7 @@ class AnfRuntimeAlgorithm {
|
|||
static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
|
||||
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.
|
||||
static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
|
||||
static bool IsCondControlKernel(const CNodePtr &node);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "backend/session/ascend_control_parser.h"
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/union_find_set.h"
|
||||
#include "runtime/device/ascend/ascend_label_assign.h"
|
||||
|
@ -31,94 +32,11 @@ static constexpr size_t kCNodePartialLength = 2;
|
|||
static constexpr size_t kCNodePartialFunc = 1;
|
||||
static constexpr size_t kCNodeSwitchLayerBranch = 2;
|
||||
static constexpr size_t kCNodeSwitchLayerLength = 3;
|
||||
static constexpr size_t kCNodeAssignTarget = 1;
|
||||
static constexpr size_t kCNodeAssignSource = 2;
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
|
||||
auto &nodes = parent_graph->execution_order();
|
||||
CNodePtr last_jump_node = nullptr;
|
||||
for (auto &node : nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) {
|
||||
if (child_graph->get_start_label() == node->input(kCNodeCallArg)) {
|
||||
return node;
|
||||
}
|
||||
last_jump_node = node;
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) {
|
||||
if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
|
||||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) {
|
||||
return node;
|
||||
}
|
||||
last_jump_node = node;
|
||||
}
|
||||
}
|
||||
if (last_jump_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
|
||||
}
|
||||
return last_jump_node;
|
||||
}
|
||||
|
||||
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
if (memo->find(kg.get()) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(kg.get());
|
||||
|
||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
|
||||
for (auto &iter : real_inputs) {
|
||||
auto ¶ = iter.first;
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
if (para->isa<Parameter>()) {
|
||||
union_find_set->Add(para);
|
||||
}
|
||||
for (auto &arg : iter.second) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (!arg->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
union_find_set->Add(arg);
|
||||
}
|
||||
}
|
||||
for (auto &child : kg->child_graph_order()) {
|
||||
InitUnionFindSet(NOT_NULL(child), union_find_set, memo);
|
||||
}
|
||||
}
|
||||
|
||||
static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
if (memo->find(kg.get()) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(kg.get());
|
||||
|
||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
|
||||
for (auto &iter : real_inputs) {
|
||||
auto ¶ = iter.first;
|
||||
for (auto &arg : iter.second) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (!arg->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) {
|
||||
continue;
|
||||
}
|
||||
union_find_set->Union(arg, para);
|
||||
}
|
||||
}
|
||||
for (auto &child : kg->child_graph_order()) {
|
||||
UnionParentParameter(NOT_NULL(child), union_find_set, memo);
|
||||
}
|
||||
}
|
||||
|
||||
static UnionFindSet<AnfNodePtr> MakeUnionFindSet(NotNull<KernelGraphPtr> root_kg) {
|
||||
UnionFindSet<AnfNodePtr> result;
|
||||
std::set<KernelGraphPtr> memo;
|
||||
InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo));
|
||||
return result;
|
||||
}
|
||||
|
||||
static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> main_parameter,
|
||||
const std::set<AnfNodePtr> ¶meter_reuse_set,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
|
@ -135,8 +53,9 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
|
|||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to "
|
||||
<< main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get());
|
||||
MS_LOG(INFO) << "In " << kg->ToString() << " replace " << para->DebugString() << " of graph "
|
||||
<< AnfAlgo::GetGraphId(para.get()) << " to " << main_parameter->DebugString() << " of graph "
|
||||
<< AnfAlgo::GetGraphId(main_parameter.get().get());
|
||||
kg->ReplaceNode(NOT_NULL(para), main_parameter);
|
||||
}
|
||||
|
||||
|
@ -145,7 +64,7 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
|
|||
}
|
||||
}
|
||||
|
||||
static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr key,
|
||||
static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr &key,
|
||||
const std::set<AnfNodePtr> ¶meter_reuse_set) {
|
||||
AnfNodePtr main_parameter = key;
|
||||
std::set<AnfNodePtr> root_inputs_set;
|
||||
|
@ -160,8 +79,19 @@ static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNod
|
|||
return main_parameter;
|
||||
}
|
||||
|
||||
static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) {
|
||||
auto parameter_reuse_sets = parameter_set->GetSets();
|
||||
static void ReuseParameter(NotNull<KernelGraphPtr> root_kg,
|
||||
const std::vector<std::pair<AnfNodePtr, AnfNodePtr>> &link_list) {
|
||||
// make union find set
|
||||
UnionFindSet<AnfNodePtr> union_find_set;
|
||||
for (auto &[param, arg] : link_list) {
|
||||
union_find_set.Add(param);
|
||||
union_find_set.Add(arg);
|
||||
}
|
||||
for (auto &[param, arg] : link_list) {
|
||||
union_find_set.Union(param, arg);
|
||||
}
|
||||
auto parameter_reuse_sets = union_find_set.GetSets();
|
||||
|
||||
for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) {
|
||||
if (parameter_reuse_set.size() <= 1) {
|
||||
continue;
|
||||
|
@ -172,7 +102,7 @@ static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet
|
|||
}
|
||||
}
|
||||
|
||||
CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
|
||||
static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
|
||||
for (size_t i = start; i < list.size() - 1; ++i) {
|
||||
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) {
|
||||
return list[i];
|
||||
|
@ -181,71 +111,287 @@ CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
|
||||
std::set<KernelGraphPtr> memo;
|
||||
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
|
||||
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg);
|
||||
std::map<uint32_t, KernelGraphPtr> graph_id_map;
|
||||
for (auto &g : memo) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id()
|
||||
<< ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString();
|
||||
static void UpdateLabelIdToLabelSetMap(const std::vector<CNodePtr> &exec_order,
|
||||
const NotNull<std::map<uint32_t, CNodePtr> *> label_id_to_label_set) {
|
||||
for (auto &node : exec_order) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimLabelSet)) {
|
||||
continue;
|
||||
}
|
||||
graph_id_map[g->graph_id()] = g;
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) {
|
||||
MS_LOG(EXCEPTION) << node->DebugString() << " has no attr kAttrLabelIndex";
|
||||
}
|
||||
uint32_t label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
if (auto iter = label_id_to_label_set->find(label_id); iter != label_id_to_label_set->end()) {
|
||||
MS_LOG(EXCEPTION) << "There are more than one node has same label id " << label_id
|
||||
<< ", node: " << iter->second->DebugString() << " and " << node->DebugString();
|
||||
}
|
||||
(*label_id_to_label_set)[label_id] = node;
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<CNodePtr> GetTargetLabelSetNodes(NotNull<CNodePtr> jump_node,
|
||||
const std::map<uint32_t, CNodePtr> &label_id_to_label_set) {
|
||||
std::vector<uint32_t> target_label_list;
|
||||
std::vector<CNodePtr> target_labelset_nodes;
|
||||
if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelGoto)) {
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, jump_node)) {
|
||||
MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kAttrLabelIndex";
|
||||
}
|
||||
uint32_t label_id = AnfAlgo::GetNodeAttr<uint32_t>(jump_node.get(), kAttrLabelIndex);
|
||||
target_label_list.push_back(label_id);
|
||||
} else if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelSwitch)) {
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, jump_node)) {
|
||||
MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kPrimLabelSwitch";
|
||||
}
|
||||
target_label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(jump_node.get(), kAttrLabelSwitchList);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unknown type jump node " << jump_node->DebugString();
|
||||
}
|
||||
|
||||
for (auto label_id : target_label_list) {
|
||||
auto iter = label_id_to_label_set.find(label_id);
|
||||
if (iter == label_id_to_label_set.end()) {
|
||||
MS_LOG(EXCEPTION) << "Connot find LabelSet node has label id " << label_id;
|
||||
}
|
||||
target_labelset_nodes.push_back(iter->second);
|
||||
}
|
||||
return target_labelset_nodes;
|
||||
}
|
||||
|
||||
static void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNull<std::vector<CNodePtr> *> exec_order) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto exec_iter = std::find(exec_order->begin(), exec_order->end(), node);
|
||||
if (exec_iter == exec_order->end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find " << node->DebugString() << " in exec order.";
|
||||
}
|
||||
exec_order->erase(exec_iter);
|
||||
}
|
||||
|
||||
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
|
||||
std::set<KernelGraphPtr> memo;
|
||||
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> link_list;
|
||||
// Insert Assign
|
||||
ChildGraphDataAssign(graph_id_map);
|
||||
// Make UnionFindSet
|
||||
UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg);
|
||||
ChildGraphDataAssign(kg, NOT_NULL(&link_list), NOT_NULL(&memo));
|
||||
// Reuse Parameter
|
||||
ReuseParameter(kg, NOT_NULL(¶meter_set));
|
||||
ReuseParameter(kg, link_list);
|
||||
// replace call by label goto / label switch
|
||||
memo.clear();
|
||||
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
|
||||
// assign label resource
|
||||
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg);
|
||||
}
|
||||
|
||||
void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
|
||||
const std::set<KernelGraphPtr> &graph_list) {
|
||||
std::vector<CNodePtr> exec_order = root_graph->execution_order();
|
||||
std::set<CNodePtr> search_list(exec_order.begin(), exec_order.end());
|
||||
std::set<AnfNodePtr> root_inputs(root_graph->inputs().begin(), root_graph->inputs().end());
|
||||
auto ref_map = root_graph->GetRefMap();
|
||||
ReferenceCounter parameter_count([](int32_t read, int32_t write) -> bool { return write == 1; });
|
||||
std::multimap<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> ref_multimap;
|
||||
std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()),
|
||||
[](const std::pair<std::pair<AnfNodePtr, size_t>, std::pair<AnfNodePtr, size_t>> &p)
|
||||
-> std::pair<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> {
|
||||
return {p.first.first, {p.first.second, p.second.first, p.second.second}};
|
||||
});
|
||||
std::set<CNodePtr> all_nodes;
|
||||
std::map<AnfNodePtr, CNodePtr> para_to_written_node;
|
||||
for (auto &graph : graph_list) {
|
||||
auto out = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(out);
|
||||
search_list.insert(out->cast<CNodePtr>());
|
||||
auto nodes = TopoSort(out);
|
||||
for (auto &node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode != nullptr) {
|
||||
all_nodes.insert(cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
// prepare referance count
|
||||
for (auto &node : search_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// if assign node
|
||||
std::set<AnfNodePtr> refed_parameters;
|
||||
for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) {
|
||||
refed_parameters.insert(std::get<1>(iter->second));
|
||||
}
|
||||
|
||||
for (auto &in : node->inputs()) {
|
||||
auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first;
|
||||
if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) {
|
||||
continue;
|
||||
}
|
||||
if (refed_parameters.find(visit_node) != refed_parameters.end()) {
|
||||
parameter_count.AddWriteCount(visit_node, 1);
|
||||
para_to_written_node[visit_node] = node;
|
||||
} else {
|
||||
parameter_count.AddReadCount(visit_node, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (parameter_count.HasValidElem()) {
|
||||
auto [para, read, written] = parameter_count.GetOneValidElem();
|
||||
MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times.";
|
||||
auto assign_iter = para_to_written_node.find(para);
|
||||
if (assign_iter == para_to_written_node.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find assign node that write " << para->DebugString();
|
||||
}
|
||||
auto &assign_node = assign_iter->second;
|
||||
MS_EXCEPTION_IF_NULL(assign_node);
|
||||
if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign)) {
|
||||
parameter_count.EraseElem(para);
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "Erase " << assign_node->DebugString(5);
|
||||
EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order));
|
||||
|
||||
auto source = AnfAlgo::VisitKernelWithReturnType(assign_node->input(kCNodeAssignSource), 0).first;
|
||||
parameter_count.AddReadCount(source, -1);
|
||||
parameter_count.AddWriteCount(para, -1);
|
||||
for (auto &node : all_nodes) {
|
||||
for (size_t i = 0; i < node->size(); ++i) {
|
||||
if (node->input(i) == para) {
|
||||
MS_LOG_INFO << "Replace " << node->DebugString() << " input " << i << " by " << source->DebugString();
|
||||
node->set_input(i, source);
|
||||
}
|
||||
}
|
||||
}
|
||||
parameter_count.AddReadCount(source, 1);
|
||||
parameter_count.AddReadCount(para, -1);
|
||||
}
|
||||
root_graph->set_execution_order(exec_order);
|
||||
}
|
||||
|
||||
void AscendControlParser::EraseLabel(NotNull<KernelGraphPtr> root_graph) {
|
||||
std::vector<CNodePtr> exec_order = root_graph->execution_order();
|
||||
ReferenceCounter label_count([](int32_t read, int32_t write) -> bool { return read <= 1; });
|
||||
std::map<AnfNodePtr, CNodePtr> label_to_written_node;
|
||||
std::map<uint32_t, CNodePtr> label_id_to_label_set;
|
||||
UpdateLabelIdToLabelSetMap(exec_order, NOT_NULL(&label_id_to_label_set));
|
||||
CNodePtr last_node = nullptr;
|
||||
for (auto &cur_node : exec_order) {
|
||||
MS_EXCEPTION_IF_NULL(cur_node);
|
||||
if (AnfAlgo::IsCondControlKernel(cur_node)) {
|
||||
std::vector<CNodePtr> target_labelset_nodes = GetTargetLabelSetNodes(NOT_NULL(cur_node), label_id_to_label_set);
|
||||
for (auto &label_set : target_labelset_nodes) {
|
||||
label_count.AddReadCount(label_set, 1);
|
||||
label_to_written_node[label_set] = cur_node;
|
||||
}
|
||||
} else if (IsPrimitiveCNode(cur_node, prim::kPrimLabelSet)) {
|
||||
label_count.AddWriteCount(cur_node, 1);
|
||||
if (last_node != nullptr && !AnfAlgo::IsCondControlKernel(last_node)) {
|
||||
label_count.AddReadCount(cur_node, 1);
|
||||
label_to_written_node[cur_node] = last_node;
|
||||
}
|
||||
}
|
||||
last_node = cur_node;
|
||||
}
|
||||
|
||||
while (label_count.HasValidElem()) {
|
||||
auto [label_set, read, written] = label_count.GetOneValidElem();
|
||||
MS_LOG(INFO) << label_set->DebugString() << " was read " << read << " times, written " << written << " times.";
|
||||
auto iter = label_to_written_node.find(label_set);
|
||||
if (read > 0 && iter == label_to_written_node.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find node jump to " << label_set->DebugString();
|
||||
}
|
||||
CNodePtr jump_node = read > 0 ? iter->second : nullptr;
|
||||
if (jump_node == nullptr || IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) {
|
||||
MS_LOG(INFO) << "Erase node " << label_set->DebugString();
|
||||
EraseNodeFromExecOrder(label_set, NOT_NULL(&exec_order));
|
||||
}
|
||||
if (jump_node != nullptr && IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) {
|
||||
MS_LOG(INFO) << "Erase node " << jump_node->DebugString();
|
||||
EraseNodeFromExecOrder(jump_node, NOT_NULL(&exec_order));
|
||||
}
|
||||
label_count.EraseElem(label_set);
|
||||
}
|
||||
|
||||
root_graph->set_execution_order(exec_order);
|
||||
}
|
||||
|
||||
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
|
||||
std::set<KernelGraphPtr> memo;
|
||||
(void)RecurseGraph(root_graph, NOT_NULL(&memo));
|
||||
EraseParameter(root_graph, memo);
|
||||
EraseLabel(root_graph);
|
||||
}
|
||||
|
||||
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
|
||||
for (auto &iter : graph_id_map) {
|
||||
auto &kg = iter.second;
|
||||
MS_LOG(INFO) << "Data assign graph:" << kg->graph_id();
|
||||
MS_EXCEPTION_IF_NULL(kg);
|
||||
std::set<std::pair<AnfNodePtr, AnfNodePtr>> memo;
|
||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
|
||||
for (auto &it : real_inputs) {
|
||||
auto ¶meter = it.first;
|
||||
auto &args = it.second;
|
||||
for (auto &arg : args) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (memo.find({parameter, arg}) != memo.end()) {
|
||||
continue;
|
||||
} else {
|
||||
memo.emplace(parameter, arg);
|
||||
}
|
||||
auto unreuse_args_map = kg->unreuse_args();
|
||||
auto unreuse_arg_iter = unreuse_args_map.find(arg);
|
||||
if (unreuse_arg_iter == unreuse_args_map.end()) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
if (!arg->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << ".";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
|
||||
<< ", arg:" << arg->DebugString();
|
||||
std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallNode(
|
||||
NotNull<CNodePtr> call_node) {
|
||||
std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ret;
|
||||
if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) {
|
||||
MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node.";
|
||||
}
|
||||
if (call_node->size() <= kCNodeCallArg) {
|
||||
MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size();
|
||||
}
|
||||
const std::vector<AnfNodePtr> &call_node_inputs = call_node->inputs();
|
||||
auto call_arg = call_node_inputs[kCNodeCallArg];
|
||||
MS_EXCEPTION_IF_NULL(call_arg);
|
||||
if (IsValueNode<KernelGraph>(call_arg)) {
|
||||
ret.emplace_back(GetValueNode<KernelGraphPtr>(call_arg),
|
||||
std::vector<AnfNodePtr>(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end()));
|
||||
} else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) {
|
||||
auto switch_cnode = call_arg->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_cnode);
|
||||
const std::vector<AnfNodePtr> &switch_inputs = switch_cnode->inputs();
|
||||
if (switch_inputs.size() <= kCNodeSwitchCond) {
|
||||
MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size "
|
||||
<< switch_inputs.size();
|
||||
}
|
||||
for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) {
|
||||
const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter));
|
||||
ret.emplace_back(target_graph, args);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void AscendControlParser::ChildGraphDataAssign(
|
||||
NotNull<KernelGraphPtr> kg, const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
if (memo->find(kg) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(kg.get());
|
||||
|
||||
MS_LOG(INFO) << "Start link data for " << kg->ToString();
|
||||
const std::vector<CNodePtr> &nodes = kg->execution_order();
|
||||
|
||||
for (auto &node : nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimCall)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto child_graph_list = ParseCallNode(NOT_NULL(node));
|
||||
for (auto &[child_graph, args] : child_graph_list) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
const std::vector<AnfNodePtr> ¶ms = child_graph->inputs();
|
||||
if (args.size() != params.size()) {
|
||||
MS_LOG(EXCEPTION) << child_graph->ToString() << " needs " << params.size() << " inputs but call node "
|
||||
<< node->DebugString(5) << " gives " << args.size();
|
||||
}
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
if (args[i]->isa<Parameter>() && memo->find(child_graph) == memo->end()) {
|
||||
MS_LOG(INFO) << args[i]->DebugString() << " to " << params[i]->DebugString()
|
||||
<< " should be reused, continue.";
|
||||
link_list->emplace_back(args[i], params[i]);
|
||||
continue;
|
||||
}
|
||||
auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get()));
|
||||
if (target_graph_iter == graph_id_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
|
||||
}
|
||||
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg),
|
||||
NOT_NULL(parameter));
|
||||
|
||||
InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i]));
|
||||
}
|
||||
}
|
||||
kg->SetExecOrderByDefault();
|
||||
}
|
||||
kg->SetExecOrderByDefault();
|
||||
for (auto &child_graph : kg->child_graph_order()) {
|
||||
ChildGraphDataAssign(NOT_NULL(child_graph), link_list, memo);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -325,7 +471,7 @@ void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNul
|
|||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
return_node->input(kFirstDataInputIndex), attch_node.get()};
|
||||
auto depend_node = kg->NewCNode(inputs);
|
||||
return_node->set_input(1, depend_node);
|
||||
return_node->set_input(kFirstDataInputIndex, depend_node);
|
||||
}
|
||||
|
||||
void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
|
||||
|
@ -381,6 +527,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
|
|||
new_inputs.push_back(sub_label);
|
||||
cur_node->set_inputs(new_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get());
|
||||
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
|
||||
}
|
||||
|
||||
|
@ -409,9 +556,12 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
|||
std::vector<AnfNodePtr> new_switch_inputs = {
|
||||
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
|
||||
origin_switch_inputs[kCNodeSwitchCond]};
|
||||
std::vector<KernelGraphPtr> child_graphs;
|
||||
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
|
||||
// 3.1 branch kernel graph and args
|
||||
KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||
KernelGraphPtr branch_fg;
|
||||
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||
child_graphs.push_back(branch_fg);
|
||||
// 3.2 recurse sub graph
|
||||
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
||||
new_switch_inputs.push_back(branch_label);
|
||||
|
@ -420,6 +570,7 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
|||
|
||||
cur_node->set_inputs(new_switch_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), cur_node.get());
|
||||
MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString();
|
||||
}
|
||||
|
||||
|
@ -453,9 +604,12 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
|
|||
std::vector<AnfNodePtr> new_switch_inputs = {
|
||||
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
|
||||
origin_switch_inputs[kCNodeSwitchCond]};
|
||||
std::vector<KernelGraphPtr> child_graphs;
|
||||
for (size_t i = 0; i < branch_partial.size(); ++i) {
|
||||
// 3.1 branch kernel graph and args
|
||||
KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||
KernelGraphPtr branch_fg;
|
||||
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||
child_graphs.push_back(branch_fg);
|
||||
// 3.2 recurse sub graph
|
||||
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
||||
new_switch_inputs.push_back(branch_label);
|
||||
|
@ -463,13 +617,14 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
|
|||
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
|
||||
cur_node->set_inputs(new_switch_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), cur_node.get());
|
||||
MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString();
|
||||
}
|
||||
|
||||
KernelGraphPtr AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
|
||||
std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
|
||||
if (!node.get()->isa<CNode>()) {
|
||||
if (IsValueNode<KernelGraph>(node)) {
|
||||
return GetValueNode<KernelGraphPtr>(node);
|
||||
return {GetValueNode<KernelGraphPtr>(node), {}};
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString();
|
||||
}
|
||||
|
@ -485,12 +640,11 @@ KernelGraphPtr AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
|
|||
MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << ".";
|
||||
}
|
||||
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
|
||||
return branch_kg;
|
||||
return {branch_kg, std::vector<AnfNodePtr>(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end())};
|
||||
}
|
||||
|
||||
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
|
||||
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from,
|
||||
NotNull<AnfNodePtr> to) {
|
||||
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, const AnfNodePtr &jump_node,
|
||||
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to) {
|
||||
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
|
||||
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
|
||||
MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]";
|
||||
|
@ -500,22 +654,35 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr
|
|||
}
|
||||
for (size_t i = 0; i < from_outputs.size(); i++) {
|
||||
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
|
||||
if (assign_node != nullptr) {
|
||||
auto jump_node = GetJumpNode(from_graph, to_graph);
|
||||
const auto &from_graph_exe_order = from_graph->execution_order();
|
||||
auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node);
|
||||
if (jump_node_iter == from_graph_exe_order.end()) {
|
||||
MS_EXCEPTION_IF_NULL(jump_node);
|
||||
MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id();
|
||||
}
|
||||
// insert assign between jump_node -1 and jump_node
|
||||
if (jump_node_iter != from_graph_exe_order.begin()) {
|
||||
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
|
||||
}
|
||||
if (jump_node != nullptr) {
|
||||
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
|
||||
const auto &from_graph_exe_order = from_graph->execution_order();
|
||||
std::vector<CNodePtr> real_exe_order(from_graph_exe_order.size());
|
||||
size_t real_exe_order_size = 0;
|
||||
std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(),
|
||||
[&real_exe_order_size](const CNodePtr &node) -> bool {
|
||||
return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial))
|
||||
? false
|
||||
: (++real_exe_order_size, true);
|
||||
});
|
||||
real_exe_order.resize(real_exe_order_size);
|
||||
if (jump_node == nullptr) {
|
||||
if (!real_exe_order.empty()) {
|
||||
InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node));
|
||||
} else {
|
||||
InsertDependToGraph(from_graph, NOT_NULL(assign_node));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node);
|
||||
if (jump_node_iter == real_exe_order.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph "
|
||||
<< from_graph->ToString();
|
||||
}
|
||||
// insert assign between jump_node -1 and jump_node
|
||||
if (jump_node_iter != real_exe_order.begin()) {
|
||||
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
|
||||
}
|
||||
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -618,26 +785,45 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
|
|||
}
|
||||
}
|
||||
|
||||
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
|
||||
MS_LOG(INFO) << "Graph id:" << kg->graph_id();
|
||||
kg->SetExecOrderByDefault();
|
||||
auto call_nodes = kg->FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
|
||||
std::vector<KernelGraphPtr> child_graph_order;
|
||||
for (auto &call_node : call_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
|
||||
for (const auto &child_graph : call_child_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
if (child_graph != kg->parent_graph()) {
|
||||
child_graph->set_parent_graph(kg.get());
|
||||
}
|
||||
child_graph_order.push_back(child_graph);
|
||||
}
|
||||
void AscendControlParser::ReferenceCounter::AddReadCount(const AnfNodePtr &key, int32_t num) {
|
||||
auto iter = count_.find(key);
|
||||
if (iter != count_.end()) {
|
||||
iter->second.first += num;
|
||||
} else {
|
||||
count_[key] = {num, 0};
|
||||
}
|
||||
for (size_t i = 0; i < child_graph_order.size(); i++) {
|
||||
MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
|
||||
}
|
||||
|
||||
void AscendControlParser::ReferenceCounter::AddWriteCount(const AnfNodePtr &key, int32_t num) {
|
||||
auto iter = count_.find(key);
|
||||
if (iter != count_.end()) {
|
||||
iter->second.second += num;
|
||||
} else {
|
||||
count_[key] = {0, num};
|
||||
}
|
||||
kg->set_child_graph_order(child_graph_order);
|
||||
}
|
||||
|
||||
void AscendControlParser::ReferenceCounter::EraseElem(const AnfNodePtr &key) { count_.erase(key); }
|
||||
|
||||
bool AscendControlParser::ReferenceCounter::HasValidElem() const {
|
||||
auto it = std::find_if(count_.begin(), count_.end(),
|
||||
[this](const std::pair<AnfNodePtr, std::pair<uint32_t, uint32_t>> &p) -> bool {
|
||||
auto &[read, written] = p.second;
|
||||
return predicate_(read, written);
|
||||
});
|
||||
return it != count_.end();
|
||||
}
|
||||
|
||||
std::tuple<AnfNodePtr, int32_t, int32_t> AscendControlParser::ReferenceCounter::GetOneValidElem() const {
|
||||
auto it = std::find_if(count_.begin(), count_.end(),
|
||||
[this](const std::pair<AnfNodePtr, std::pair<uint32_t, uint32_t>> &p) -> bool {
|
||||
auto &[read, written] = p.second;
|
||||
return predicate_(read, written);
|
||||
});
|
||||
if (it == count_.end()) {
|
||||
MS_LOG(EXCEPTION) << "No valid parameter.";
|
||||
}
|
||||
return {it->first, it->second.first, it->second.second};
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include <map>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "utils/base_ref.h"
|
||||
#include "utils/contract.h"
|
||||
|
@ -29,16 +31,23 @@ namespace mindspore {
|
|||
namespace session {
|
||||
class AscendControlParser {
|
||||
public:
|
||||
static void ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map);
|
||||
static void LinkGraph(NotNull<KernelGraphPtr> kg);
|
||||
|
||||
static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node);
|
||||
static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
|
||||
NotNull<AnfNodePtr> second_node);
|
||||
static void ExecutorValidate(NotNull<KernelGraphPtr> root_graph);
|
||||
static void UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg);
|
||||
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, const AnfNodePtr &jump_node,
|
||||
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
|
||||
private:
|
||||
class ReferenceCounter;
|
||||
|
||||
static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
|
||||
static void EraseLabel(NotNull<KernelGraphPtr> root_graph);
|
||||
static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg,
|
||||
const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
static NotNull<CNodePtr> GetStartLabel(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
|
||||
const CNodePtr &last_label);
|
||||
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
|
||||
|
@ -53,11 +62,10 @@ class AscendControlParser {
|
|||
|
||||
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
|
||||
const CNodePtr &last_label);
|
||||
static KernelGraphPtr ParsePartial(NotNull<AnfNodePtr> node);
|
||||
|
||||
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
|
||||
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallNode(NotNull<CNodePtr> call_node);
|
||||
static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node);
|
||||
|
||||
// root graph order
|
||||
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
|
||||
|
@ -65,6 +73,19 @@ class AscendControlParser {
|
|||
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
};
|
||||
class AscendControlParser::ReferenceCounter {
|
||||
public:
|
||||
explicit ReferenceCounter(std::function<bool(int32_t, int32_t)> func) : predicate_(func), count_() {}
|
||||
void AddReadCount(const AnfNodePtr &key, int32_t num);
|
||||
void AddWriteCount(const AnfNodePtr &key, int32_t num);
|
||||
void EraseElem(const AnfNodePtr &key);
|
||||
bool HasValidElem() const;
|
||||
std::tuple<AnfNodePtr, int32_t, int32_t> GetOneValidElem() const;
|
||||
|
||||
private:
|
||||
std::function<bool(int32_t, int32_t)> predicate_;
|
||||
std::map<AnfNodePtr, std::pair<int32_t, int32_t>> count_;
|
||||
};
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -289,6 +289,17 @@ static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
|
|||
// this action should from bottom to top
|
||||
graph->UpdateCallRealInput();
|
||||
}
|
||||
|
||||
void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
|
||||
auto return_node = root_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
if (return_node->size() <= kReturnDataIndex) {
|
||||
return;
|
||||
}
|
||||
auto make_tuple = root_graph->NewCNode(
|
||||
{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
|
||||
root_graph->set_output(make_tuple);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
|
@ -305,22 +316,39 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
std::vector<KernelGraphPtr> all_graphs;
|
||||
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
||||
BackendOptimization(all_graphs);
|
||||
// split switch
|
||||
SplitGraphs(NOT_NULL(root_graph));
|
||||
// empty graph dont entry to backend
|
||||
if (root_graph->execution_order().empty()) {
|
||||
MS_LOG(INFO) << root_graph->ToString() << " is empty graph.";
|
||||
InsertMakeTupleForOutput(NOT_NULL(root_graph));
|
||||
root_graph->set_executable(false);
|
||||
InitRuntimeResource();
|
||||
return root_graph->graph_id();
|
||||
}
|
||||
// create parameter for multiple branch
|
||||
std::set<KernelGraphPtr> memo;
|
||||
CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
// insert goto labels and label_sets
|
||||
LinkChildGraphs(NOT_NULL(root_graph));
|
||||
// resource initialize
|
||||
InitRuntimeResource();
|
||||
// recurse compile child root_graph
|
||||
std::set<KernelGraphPtr> memo;
|
||||
RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
|
||||
IrFusionPass(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
|
||||
SelectKernel(NOT_NULL(root_graph));
|
||||
memo.clear();
|
||||
|
||||
HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
|
||||
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
|
||||
UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
// add make_tuple to the output graph
|
||||
InsertMakeTupleForOutput(NOT_NULL(root_graph));
|
||||
// root root_graph valiate,include genearte execute order and so on
|
||||
RootGraphExecutorValidate(NOT_NULL(root_graph));
|
||||
// adjust kernel
|
||||
|
@ -1682,7 +1710,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
|
|||
bool split_flag = false;
|
||||
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
|
||||
// update the root graph child graph order
|
||||
AscendControlParser::UpdateChildGraphOrder(graph);
|
||||
graph->UpdateChildGraphOrder();
|
||||
// get child list from current graph
|
||||
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(apply_list, cut_prims);
|
||||
if (child_graph_lists.size() > 1) {
|
||||
|
@ -1714,7 +1742,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
|
|||
}
|
||||
split_flag = true;
|
||||
}
|
||||
AscendControlParser::UpdateChildGraphOrder(graph);
|
||||
graph->UpdateChildGraphOrder();
|
||||
UpdateRealInput(graph, split_flag, memo);
|
||||
MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end";
|
||||
}
|
||||
|
@ -1753,5 +1781,216 @@ void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const Not
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
if (memo->find(graph.get()) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(graph.get());
|
||||
|
||||
graph->UpdateChildGraphOrder();
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
CreateMultiBranchOutput(NOT_NULL(child_graph), memo);
|
||||
}
|
||||
|
||||
std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
|
||||
auto node_list = GetCNodes(TopoSort(graph->get_return()));
|
||||
for (auto &node : node_list) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
|
||||
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
|
||||
// auto multi_output_param = graph->NewParameter();
|
||||
auto origin_inputs = graph->inputs();
|
||||
auto output_param = CreateNewParameterFromCNode(node, true, graph.get().get());
|
||||
MS_EXCEPTION_IF_NULL(graph->MutableInputs());
|
||||
graph->MutableInputs()->operator=(origin_inputs);
|
||||
graph->AddChildGraphResult(output_param);
|
||||
|
||||
std::vector<AnfNodePtr> depend_inputs = {
|
||||
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name()))), output_param, node};
|
||||
auto depend = graph->NewCNode(depend_inputs);
|
||||
need_replace_list.emplace(node, depend);
|
||||
MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString()
|
||||
<< ", depend node is " << depend->DebugString();
|
||||
// insert assign in order to transfer child graph output to parameter
|
||||
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node);
|
||||
for (auto &child_graph : child_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
if (child_graph->get_output_null()) {
|
||||
continue;
|
||||
}
|
||||
auto graph_output = child_graph->output();
|
||||
AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, NOT_NULL(graph_output),
|
||||
NOT_NULL(output_param));
|
||||
}
|
||||
}
|
||||
}
|
||||
// searching for nodes' input to replace call by depend(parameter, call)
|
||||
for (auto &node : node_list) {
|
||||
for (size_t i = 0; i < node->size(); ++i) {
|
||||
auto input = node->input(i);
|
||||
auto iter = need_replace_list.find(input);
|
||||
if (iter != need_replace_list.end()) {
|
||||
node->set_input(i, iter->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
if (memo->find(graph) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(graph.get());
|
||||
|
||||
opt::AscendBackendIRFusionOptimization(graph);
|
||||
opt::AscendBackendFuseBasicOpt(graph, true);
|
||||
opt::AscendBackendGraphKernelOpt(graph, true);
|
||||
graph->SetExecOrderByDefault();
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
if (save_graphs) {
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
std::string file_path =
|
||||
save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(graph->graph_id()) + ".ir";
|
||||
DumpIR(file_path, graph.get());
|
||||
}
|
||||
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
IrFusionPass(NOT_NULL(child_graph), memo);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
|
||||
MS_LOG(INFO) << "Start select kernel.";
|
||||
size_t raise_precision_count = 0;
|
||||
size_t reduce_precision_count = 0;
|
||||
|
||||
std::set<KernelGraphPtr> memo;
|
||||
(void)RecurseSelectKernelInfo(root_graph, NOT_NULL(&memo), &raise_precision_count, &reduce_precision_count);
|
||||
memo.clear();
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kGraphMode) {
|
||||
if (raise_precision_count > 0) {
|
||||
MS_LOG(WARNING) << "There has " << raise_precision_count
|
||||
<< " node/nodes used raise precision to selected the kernel!";
|
||||
}
|
||||
if (reduce_precision_count > 0) {
|
||||
MS_LOG(WARNING) << "There has " << raise_precision_count
|
||||
<< " node/nodes used reduce precision to selected the kernel!";
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
|
||||
NotNull<std::set<KernelGraphPtr> *> const memo,
|
||||
size_t *const raise_precision_count,
|
||||
size_t *const reduce_precision_count) const {
|
||||
if (memo->find(graph) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(graph.get());
|
||||
MS_LOG(INFO) << "Start to select kernel info in graph: " << graph->graph_id();
|
||||
|
||||
for (const auto &cnode : graph->execution_order()) {
|
||||
if (AnfAlgo::IsCondControlKernel(cnode)) {
|
||||
std::vector<KernelGraphPtr> child_graphs;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)) {
|
||||
child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph);
|
||||
}
|
||||
for (auto &child_graph : child_graphs) {
|
||||
RecurseSelectKernelInfo(NOT_NULL(child_graph), memo, raise_precision_count, reduce_precision_count);
|
||||
}
|
||||
}
|
||||
|
||||
auto status = device::ascend::SelectKernelInfo(cnode);
|
||||
if (status == device::ascend::kStatusRaisePrecision) {
|
||||
(*raise_precision_count)++;
|
||||
} else if (status == device::ascend::kStatusReducePrecision) {
|
||||
(*reduce_precision_count)++;
|
||||
}
|
||||
MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
|
||||
}
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
if (save_graphs) {
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
std::string file_path =
|
||||
save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(graph->graph_id()) + ".ir";
|
||||
DumpIR(file_path, graph.get());
|
||||
}
|
||||
MS_LOG(INFO) << "Finish selecting kernel info in graph: " << graph->graph_id();
|
||||
}
|
||||
|
||||
void AscendSession::HardwareOptimize(NotNull<KernelGraphPtr> graph,
|
||||
NotNull<std::set<KernelGraphPtr> *> const memo) const {
|
||||
if (memo->find(graph) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(graph.get());
|
||||
|
||||
MS_LOG(INFO) << "Start to do HardwareOptimize in graph: " << graph->graph_id();
|
||||
// convert kernel Graph to model
|
||||
predictmodel::StepConvertGraph(graph.get());
|
||||
|
||||
HardwareOptimize(graph.get());
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
HardwareOptimize(NOT_NULL(child_graph), memo);
|
||||
}
|
||||
MS_LOG(INFO) << "Finish doing HardwareOptimize in graph: " << graph->graph_id();
|
||||
}
|
||||
|
||||
void AscendSession::AssignStaticMemory(NotNull<KernelGraphPtr> graph,
|
||||
NotNull<std::set<KernelGraphPtr> *> const memo) const {
|
||||
if (memo->find(graph) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(graph.get());
|
||||
|
||||
MS_LOG(INFO) << "Start to assign static memory for parameter in graph: " << graph->graph_id();
|
||||
// assign static memory for parameters
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
runtime_instance->AssignStaticMemoryInput(graph.get().get());
|
||||
runtime_instance->AssignStaticMemoryValueNode(graph.get().get());
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
AssignStaticMemory(NOT_NULL(child_graph), memo);
|
||||
}
|
||||
MS_LOG(INFO) << "Finish assigning static memory for parameter in graph: " << graph->graph_id();
|
||||
}
|
||||
|
||||
void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
|
||||
NotNull<std::set<KernelGraphPtr> *> const memo) const {
|
||||
if (memo->find(graph) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(graph.get());
|
||||
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
UpdateRefOutputMap(NOT_NULL(child_graph), memo);
|
||||
// copy ref map to final graph
|
||||
auto child_ref_map = child_graph->GetRefMap();
|
||||
for (auto &item : child_ref_map) {
|
||||
if (graph->IsInRefOutputMap(item.first)) {
|
||||
MS_LOG(WARNING) << "The ref pair <" << item.first.first->DebugString() << ", " << item.first.second
|
||||
<< "> is already in " << graph->ToString();
|
||||
continue;
|
||||
}
|
||||
graph->AddRefCorrespondPairs(item.first, item.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -151,6 +151,15 @@ class AscendSession : public SessionBasic {
|
|||
// sync intial tensors' data to device
|
||||
void SyncInitialTenosrToDevice();
|
||||
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
// create parameter to receive data from multiple branch output
|
||||
void CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
void SelectKernel(NotNull<KernelGraphPtr> root_graph);
|
||||
void RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> const memo,
|
||||
size_t *const raise_precision_count, size_t *const reduce_precision_count) const;
|
||||
void IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
void HardwareOptimize(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||
|
||||
// member variables
|
||||
// key is final_graph_id,value is child graph execute order of final graph
|
||||
|
|
|
@ -616,8 +616,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
|
|||
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
|
||||
depend_mode = AnfAlgo::GetNodeAttr<int>(cnode, kControlDependMode);
|
||||
}
|
||||
MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
|
||||
<< "], depend_mode :" << depend_mode << ".";
|
||||
MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
|
||||
<< "], depend_mode :" << depend_mode << ".";
|
||||
if (prior_node->isa<Parameter>() && depend_mode == 1) {
|
||||
prior_nodes = GetOutputNodes(prior_node);
|
||||
}
|
||||
|
@ -647,7 +647,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(first_node);
|
||||
MS_EXCEPTION_IF_NULL(second_node);
|
||||
MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
|
||||
MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString()
|
||||
<< ",second node:" << second_node->DebugString();
|
||||
AddDependEdge(second_node, first_node, 1);
|
||||
}
|
||||
}
|
||||
|
@ -991,6 +992,30 @@ bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateChildGraphOrder() {
|
||||
MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
|
||||
SetExecOrderByDefault();
|
||||
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
|
||||
std::vector<KernelGraphPtr> child_graph_order;
|
||||
for (auto &call_node : call_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
|
||||
for (const auto &child_graph : call_child_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
if (child_graph != parent_graph_) {
|
||||
auto shared_this = std::dynamic_pointer_cast<KernelGraph>(shared_from_this());
|
||||
MS_EXCEPTION_IF_NULL(shared_this);
|
||||
child_graph->set_parent_graph(shared_this);
|
||||
}
|
||||
child_graph_order.push_back(child_graph);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < child_graph_order.size(); ++i) {
|
||||
MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
|
||||
}
|
||||
child_graph_order_ = child_graph_order;
|
||||
}
|
||||
|
||||
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
|
||||
|
||||
KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); }
|
||||
|
|
|
@ -156,6 +156,12 @@ class KernelGraph : public FuncGraph {
|
|||
bool IsFinalOutputKernel(const AnfNodePtr &node) const;
|
||||
uint32_t current_epoch() const { return current_epoch_; }
|
||||
void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
|
||||
void UpdateChildGraphOrder();
|
||||
const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
|
||||
void AddChildGraphResult(const AnfNodePtr ¶meter) { child_graph_result_.push_back(parameter); }
|
||||
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
|
||||
child_graph_result_ = child_graph_result;
|
||||
}
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
|
@ -173,6 +179,7 @@ class KernelGraph : public FuncGraph {
|
|||
void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
|
||||
|
||||
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
|
||||
std::vector<AnfNodePtr> child_graph_result_;
|
||||
std::vector<CNodePtr> execution_order_;
|
||||
uint32_t graph_id_;
|
||||
uint32_t stream_distinction_label_;
|
||||
|
|
|
@ -74,7 +74,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
return input_tensors[input_idx];
|
||||
}
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr";
|
||||
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr";
|
||||
}
|
||||
}
|
||||
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
|
||||
|
@ -107,8 +107,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
return tensor;
|
||||
}
|
||||
|
||||
BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
|
||||
|
@ -120,7 +120,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
VectorRef ret;
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
|
||||
auto out = CreateTensorForOutput(cnode->input(i), graph, input_tensors);
|
||||
ret.push_back(out);
|
||||
}
|
||||
return ret;
|
||||
|
@ -133,25 +133,6 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
|
|||
return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
|
||||
}
|
||||
|
||||
BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (!AnfAlgo::IsRealKernel(anf)) {
|
||||
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel";
|
||||
}
|
||||
if (anf->isa<ValueNode>()) {
|
||||
return CreateOneTensor(anf, 0, graph, input_tensors);
|
||||
}
|
||||
VectorRef ret;
|
||||
if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
|
||||
auto out = CreateOneTensor(anf, i, graph, input_tensors);
|
||||
ret.emplace_back(out);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -880,20 +861,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
const std::vector<tensor::TensorPtr> &input_tensors) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (!kernel_graph->child_graph_order().empty()) {
|
||||
// use the last child graph output as the root graph output
|
||||
UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors);
|
||||
return;
|
||||
}
|
||||
auto anf_outputs = kernel_graph->outputs();
|
||||
for (auto &item : anf_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
|
||||
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
|
||||
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
|
||||
continue;
|
||||
}
|
||||
outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
|
||||
outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -294,6 +294,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
auto graph_inputs = graph->inputs();
|
||||
auto graph_valid_input = graph->valid_inputs();
|
||||
graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
|
||||
std::vector<AnfNodePtr> need_alloc_nodes;
|
||||
for (size_t i = 0; i < graph_inputs.size(); ++i) {
|
||||
auto item = graph_inputs[i];
|
||||
|
|
|
@ -240,6 +240,7 @@ constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
|
|||
constexpr auto kAttrOffset = "offset";
|
||||
constexpr auto kAttrPsKey = "ps_key";
|
||||
constexpr auto kAttrOptimizerType = "optim_type";
|
||||
constexpr auto kAttrChildGraph = "child_graph";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
Loading…
Reference in New Issue