forked from mindspore-Ecosystem/mindspore
!9298 Add an AllGather fusion feature
From: @alouhahahahaha Reviewed-by: Signed-off-by:
This commit is contained in:
commit
0225c29e71
|
@ -133,8 +133,22 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
|
|||
if (!output_size_list_.empty()) {
|
||||
return output_size_list_;
|
||||
}
|
||||
for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) {
|
||||
if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_output_shape_list_[i], &size)) {
|
||||
auto cnode = anf_node_->cast<CNodePtr>();
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
int64_t rank_size = 1;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
|
||||
rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
|
||||
}
|
||||
int64_t fusion = 0;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode)) {
|
||||
fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
|
||||
}
|
||||
ulong loop_size = hccl_data_type_list_.size();
|
||||
if (op_name == kAllGatherOpName && fusion >= 1) {
|
||||
loop_size *= rank_size;
|
||||
}
|
||||
for (ulong i = 0; i < loop_size; ++i) {
|
||||
if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) {
|
||||
MS_LOG(ERROR) << "GetHcclOpOutputSize failed";
|
||||
}
|
||||
output_size_list_.push_back(size);
|
||||
|
|
|
@ -127,7 +127,12 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp
|
|||
total_size = total_size + block_size;
|
||||
} else {
|
||||
if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) {
|
||||
block_size = input_size;
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) {
|
||||
block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
|
||||
} else {
|
||||
block_size = input_size;
|
||||
}
|
||||
} else {
|
||||
block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
|
||||
}
|
||||
|
|
|
@ -20,58 +20,33 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
void AddOutputs(const AnfNodePtr &node, int64_t rank_size) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto origin_abstract = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(origin_abstract);
|
||||
auto tuple_abstract = origin_abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
auto &origin_abstracts = tuple_abstract->elements();
|
||||
AbstractBasePtrList abstract_list;
|
||||
std::vector<TypeId> outputs_device_type;
|
||||
std::vector<std::string> outputs_device_format;
|
||||
for (int64_t i = 0; i < rank_size; ++i) {
|
||||
for (size_t j = 0; j < origin_abstracts.size(); ++j) {
|
||||
abstract_list.push_back(origin_abstracts[j]);
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(node, j));
|
||||
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(node, j));
|
||||
}
|
||||
}
|
||||
// Update abstract
|
||||
auto new_abstracts = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
node->set_abstract(new_abstracts);
|
||||
// Update kernel build info
|
||||
auto builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
|
||||
builder->SetOutputsDeviceType(outputs_device_type);
|
||||
builder->SetOutputsFormat(outputs_device_format);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const std::vector<AnfNodePtr> &new_tuple_getitems,
|
||||
int64_t rank_size) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
||||
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < inputs_size; ++i) {
|
||||
for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) {
|
||||
std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) {
|
||||
concat_inputs.push_back(new_tuple_getitems[idx]);
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
MS_EXCEPTION_IF_NULL(new_tuple_getitems[idx]);
|
||||
concat->set_abstract(new_tuple_getitems[idx]->abstract());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat);
|
||||
std::vector<int64_t> dyn_input_size{rank_size};
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
|
||||
kernel_select_->SelectKernel(concat);
|
||||
make_tuple_inputs.push_back(concat);
|
||||
}
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]);
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(new_tuple_getitems[i], 0)};
|
||||
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(new_tuple_getitems[i], 0);
|
||||
shape[0] *= rank_size;
|
||||
std::vector<std::vector<size_t>> shapes = {shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat);
|
||||
std::vector<int64_t> dyn_input_size{rank_size};
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
|
||||
kernel_select_->SelectKernel(concat);
|
||||
make_tuple_inputs.push_back(concat);
|
||||
}
|
||||
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
return make_tuple;
|
||||
}
|
||||
|
@ -94,8 +69,11 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra
|
|||
if (fusion <= 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr("fused", cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr("fused", MakeValue(true), node);
|
||||
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize);
|
||||
AddOutputs(node, rank_size);
|
||||
std::vector<AnfNodePtr> new_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs);
|
||||
return InsertConcatForOutput(func_graph, node, new_outputs, rank_size);
|
||||
|
|
|
@ -46,15 +46,23 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co
|
|||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto cnode = communication_op_info.communication_op_nodes[idx];
|
||||
int64_t rank_size = 1;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) && AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) {
|
||||
rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
|
||||
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
|
||||
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
for (size_t rank_index = 0; rank_index < IntToSize(rank_size); ++rank_index) {
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
|
||||
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
|
||||
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
|
||||
shape[0] /= rank_size;
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
}
|
||||
}
|
||||
builder.SetFusionType(AnfAlgo::GetFusionType(cnode));
|
||||
builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
|
||||
|
@ -182,18 +190,27 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
|
|||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
fused_node->set_kernel_info(kernel_info);
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto cnode = communication_op_info.communication_op_nodes[idx];
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
auto final_node = communication_op_info.communication_op_nodes[end_index];
|
||||
size_t node_num = end_index - start_index + 1;
|
||||
int64_t rank_size = 1;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) && AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) {
|
||||
rank_size = AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize);
|
||||
}
|
||||
size_t output_num = node_num * rank_size;
|
||||
std::vector<TypeId> dtypes(output_num, AnfAlgo::GetOutputInferDataType(final_node, 0));
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
for (size_t i = 0; i < IntToSize(rank_size); ++i) {
|
||||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto cnode = communication_op_info.communication_op_nodes[idx];
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
||||
shape[0] /= rank_size;
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, fused_node.get());
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get());
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
fused_node->set_abstract(abstract_tuple);
|
||||
auto final_node = communication_op_info.communication_op_nodes[end_index];
|
||||
AnfAlgo::CopyNodeAttr(kAttrFusion, final_node, fused_node);
|
||||
AnfAlgo::CopyNodeAttr(kAttrOp, final_node, fused_node);
|
||||
AnfAlgo::CopyNodeAttr(kAttrGroup, final_node, fused_node);
|
||||
|
|
Loading…
Reference in New Issue