!2931 Ascend control flow not split graphs

Merge pull request !2931 from zhoufeng/liantiao1
This commit is contained in:
mindspore-ci-bot 2020-07-15 20:57:37 +08:00 committed by Gitee
commit 130cc29603
11 changed files with 766 additions and 271 deletions

View File

@ -40,6 +40,9 @@ using kernel::KernelBuildInfoPtr;
using kernel::KernelMod;
using kernel::KernelModPtr;
namespace {
constexpr size_t kNopNodeInputSize = 2;
constexpr size_t kNopNodeRealInputIndex = 1;
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
std::vector<size_t> shape_size_t;
@ -48,6 +51,26 @@ std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
}
} // namespace
AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
MS_EXCEPTION_IF_NULL(tuple_get_item);
if (tuple_get_item->size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
}
size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
MS_EXCEPTION_IF_NULL(tuple_get_item);
if (tuple_get_item->size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(output_index_value_node);
auto value_node = output_index_value_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return IntToSize(GetValue<int>(value_node->value()));
}
KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
if (anf_node->isa<ValueNode>()) {
@ -83,49 +106,47 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
}
}
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, int index,
bool visit_nop_node,
const std::vector<PrimitivePtr> &return_types) {
MS_EXCEPTION_IF_NULL(anf_node);
for (const auto &prim_type : return_types) {
if (CheckPrimitiveType(anf_node, prim_type)) {
return std::make_pair(anf_node, index);
}
if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
return CheckPrimitiveType(anf_node, prim_type);
})) {
return KernelWithIndex(anf_node, index);
}
if (anf_node->isa<ValueNode>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<Parameter>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input0 = cnode->input(0);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(input2);
auto value_node = input2->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
visit_nop_node, return_types);
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
} else if (opt::IsNopNode(cnode) && visit_nop_node) {
if (cnode->inputs().size() == 2) {
return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
} else {
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
}
} else {
return std::make_pair(anf_node, index);
}
} else {
MS_LOG(EXCEPTION) << "The input is invalid";
if (!anf_node->isa<CNode>()) {
return KernelWithIndex(anf_node, 0);
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode),
GetTupleGetItemOutIndex(cnode), visit_nop_node, return_types);
if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
const std::vector<AnfNodePtr> &make_tuple_inputs = make_tuple->inputs();
size_t make_tuple_input_index = item_with_index_tmp.second + 1;
if (make_tuple_input_index >= make_tuple_inputs.size()) {
MS_LOG(EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size()
<< "].";
}
return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, visit_nop_node, return_types);
}
return item_with_index_tmp;
}
if (CheckPrimitiveType(cnode, prim::kPrimDepend) || CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types);
}
if (opt::IsNopNode(cnode) && visit_nop_node) {
if (cnode->size() != kNopNodeInputSize) {
MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString();
}
return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, visit_nop_node, return_types);
}
return KernelWithIndex(anf_node, index);
}
std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
@ -591,7 +612,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
if (opt::IsNopNode(node) && visit_nop_node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() == 2) {
if (cnode->size() == kNopNodeInputSize) {
return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
} else {
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
@ -613,7 +634,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
if (opt::IsNopNode(node) && visit_nop_node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() == 2) {
if (cnode->inputs().size() == kNopNodeInputSize) {
return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
} else {
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
@ -806,7 +827,7 @@ bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
IsPrimitive(input, prim::kPrimReturn);
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
return !is_virtual_node;
}
@ -1117,5 +1138,14 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s
}
return GetCNodeOutputPrecision(kernel_with_index.first);
}
bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode.";
}
auto input = node->input(kAnfPrimitiveIndex);
return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
}
} // namespace session
} // namespace mindspore

View File

@ -42,9 +42,12 @@ using DeviceAddress = device::DeviceAddress;
using DeviceAddressPtr = device::DeviceAddressPtr;
class AnfRuntimeAlgorithm {
public:
// get real input node of tuple_get_item
static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
// get input_anf_node's real kernel by recurse
static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, int output_index,
bool visit_nop_node = false,
const std::vector<PrimitivePtr> &return_types = {
prim::kPrimMakeTuple});
@ -205,6 +208,7 @@ class AnfRuntimeAlgorithm {
static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.
static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
static bool IsCondControlKernel(const CNodePtr &node);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -17,6 +17,7 @@
#include "backend/session/ascend_control_parser.h"
#include <utility>
#include <memory>
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/union_find_set.h"
#include "runtime/device/ascend/ascend_label_assign.h"
@ -31,94 +32,11 @@ static constexpr size_t kCNodePartialLength = 2;
static constexpr size_t kCNodePartialFunc = 1;
static constexpr size_t kCNodeSwitchLayerBranch = 2;
static constexpr size_t kCNodeSwitchLayerLength = 3;
static constexpr size_t kCNodeAssignTarget = 1;
static constexpr size_t kCNodeAssignSource = 2;
namespace mindspore {
namespace session {
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
auto &nodes = parent_graph->execution_order();
CNodePtr last_jump_node = nullptr;
for (auto &node : nodes) {
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) {
if (child_graph->get_start_label() == node->input(kCNodeCallArg)) {
return node;
}
last_jump_node = node;
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) {
if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) {
return node;
}
last_jump_node = node;
}
}
if (last_jump_node == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
}
return last_jump_node;
}
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg.get()) != memo->end()) {
return;
}
memo->insert(kg.get());
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
for (auto &iter : real_inputs) {
auto &para = iter.first;
MS_EXCEPTION_IF_NULL(para);
if (para->isa<Parameter>()) {
union_find_set->Add(para);
}
for (auto &arg : iter.second) {
MS_EXCEPTION_IF_NULL(arg);
if (!arg->isa<Parameter>()) {
continue;
}
union_find_set->Add(arg);
}
}
for (auto &child : kg->child_graph_order()) {
InitUnionFindSet(NOT_NULL(child), union_find_set, memo);
}
}
static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg.get()) != memo->end()) {
return;
}
memo->insert(kg.get());
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
for (auto &iter : real_inputs) {
auto &para = iter.first;
for (auto &arg : iter.second) {
MS_EXCEPTION_IF_NULL(arg);
if (!arg->isa<Parameter>()) {
continue;
}
if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) {
continue;
}
union_find_set->Union(arg, para);
}
}
for (auto &child : kg->child_graph_order()) {
UnionParentParameter(NOT_NULL(child), union_find_set, memo);
}
}
static UnionFindSet<AnfNodePtr> MakeUnionFindSet(NotNull<KernelGraphPtr> root_kg) {
UnionFindSet<AnfNodePtr> result;
std::set<KernelGraphPtr> memo;
InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo));
memo.clear();
UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo));
return result;
}
static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> main_parameter,
const std::set<AnfNodePtr> &parameter_reuse_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
@ -135,8 +53,9 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
continue;
}
MS_EXCEPTION_IF_NULL(para);
MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to "
<< main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get());
MS_LOG(INFO) << "In " << kg->ToString() << " replace " << para->DebugString() << " of graph "
<< AnfAlgo::GetGraphId(para.get()) << " to " << main_parameter->DebugString() << " of graph "
<< AnfAlgo::GetGraphId(main_parameter.get().get());
kg->ReplaceNode(NOT_NULL(para), main_parameter);
}
@ -145,7 +64,7 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
}
}
static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr key,
static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr &key,
const std::set<AnfNodePtr> &parameter_reuse_set) {
AnfNodePtr main_parameter = key;
std::set<AnfNodePtr> root_inputs_set;
@ -160,8 +79,19 @@ static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNod
return main_parameter;
}
static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) {
auto parameter_reuse_sets = parameter_set->GetSets();
static void ReuseParameter(NotNull<KernelGraphPtr> root_kg,
const std::vector<std::pair<AnfNodePtr, AnfNodePtr>> &link_list) {
// make union find set
UnionFindSet<AnfNodePtr> union_find_set;
for (auto &[param, arg] : link_list) {
union_find_set.Add(param);
union_find_set.Add(arg);
}
for (auto &[param, arg] : link_list) {
union_find_set.Union(param, arg);
}
auto parameter_reuse_sets = union_find_set.GetSets();
for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) {
if (parameter_reuse_set.size() <= 1) {
continue;
@ -172,7 +102,7 @@ static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet
}
}
CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
for (size_t i = start; i < list.size() - 1; ++i) {
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) {
return list[i];
@ -181,71 +111,287 @@ CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
return nullptr;
}
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
std::set<KernelGraphPtr> memo;
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg);
std::map<uint32_t, KernelGraphPtr> graph_id_map;
for (auto &g : memo) {
MS_EXCEPTION_IF_NULL(g);
if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) {
MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id()
<< ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString();
static void UpdateLabelIdToLabelSetMap(const std::vector<CNodePtr> &exec_order,
const NotNull<std::map<uint32_t, CNodePtr> *> label_id_to_label_set) {
for (auto &node : exec_order) {
MS_EXCEPTION_IF_NULL(node);
if (!IsPrimitiveCNode(node, prim::kPrimLabelSet)) {
continue;
}
graph_id_map[g->graph_id()] = g;
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) {
MS_LOG(EXCEPTION) << node->DebugString() << " has no attr kAttrLabelIndex";
}
uint32_t label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
if (auto iter = label_id_to_label_set->find(label_id); iter != label_id_to_label_set->end()) {
MS_LOG(EXCEPTION) << "There are more than one node has same label id " << label_id
<< ", node: " << iter->second->DebugString() << " and " << node->DebugString();
}
(*label_id_to_label_set)[label_id] = node;
}
}
static std::vector<CNodePtr> GetTargetLabelSetNodes(NotNull<CNodePtr> jump_node,
const std::map<uint32_t, CNodePtr> &label_id_to_label_set) {
std::vector<uint32_t> target_label_list;
std::vector<CNodePtr> target_labelset_nodes;
if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelGoto)) {
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, jump_node)) {
MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kAttrLabelIndex";
}
uint32_t label_id = AnfAlgo::GetNodeAttr<uint32_t>(jump_node.get(), kAttrLabelIndex);
target_label_list.push_back(label_id);
} else if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelSwitch)) {
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, jump_node)) {
MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kPrimLabelSwitch";
}
target_label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(jump_node.get(), kAttrLabelSwitchList);
} else {
MS_LOG(EXCEPTION) << "Unknown type jump node " << jump_node->DebugString();
}
for (auto label_id : target_label_list) {
auto iter = label_id_to_label_set.find(label_id);
if (iter == label_id_to_label_set.end()) {
MS_LOG(EXCEPTION) << "Connot find LabelSet node has label id " << label_id;
}
target_labelset_nodes.push_back(iter->second);
}
return target_labelset_nodes;
}
static void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNull<std::vector<CNodePtr> *> exec_order) {
MS_EXCEPTION_IF_NULL(node);
auto exec_iter = std::find(exec_order->begin(), exec_order->end(), node);
if (exec_iter == exec_order->end()) {
MS_LOG(EXCEPTION) << "Cannot find " << node->DebugString() << " in exec order.";
}
exec_order->erase(exec_iter);
}
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
std::set<KernelGraphPtr> memo;
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> link_list;
// Insert Assign
ChildGraphDataAssign(graph_id_map);
// Make UnionFindSet
UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg);
ChildGraphDataAssign(kg, NOT_NULL(&link_list), NOT_NULL(&memo));
// Reuse Parameter
ReuseParameter(kg, NOT_NULL(&parameter_set));
ReuseParameter(kg, link_list);
// replace call by label goto / label switch
memo.clear();
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
// assign label resource
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg);
}
void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
const std::set<KernelGraphPtr> &graph_list) {
std::vector<CNodePtr> exec_order = root_graph->execution_order();
std::set<CNodePtr> search_list(exec_order.begin(), exec_order.end());
std::set<AnfNodePtr> root_inputs(root_graph->inputs().begin(), root_graph->inputs().end());
auto ref_map = root_graph->GetRefMap();
ReferenceCounter parameter_count([](int32_t read, int32_t write) -> bool { return write == 1; });
std::multimap<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> ref_multimap;
std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()),
[](const std::pair<std::pair<AnfNodePtr, size_t>, std::pair<AnfNodePtr, size_t>> &p)
-> std::pair<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> {
return {p.first.first, {p.first.second, p.second.first, p.second.second}};
});
std::set<CNodePtr> all_nodes;
std::map<AnfNodePtr, CNodePtr> para_to_written_node;
for (auto &graph : graph_list) {
auto out = graph->get_return();
MS_EXCEPTION_IF_NULL(out);
search_list.insert(out->cast<CNodePtr>());
auto nodes = TopoSort(out);
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode != nullptr) {
all_nodes.insert(cnode);
}
}
}
// prepare referance count
for (auto &node : search_list) {
MS_EXCEPTION_IF_NULL(node);
// if assign node
std::set<AnfNodePtr> refed_parameters;
for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) {
refed_parameters.insert(std::get<1>(iter->second));
}
for (auto &in : node->inputs()) {
auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first;
if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) {
continue;
}
if (refed_parameters.find(visit_node) != refed_parameters.end()) {
parameter_count.AddWriteCount(visit_node, 1);
para_to_written_node[visit_node] = node;
} else {
parameter_count.AddReadCount(visit_node, 1);
}
}
}
while (parameter_count.HasValidElem()) {
auto [para, read, written] = parameter_count.GetOneValidElem();
MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times.";
auto assign_iter = para_to_written_node.find(para);
if (assign_iter == para_to_written_node.end()) {
MS_LOG(EXCEPTION) << "Cannot find assign node that write " << para->DebugString();
}
auto &assign_node = assign_iter->second;
MS_EXCEPTION_IF_NULL(assign_node);
if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign)) {
parameter_count.EraseElem(para);
continue;
}
MS_LOG(INFO) << "Erase " << assign_node->DebugString(5);
EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order));
auto source = AnfAlgo::VisitKernelWithReturnType(assign_node->input(kCNodeAssignSource), 0).first;
parameter_count.AddReadCount(source, -1);
parameter_count.AddWriteCount(para, -1);
for (auto &node : all_nodes) {
for (size_t i = 0; i < node->size(); ++i) {
if (node->input(i) == para) {
MS_LOG_INFO << "Replace " << node->DebugString() << " input " << i << " by " << source->DebugString();
node->set_input(i, source);
}
}
}
parameter_count.AddReadCount(source, 1);
parameter_count.AddReadCount(para, -1);
}
root_graph->set_execution_order(exec_order);
}
void AscendControlParser::EraseLabel(NotNull<KernelGraphPtr> root_graph) {
std::vector<CNodePtr> exec_order = root_graph->execution_order();
ReferenceCounter label_count([](int32_t read, int32_t write) -> bool { return read <= 1; });
std::map<AnfNodePtr, CNodePtr> label_to_written_node;
std::map<uint32_t, CNodePtr> label_id_to_label_set;
UpdateLabelIdToLabelSetMap(exec_order, NOT_NULL(&label_id_to_label_set));
CNodePtr last_node = nullptr;
for (auto &cur_node : exec_order) {
MS_EXCEPTION_IF_NULL(cur_node);
if (AnfAlgo::IsCondControlKernel(cur_node)) {
std::vector<CNodePtr> target_labelset_nodes = GetTargetLabelSetNodes(NOT_NULL(cur_node), label_id_to_label_set);
for (auto &label_set : target_labelset_nodes) {
label_count.AddReadCount(label_set, 1);
label_to_written_node[label_set] = cur_node;
}
} else if (IsPrimitiveCNode(cur_node, prim::kPrimLabelSet)) {
label_count.AddWriteCount(cur_node, 1);
if (last_node != nullptr && !AnfAlgo::IsCondControlKernel(last_node)) {
label_count.AddReadCount(cur_node, 1);
label_to_written_node[cur_node] = last_node;
}
}
last_node = cur_node;
}
while (label_count.HasValidElem()) {
auto [label_set, read, written] = label_count.GetOneValidElem();
MS_LOG(INFO) << label_set->DebugString() << " was read " << read << " times, written " << written << " times.";
auto iter = label_to_written_node.find(label_set);
if (read > 0 && iter == label_to_written_node.end()) {
MS_LOG(EXCEPTION) << "Cannot find node jump to " << label_set->DebugString();
}
CNodePtr jump_node = read > 0 ? iter->second : nullptr;
if (jump_node == nullptr || IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) {
MS_LOG(INFO) << "Erase node " << label_set->DebugString();
EraseNodeFromExecOrder(label_set, NOT_NULL(&exec_order));
}
if (jump_node != nullptr && IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) {
MS_LOG(INFO) << "Erase node " << jump_node->DebugString();
EraseNodeFromExecOrder(jump_node, NOT_NULL(&exec_order));
}
label_count.EraseElem(label_set);
}
root_graph->set_execution_order(exec_order);
}
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
std::set<KernelGraphPtr> memo;
(void)RecurseGraph(root_graph, NOT_NULL(&memo));
EraseParameter(root_graph, memo);
EraseLabel(root_graph);
}
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
for (auto &iter : graph_id_map) {
auto &kg = iter.second;
MS_LOG(INFO) << "Data assign graph:" << kg->graph_id();
MS_EXCEPTION_IF_NULL(kg);
std::set<std::pair<AnfNodePtr, AnfNodePtr>> memo;
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
for (auto &it : real_inputs) {
auto &parameter = it.first;
auto &args = it.second;
for (auto &arg : args) {
MS_EXCEPTION_IF_NULL(arg);
if (memo.find({parameter, arg}) != memo.end()) {
continue;
} else {
memo.emplace(parameter, arg);
}
auto unreuse_args_map = kg->unreuse_args();
auto unreuse_arg_iter = unreuse_args_map.find(arg);
if (unreuse_arg_iter == unreuse_args_map.end()) {
MS_EXCEPTION_IF_NULL(arg);
MS_EXCEPTION_IF_NULL(parameter);
if (!arg->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << ".";
}
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
<< ", arg:" << arg->DebugString();
std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallNode(
NotNull<CNodePtr> call_node) {
std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ret;
if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) {
MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node.";
}
if (call_node->size() <= kCNodeCallArg) {
MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size();
}
const std::vector<AnfNodePtr> &call_node_inputs = call_node->inputs();
auto call_arg = call_node_inputs[kCNodeCallArg];
MS_EXCEPTION_IF_NULL(call_arg);
if (IsValueNode<KernelGraph>(call_arg)) {
ret.emplace_back(GetValueNode<KernelGraphPtr>(call_arg),
std::vector<AnfNodePtr>(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end()));
} else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) {
auto switch_cnode = call_arg->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
const std::vector<AnfNodePtr> &switch_inputs = switch_cnode->inputs();
if (switch_inputs.size() <= kCNodeSwitchCond) {
MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size "
<< switch_inputs.size();
}
for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) {
const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter));
ret.emplace_back(target_graph, args);
}
} else {
MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5);
}
return ret;
}
void AscendControlParser::ChildGraphDataAssign(
NotNull<KernelGraphPtr> kg, const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg) != memo->end()) {
return;
}
memo->insert(kg.get());
MS_LOG(INFO) << "Start link data for " << kg->ToString();
const std::vector<CNodePtr> &nodes = kg->execution_order();
for (auto &node : nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimCall)) {
continue;
}
auto child_graph_list = ParseCallNode(NOT_NULL(node));
for (auto &[child_graph, args] : child_graph_list) {
MS_EXCEPTION_IF_NULL(child_graph);
const std::vector<AnfNodePtr> &params = child_graph->inputs();
if (args.size() != params.size()) {
MS_LOG(EXCEPTION) << child_graph->ToString() << " needs " << params.size() << " inputs but call node "
<< node->DebugString(5) << " gives " << args.size();
}
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->isa<Parameter>() && memo->find(child_graph) == memo->end()) {
MS_LOG(INFO) << args[i]->DebugString() << " to " << params[i]->DebugString()
<< " should be reused, continue.";
link_list->emplace_back(args[i], params[i]);
continue;
}
auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get()));
if (target_graph_iter == graph_id_map.end()) {
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
}
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg),
NOT_NULL(parameter));
InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i]));
}
}
kg->SetExecOrderByDefault();
}
kg->SetExecOrderByDefault();
for (auto &child_graph : kg->child_graph_order()) {
ChildGraphDataAssign(NOT_NULL(child_graph), link_list, memo);
}
}
@ -325,7 +471,7 @@ void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNul
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
return_node->input(kFirstDataInputIndex), attch_node.get()};
auto depend_node = kg->NewCNode(inputs);
return_node->set_input(1, depend_node);
return_node->set_input(kFirstDataInputIndex, depend_node);
}
void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
@ -381,6 +527,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
new_inputs.push_back(sub_label);
cur_node->set_inputs(new_inputs);
cur_node->set_abstract(nullptr);
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get());
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
}
@ -409,9 +556,12 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
std::vector<AnfNodePtr> new_switch_inputs = {
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
origin_switch_inputs[kCNodeSwitchCond]};
std::vector<KernelGraphPtr> child_graphs;
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
// 3.1 branch kernel graph and args
KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
KernelGraphPtr branch_fg;
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
child_graphs.push_back(branch_fg);
// 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label);
@ -420,6 +570,7 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
cur_node->set_inputs(new_switch_inputs);
cur_node->set_abstract(nullptr);
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), cur_node.get());
MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString();
}
@ -453,9 +604,12 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
std::vector<AnfNodePtr> new_switch_inputs = {
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
origin_switch_inputs[kCNodeSwitchCond]};
std::vector<KernelGraphPtr> child_graphs;
for (size_t i = 0; i < branch_partial.size(); ++i) {
// 3.1 branch kernel graph and args
KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
KernelGraphPtr branch_fg;
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
child_graphs.push_back(branch_fg);
// 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label);
@ -463,13 +617,14 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
cur_node->set_inputs(new_switch_inputs);
cur_node->set_abstract(nullptr);
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), cur_node.get());
MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString();
}
KernelGraphPtr AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
if (!node.get()->isa<CNode>()) {
if (IsValueNode<KernelGraph>(node)) {
return GetValueNode<KernelGraphPtr>(node);
return {GetValueNode<KernelGraphPtr>(node), {}};
}
MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString();
}
@ -485,12 +640,11 @@ KernelGraphPtr AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << ".";
}
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
return branch_kg;
return {branch_kg, std::vector<AnfNodePtr>(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end())};
}
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) {
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, const AnfNodePtr &jump_node,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to) {
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]";
@ -500,22 +654,35 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr
}
for (size_t i = 0; i < from_outputs.size(); i++) {
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
if (assign_node != nullptr) {
auto jump_node = GetJumpNode(from_graph, to_graph);
const auto &from_graph_exe_order = from_graph->execution_order();
auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node);
if (jump_node_iter == from_graph_exe_order.end()) {
MS_EXCEPTION_IF_NULL(jump_node);
MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id();
}
// insert assign between jump_node -1 and jump_node
if (jump_node_iter != from_graph_exe_order.begin()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
}
if (jump_node != nullptr) {
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
const auto &from_graph_exe_order = from_graph->execution_order();
std::vector<CNodePtr> real_exe_order(from_graph_exe_order.size());
size_t real_exe_order_size = 0;
std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(),
[&real_exe_order_size](const CNodePtr &node) -> bool {
return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial))
? false
: (++real_exe_order_size, true);
});
real_exe_order.resize(real_exe_order_size);
if (jump_node == nullptr) {
if (!real_exe_order.empty()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node));
} else {
InsertDependToGraph(from_graph, NOT_NULL(assign_node));
}
continue;
}
auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node);
if (jump_node_iter == real_exe_order.end()) {
MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph "
<< from_graph->ToString();
}
// insert assign between jump_node -1 and jump_node
if (jump_node_iter != real_exe_order.begin()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
}
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
}
}
@ -618,26 +785,45 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
}
}
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
MS_LOG(INFO) << "Graph id:" << kg->graph_id();
kg->SetExecOrderByDefault();
auto call_nodes = kg->FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
std::vector<KernelGraphPtr> child_graph_order;
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
for (const auto &child_graph : call_child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
if (child_graph != kg->parent_graph()) {
child_graph->set_parent_graph(kg.get());
}
child_graph_order.push_back(child_graph);
}
void AscendControlParser::ReferenceCounter::AddReadCount(const AnfNodePtr &key, int32_t num) {
auto iter = count_.find(key);
if (iter != count_.end()) {
iter->second.first += num;
} else {
count_[key] = {num, 0};
}
for (size_t i = 0; i < child_graph_order.size(); i++) {
MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
}
void AscendControlParser::ReferenceCounter::AddWriteCount(const AnfNodePtr &key, int32_t num) {
auto iter = count_.find(key);
if (iter != count_.end()) {
iter->second.second += num;
} else {
count_[key] = {0, num};
}
kg->set_child_graph_order(child_graph_order);
}
void AscendControlParser::ReferenceCounter::EraseElem(const AnfNodePtr &key) { count_.erase(key); }
bool AscendControlParser::ReferenceCounter::HasValidElem() const {
auto it = std::find_if(count_.begin(), count_.end(),
[this](const std::pair<AnfNodePtr, std::pair<uint32_t, uint32_t>> &p) -> bool {
auto &[read, written] = p.second;
return predicate_(read, written);
});
return it != count_.end();
}
std::tuple<AnfNodePtr, int32_t, int32_t> AscendControlParser::ReferenceCounter::GetOneValidElem() const {
auto it = std::find_if(count_.begin(), count_.end(),
[this](const std::pair<AnfNodePtr, std::pair<uint32_t, uint32_t>> &p) -> bool {
auto &[read, written] = p.second;
return predicate_(read, written);
});
if (it == count_.end()) {
MS_LOG(EXCEPTION) << "No valid parameter.";
}
return {it->first, it->second.first, it->second.second};
}
} // namespace session
} // namespace mindspore

