forked from mindspore-Ecosystem/mindspore
Add sort by index for each group of AllReduce
This commit is contained in:
parent
6089d58d8d
commit
f15cb6b7c9
|
@ -91,6 +91,30 @@ kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
|
|||
builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32});
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
bool CheckInputNamesSize(const CNodePtr &cnode) {
|
||||
auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames);
|
||||
if (input_names_vec.size() < kTopkIndexK + 1) {
|
||||
MS_LOG(INFO) << "The input k of topk has been converted to attr";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CheckOutputShape(const AnfNodePtr &node) {
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
if (shape.empty()) {
|
||||
MS_LOG(INFO) << "The output shape of topk to split must not be empty";
|
||||
return false;
|
||||
}
|
||||
auto last_dim = shape[shape.size() - 1];
|
||||
const size_t kMaxFloat16 = 65500;
|
||||
if (last_dim > kMaxFloat16) {
|
||||
MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef TopKSplit::DefinePattern() const {
|
||||
|
@ -107,16 +131,10 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
|||
// set value node as topk's input
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames);
|
||||
if (input_names_vec.size() < kTopkIndexK + 1) {
|
||||
MS_LOG(INFO) << "The input k of topk has been converted to attr";
|
||||
if (!CheckInputNamesSize(cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
auto last_dim = shape[shape.size() - 1];
|
||||
const size_t kMaxFloat16 = 65500;
|
||||
if (last_dim > kMaxFloat16) {
|
||||
MS_LOG(INFO) << "The last dim is more than 65500, switch to aicpu ops.";
|
||||
if (!CheckOutputShape(cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
// Copy a new node to check supported.
|
||||
|
|
|
@ -253,6 +253,13 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
if (it.second.communication_op_nodes.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
auto first_node = it.second.communication_op_nodes[0];
|
||||
if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr<int>(first_node, kAttrIndex) > 0) {
|
||||
std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(),
|
||||
[](const CNodePtr &a, const CNodePtr &b) {
|
||||
return AnfAlgo::GetNodeAttr<int>(a, kAttrIndex) < AnfAlgo::GetNodeAttr<int>(b, kAttrIndex);
|
||||
});
|
||||
}
|
||||
size_t segment_num = 0;
|
||||
std::vector<size_t> segment_index;
|
||||
if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) {
|
||||
|
|
|
@ -209,6 +209,7 @@ constexpr auto kAttrRecordEvent = "record_event";
|
|||
constexpr auto kAttrWaitEvent = "wait_event";
|
||||
constexpr auto kAttrRecordEventStream = "record_event_stream";
|
||||
constexpr auto kAttrWaitEventStream = "wait_event_stream";
|
||||
constexpr auto kAttrIndex = "index";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -58,7 +58,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_all) {
|
|||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto& node : node_list) {
|
||||
for (auto &node : node_list) {
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
@ -99,7 +99,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_group) {
|
|||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto& node : node_list) {
|
||||
for (auto &node : node_list) {
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) {
|
|||
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
int count = 0;
|
||||
for (auto& node : node_list) {
|
||||
for (auto &node : node_list) {
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
@ -171,5 +171,52 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) {
|
|||
EXPECT_NE(g_after, nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(new_graph, g_after));
|
||||
}
|
||||
|
||||
TEST_F(TestHWAllReduceFusion, test_fusion_sorted) {
|
||||
getPyFun_.SetDoResolve(true);
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp_x{1, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract};
|
||||
auto func_graph = GetKernelGraph(g, args_spec_list);
|
||||
EXPECT_NE(func_graph, nullptr);
|
||||
auto ret = func_graph->get_return();
|
||||
auto make_tuple = ret->input(1);
|
||||
auto make_tuple1 = make_tuple->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < make_tuple1->inputs().size(); ++i) {
|
||||
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(SizeToInt(i)), make_tuple1->input(i));
|
||||
}
|
||||
// set kernel build info
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({"NC1HWC0"});
|
||||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if ((node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) || node->isa<Parameter>()) {
|
||||
node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
|
||||
}
|
||||
}
|
||||
// do all reduce fusion
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
EXPECT_NE(new_graph, nullptr);
|
||||
// check result
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "after1");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(new_graph, g_after));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -140,6 +140,17 @@ def test_all_reduce_fusion_all(tag):
|
|||
res = make_tuple(y1, y2, y3, y4, y5)
|
||||
return make_tuple(res)
|
||||
|
||||
@fns
|
||||
def after1(x1, x2, x3, x4, x5):
|
||||
ar = allreduce(x1, x2, x3, x4, x5)
|
||||
y1 = tuple_getitem(ar, 0)
|
||||
y2 = tuple_getitem(ar, 1)
|
||||
y3 = tuple_getitem(ar, 2)
|
||||
y4 = tuple_getitem(ar, 3)
|
||||
y5 = tuple_getitem(ar, 4)
|
||||
res = make_tuple(y1, y2, y3, y4, y5)
|
||||
return make_tuple(res)
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue