Add sort by index for each group of AllReduce

This commit is contained in:
yujianfeng 2020-06-16 19:47:30 +08:00
parent 6089d58d8d
commit f15cb6b7c9
5 changed files with 95 additions and 11 deletions

View File

@ -91,6 +91,30 @@ kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32});
return builder.Build();
}
bool CheckInputNamesSize(const CNodePtr &cnode) {
auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames);
if (input_names_vec.size() < kTopkIndexK + 1) {
MS_LOG(INFO) << "The input k of topk has been converted to attr";
return false;
}
return true;
}
bool CheckOutputShape(const AnfNodePtr &node) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
if (shape.empty()) {
MS_LOG(INFO) << "The output shape of topk to split must not be empty";
return false;
}
auto last_dim = shape[shape.size() - 1];
const size_t kMaxFloat16 = 65500;
if (last_dim > kMaxFloat16) {
MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops.";
return false;
}
return true;
}
} // namespace
const BaseRef TopKSplit::DefinePattern() const {
@ -107,16 +131,10 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
// set value node as topk's input
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames);
if (input_names_vec.size() < kTopkIndexK + 1) {
MS_LOG(INFO) << "The input k of topk has been converted to attr";
if (!CheckInputNamesSize(cnode)) {
return nullptr;
}
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
auto last_dim = shape[shape.size() - 1];
const size_t kMaxFloat16 = 65500;
if (last_dim > kMaxFloat16) {
MS_LOG(INFO) << "The last dim is more than 65500, switch to aicpu ops.";
if (!CheckOutputShape(cnode)) {
return nullptr;
}
// Copy a new node to check supported.

View File

@ -253,6 +253,13 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
if (it.second.communication_op_nodes.size() <= 1) {
continue;
}
auto first_node = it.second.communication_op_nodes[0];
if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr<int>(first_node, kAttrIndex) > 0) {
std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(),
[](const CNodePtr &a, const CNodePtr &b) {
return AnfAlgo::GetNodeAttr<int>(a, kAttrIndex) < AnfAlgo::GetNodeAttr<int>(b, kAttrIndex);
});
}
size_t segment_num = 0;
std::vector<size_t> segment_index;
if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) {

View File

@ -209,6 +209,7 @@ constexpr auto kAttrRecordEvent = "record_event";
constexpr auto kAttrWaitEvent = "wait_event";
constexpr auto kAttrRecordEventStream = "record_event_stream";
constexpr auto kAttrWaitEventStream = "wait_event_stream";
constexpr auto kAttrIndex = "index";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";

View File

@ -58,7 +58,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_all) {
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
auto node_list = TopoSort(func_graph->get_return());
for (auto& node : node_list) {
for (auto &node : node_list) {
if (node == nullptr) {
continue;
}
@ -99,7 +99,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_group) {
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
auto node_list = TopoSort(func_graph->get_return());
for (auto& node : node_list) {
for (auto &node : node_list) {
if (node == nullptr) {
continue;
}
@ -141,7 +141,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) {
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
auto node_list = TopoSort(func_graph->get_return());
int count = 0;
for (auto& node : node_list) {
for (auto &node : node_list) {
if (node == nullptr) {
continue;
}
@ -171,5 +171,52 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) {
EXPECT_NE(g_after, nullptr);
EXPECT_TRUE(CheckEqualGraph(new_graph, g_after));
}
TEST_F(TestHWAllReduceFusion, test_fusion_sorted) {
getPyFun_.SetDoResolve(true);
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp_x{1, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract};
auto func_graph = GetKernelGraph(g, args_spec_list);
EXPECT_NE(func_graph, nullptr);
auto ret = func_graph->get_return();
auto make_tuple = ret->input(1);
auto make_tuple1 = make_tuple->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
for (size_t i = 1; i < make_tuple1->inputs().size(); ++i) {
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(SizeToInt(i)), make_tuple1->input(i));
}
// set kernel build info
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({"NC1HWC0"});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (node == nullptr) {
continue;
}
if ((node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) || node->isa<Parameter>()) {
node->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
}
}
// do all reduce fusion
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
EXPECT_NE(new_graph, nullptr);
// check result
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "after1");
EXPECT_NE(g_after, nullptr);
EXPECT_TRUE(CheckEqualGraph(new_graph, g_after));
}
} // namespace opt
} // namespace mindspore

View File

@ -140,6 +140,17 @@ def test_all_reduce_fusion_all(tag):
res = make_tuple(y1, y2, y3, y4, y5)
return make_tuple(res)
@fns
def after1(x1, x2, x3, x4, x5):
ar = allreduce(x1, x2, x3, x4, x5)
y1 = tuple_getitem(ar, 0)
y2 = tuple_getitem(ar, 1)
y3 = tuple_getitem(ar, 2)
y4 = tuple_getitem(ar, 3)
y5 = tuple_getitem(ar, 4)
res = make_tuple(y1, y2, y3, y4, y5)
return make_tuple(res)
return fns[tag]