View File

@ -20,6 +20,8 @@
#include <map>
#include <vector>
#include <tuple>
#include <utility>
#include <functional>
#include "backend/session/kernel_graph.h"
#include "utils/base_ref.h"
#include "utils/contract.h"
@ -29,16 +31,23 @@ namespace mindspore {
namespace session {
class AscendControlParser {
public:
static void ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map);
static void LinkGraph(NotNull<KernelGraphPtr> kg);
static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node);
static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
NotNull<AnfNodePtr> second_node);
static void ExecutorValidate(NotNull<KernelGraphPtr> root_graph);
static void UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, const AnfNodePtr &jump_node,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
private:
class ReferenceCounter;
static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
static void EraseLabel(NotNull<KernelGraphPtr> root_graph);
static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg,
const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,
const NotNull<std::set<KernelGraphPtr> *> memo);
static NotNull<CNodePtr> GetStartLabel(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label);
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
@ -53,11 +62,10 @@ class AscendControlParser {
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label);
static KernelGraphPtr ParsePartial(NotNull<AnfNodePtr> node);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallNode(NotNull<CNodePtr> call_node);
static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node);
// root graph order
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
@ -65,6 +73,19 @@ class AscendControlParser {
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo);
};
class AscendControlParser::ReferenceCounter {
public:
explicit ReferenceCounter(std::function<bool(int32_t, int32_t)> func) : predicate_(func), count_() {}
void AddReadCount(const AnfNodePtr &key, int32_t num);
void AddWriteCount(const AnfNodePtr &key, int32_t num);
void EraseElem(const AnfNodePtr &key);
bool HasValidElem() const;
std::tuple<AnfNodePtr, int32_t, int32_t> GetOneValidElem() const;
private:
std::function<bool(int32_t, int32_t)> predicate_;
std::map<AnfNodePtr, std::pair<int32_t, int32_t>> count_;
};
} // namespace session
} // namespace mindspore

View File

@ -289,6 +289,17 @@ static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
// this action should from bottom to top
graph->UpdateCallRealInput();
}
void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
auto return_node = root_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
if (return_node->size() <= kReturnDataIndex) {
return;
}
auto make_tuple = root_graph->NewCNode(
{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
root_graph->set_output(make_tuple);
}
} // namespace
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
@ -305,22 +316,39 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
std::vector<KernelGraphPtr> all_graphs;
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
BackendOptimization(all_graphs);
// split switch
SplitGraphs(NOT_NULL(root_graph));
// empty graph dont entry to backend
if (root_graph->execution_order().empty()) {
MS_LOG(INFO) << root_graph->ToString() << " is empty graph.";
InsertMakeTupleForOutput(NOT_NULL(root_graph));
root_graph->set_executable(false);
InitRuntimeResource();
return root_graph->graph_id();
}
// create parameter for multiple branch
std::set<KernelGraphPtr> memo;
CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
// insert goto labels and label_sets
LinkChildGraphs(NOT_NULL(root_graph));
// resource initialize
InitRuntimeResource();
// recurse compile child root_graph
std::set<KernelGraphPtr> memo;
RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo));
IrFusionPass(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
SelectKernel(NOT_NULL(root_graph));
memo.clear();
HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
// add make_tuple to the output graph
InsertMakeTupleForOutput(NOT_NULL(root_graph));
// root root_graph valiate,include genearte execute order and so on
RootGraphExecutorValidate(NOT_NULL(root_graph));
// adjust kernel
@ -1682,7 +1710,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
bool split_flag = false;
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
// update the root graph child graph order
AscendControlParser::UpdateChildGraphOrder(graph);
graph->UpdateChildGraphOrder();
// get child list from current graph
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(apply_list, cut_prims);
if (child_graph_lists.size() > 1) {
@ -1714,7 +1742,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
}
split_flag = true;
}
AscendControlParser::UpdateChildGraphOrder(graph);
graph->UpdateChildGraphOrder();
UpdateRealInput(graph, split_flag, memo);
MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end";
}
@ -1753,5 +1781,216 @@ void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const Not
}
}
}
void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(graph.get()) != memo->end()) {
return;
}
memo->insert(graph.get());
graph->UpdateChildGraphOrder();
for (auto &child_graph : graph->child_graph_order()) {
CreateMultiBranchOutput(NOT_NULL(child_graph), memo);
}
std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
auto node_list = GetCNodes(TopoSort(graph->get_return()));
for (auto &node : node_list) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
// auto multi_output_param = graph->NewParameter();
auto origin_inputs = graph->inputs();
auto output_param = CreateNewParameterFromCNode(node, true, graph.get().get());
MS_EXCEPTION_IF_NULL(graph->MutableInputs());
graph->MutableInputs()->operator=(origin_inputs);
graph->AddChildGraphResult(output_param);
std::vector<AnfNodePtr> depend_inputs = {
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name()))), output_param, node};
auto depend = graph->NewCNode(depend_inputs);
need_replace_list.emplace(node, depend);
MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString()
<< ", depend node is " << depend->DebugString();
// insert assign in order to transfer child graph output to parameter
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node);
for (auto &child_graph : child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
if (child_graph->get_output_null()) {
continue;
}
auto graph_output = child_graph->output();
AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, NOT_NULL(graph_output),
NOT_NULL(output_param));
}
}
}
// searching for nodes' input to replace call by depend(parameter, call)
for (auto &node : node_list) {
for (size_t i = 0; i < node->size(); ++i) {
auto input = node->input(i);
auto iter = need_replace_list.find(input);
if (iter != need_replace_list.end()) {
node->set_input(i, iter->second);
}
}
}
}
void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph.get());
opt::AscendBackendIRFusionOptimization(graph);
opt::AscendBackendFuseBasicOpt(graph, true);
opt::AscendBackendGraphKernelOpt(graph, true);
graph->SetExecOrderByDefault();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs) {
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
std::string file_path =
save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_path, graph.get());
}
for (auto &child_graph : graph->child_graph_order()) {
IrFusionPass(NOT_NULL(child_graph), memo);
}
}
void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
MS_LOG(INFO) << "Start select kernel.";
size_t raise_precision_count = 0;
size_t reduce_precision_count = 0;
std::set<KernelGraphPtr> memo;
(void)RecurseSelectKernelInfo(root_graph, NOT_NULL(&memo), &raise_precision_count, &reduce_precision_count);
memo.clear();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kGraphMode) {
if (raise_precision_count > 0) {
MS_LOG(WARNING) << "There has " << raise_precision_count
<< " node/nodes used raise precision to selected the kernel!";
}
if (reduce_precision_count > 0) {
MS_LOG(WARNING) << "There has " << raise_precision_count
<< " node/nodes used reduce precision to selected the kernel!";
}
}
MS_LOG(INFO) << "Finish!";
}
void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
NotNull<std::set<KernelGraphPtr> *> const memo,
size_t *const raise_precision_count,
size_t *const reduce_precision_count) const {
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph.get());
MS_LOG(INFO) << "Start to select kernel info in graph: " << graph->graph_id();
for (const auto &cnode : graph->execution_order()) {
if (AnfAlgo::IsCondControlKernel(cnode)) {
std::vector<KernelGraphPtr> child_graphs;
if (AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)) {
child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph);
}
for (auto &child_graph : child_graphs) {
RecurseSelectKernelInfo(NOT_NULL(child_graph), memo, raise_precision_count, reduce_precision_count);
}
}
auto status = device::ascend::SelectKernelInfo(cnode);
if (status == device::ascend::kStatusRaisePrecision) {
(*raise_precision_count)++;
} else if (status == device::ascend::kStatusReducePrecision) {
(*reduce_precision_count)++;
}
MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs) {
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
std::string file_path =
save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_path, graph.get());
}
MS_LOG(INFO) << "Finish selecting kernel info in graph: " << graph->graph_id();
}
void AscendSession::HardwareOptimize(NotNull<KernelGraphPtr> graph,
NotNull<std::set<KernelGraphPtr> *> const memo) const {
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph.get());
MS_LOG(INFO) << "Start to do HardwareOptimize in graph: " << graph->graph_id();
// convert kernel Graph to model
predictmodel::StepConvertGraph(graph.get());
HardwareOptimize(graph.get());
for (auto &child_graph : graph->child_graph_order()) {
HardwareOptimize(NOT_NULL(child_graph), memo);
}
MS_LOG(INFO) << "Finish doing HardwareOptimize in graph: " << graph->graph_id();
}
void AscendSession::AssignStaticMemory(NotNull<KernelGraphPtr> graph,
NotNull<std::set<KernelGraphPtr> *> const memo) const {
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph.get());
MS_LOG(INFO) << "Start to assign static memory for parameter in graph: " << graph->graph_id();
// assign static memory for parameters
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->AssignStaticMemoryInput(graph.get().get());
runtime_instance->AssignStaticMemoryValueNode(graph.get().get());
for (auto &child_graph : graph->child_graph_order()) {
AssignStaticMemory(NOT_NULL(child_graph), memo);
}
MS_LOG(INFO) << "Finish assigning static memory for parameter in graph: " << graph->graph_id();
}
void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
NotNull<std::set<KernelGraphPtr> *> const memo) const {
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph.get());
for (auto &child_graph : graph->child_graph_order()) {
UpdateRefOutputMap(NOT_NULL(child_graph), memo);
// copy ref map to final graph
auto child_ref_map = child_graph->GetRefMap();
for (auto &item : child_ref_map) {
if (graph->IsInRefOutputMap(item.first)) {
MS_LOG(WARNING) << "The ref pair <" << item.first.first->DebugString() << ", " << item.first.second
<< "> is already in " << graph->ToString();
continue;
}
graph->AddRefCorrespondPairs(item.first, item.second);
}
}
}
} // namespace session
} // namespace mindspore

View File

@ -151,6 +151,15 @@ class AscendSession : public SessionBasic {
// sync intial tensors' data to device
void SyncInitialTenosrToDevice();
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
// create parameter to receive data from multiple branch output
void CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
void SelectKernel(NotNull<KernelGraphPtr> root_graph);
void RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> const memo,
size_t *const raise_precision_count, size_t *const reduce_precision_count) const;
void IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
void HardwareOptimize(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
// member variables
// key is final_graph_id,value is child graph execute order of final graph

View File

@ -616,8 +616,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
depend_mode = AnfAlgo::GetNodeAttr<int>(cnode, kControlDependMode);
}
MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
<< "], depend_mode :" << depend_mode << ".";
MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
<< "], depend_mode :" << depend_mode << ".";
if (prior_node->isa<Parameter>() && depend_mode == 1) {
prior_nodes = GetOutputNodes(prior_node);
}
@ -647,7 +647,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
}
MS_EXCEPTION_IF_NULL(first_node);
MS_EXCEPTION_IF_NULL(second_node);
MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString()
<< ",second node:" << second_node->DebugString();
AddDependEdge(second_node, first_node, 1);
}
}
@ -991,6 +992,30 @@ bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const {
return false;
}
void KernelGraph::UpdateChildGraphOrder() {
MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
SetExecOrderByDefault();
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
std::vector<KernelGraphPtr> child_graph_order;
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
for (const auto &child_graph : call_child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
if (child_graph != parent_graph_) {
auto shared_this = std::dynamic_pointer_cast<KernelGraph>(shared_from_this());
MS_EXCEPTION_IF_NULL(shared_this);
child_graph->set_parent_graph(shared_this);
}
child_graph_order.push_back(child_graph);
}
}
for (size_t i = 0; i < child_graph_order.size(); ++i) {
MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
}
child_graph_order_ = child_graph_order;
}
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); }

View File

@ -156,6 +156,12 @@ class KernelGraph : public FuncGraph {
bool IsFinalOutputKernel(const AnfNodePtr &node) const;
uint32_t current_epoch() const { return current_epoch_; }
void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
void UpdateChildGraphOrder();
const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
void AddChildGraphResult(const AnfNodePtr &parameter) { child_graph_result_.push_back(parameter); }
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
child_graph_result_ = child_graph_result;
}
private:
// remove value node form graph
@ -173,6 +179,7 @@ class KernelGraph : public FuncGraph {
void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
std::vector<AnfNodePtr> child_graph_result_;
std::vector<CNodePtr> execution_order_;
uint32_t graph_id_;
uint32_t stream_distinction_label_;

View File

@ -74,7 +74,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
return input_tensors[input_idx];
}
}
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr";
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr";
}
}
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
@ -107,8 +107,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
return tensor;
}
BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(anf);
MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
@ -120,7 +120,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
MS_EXCEPTION_IF_NULL(cnode);
VectorRef ret;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
auto out = CreateTensorForOutput(cnode->input(i), graph, input_tensors);
ret.push_back(out);
}
return ret;
@ -133,25 +133,6 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
}
BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(anf);
if (!AnfAlgo::IsRealKernel(anf)) {
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel";
}
if (anf->isa<ValueNode>()) {
return CreateOneTensor(anf, 0, graph, input_tensors);
}
VectorRef ret;
if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
auto out = CreateOneTensor(anf, i, graph, input_tensors);
ret.emplace_back(out);
}
}
return ret;
}
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
@ -880,20 +861,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
const std::vector<tensor::TensorPtr> &input_tensors) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
if (!kernel_graph->child_graph_order().empty()) {
// use the last child graph output as the root graph output
UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors);
return;
}
auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
continue;
}
outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors));
}
}

View File

@ -294,6 +294,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(mem_manager_);
auto graph_inputs = graph->inputs();
auto graph_valid_input = graph->valid_inputs();
graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
std::vector<AnfNodePtr> need_alloc_nodes;
for (size_t i = 0; i < graph_inputs.size(); ++i) {
auto item = graph_inputs[i];

View File

@ -240,6 +240,7 @@ constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
constexpr auto kAttrOffset = "offset";
constexpr auto kAttrPsKey = "ps_key";
constexpr auto kAttrOptimizerType = "optim_type";
constexpr auto kAttrChildGraph = "child_graph";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";