!16496 concat allgather pass performance optimization

From: @zhoufeng54
Reviewed-by: @kisnwang,@jjfeing
Signed-off-by: @jjfeing
This commit is contained in:
mindspore-ci-bot 2021-05-17 21:09:44 +08:00 committed by Gitee
commit 034193f2df
6 changed files with 96 additions and 45 deletions

View File

@ -16,19 +16,50 @@
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
#include <string>
#include <tuple>
#include <utility>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore::opt {
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat) {
namespace {
using OutputInfo =
std::tuple<std::vector<TypeId>, std::vector<std::vector<size_t>>, std::vector<std::string>, std::vector<TypeId>>;
OutputInfo GetNodeOutputInfo(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<TypeId> output_infer_dtype;
std::vector<std::vector<size_t>> output_infer_shape;
std::vector<std::string> output_format;
std::vector<TypeId> output_device_dtype;
auto type_ptr = node->Type();
auto shape_ptr = node->Shape();
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
for (size_t i = 0; i < output_num; i++) {
output_infer_dtype.emplace_back(AnfAlgo::GetOutputInferDataType(type_ptr, i));
output_infer_shape.emplace_back(AnfAlgo::GetOutputInferShape(node, shape_ptr, i));
output_format.emplace_back(build_info->GetOutputFormat(i));
output_device_dtype.emplace_back(build_info->GetOutputDeviceType(i));
}
return {output_infer_dtype, output_infer_shape, output_format, output_device_dtype};
}
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat, const OutputInfo &allgather_output_info,
size_t allgather_input_num, size_t allgather_input_idx) {
MS_EXCEPTION_IF_NULL(concat);
std::vector<std::string> inputs_device_format;
std::vector<std::string> outputs_device_format;
std::vector<TypeId> inputs_device_type;
std::vector<TypeId> outputs_device_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(concat); ++input_index) {
inputs_device_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(concat, input_index));
inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(concat, input_index));
size_t concat_input_num = AnfAlgo::GetInputTensorNum(concat);
for (size_t i = 0; i < concat_input_num; ++i) {
size_t input_index = allgather_input_idx + i * allgather_input_num;
inputs_device_format.emplace_back(std::get<2>(allgather_output_info)[input_index]);
inputs_device_type.emplace_back(std::get<3>(allgather_output_info)[input_index]);
}
// Current only support default format & float16
auto cmp_format = inputs_device_format.begin();
@ -57,9 +88,8 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat) {
return builder.Build();
}
AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::vector<AnfNodePtr> &new_tuple_getitems,
int64_t rank_size) {
AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const OutputInfo &output_info,
const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
@ -71,16 +101,16 @@ AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &
auto concat = func_graph->NewCNode(concat_inputs);
MS_EXCEPTION_IF_NULL(concat);
MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]);
auto dtypes = {AnfAlgo::GetOutputInferDataType(new_tuple_getitems[i], 0)};
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(new_tuple_getitems[i], 0);
shape[0] *= rank_size;
const std::vector<TypeId> &dtypes = {std::get<0>(output_info)[i]};
const auto &shape = std::get<1>(output_info)[i];
std::vector<std::vector<size_t>> shapes = {shape};
shapes[0][0] *= rank_size;
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat);
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat);
std::vector<int64_t> dyn_input_size{rank_size};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
auto kernel_build_info = GenerateKernelBuildInfo(concat);
auto kernel_build_info = GenerateKernelBuildInfo(concat, output_info, inputs_size, i);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, concat.get());
make_tuple_inputs.push_back(concat);
}
@ -88,6 +118,7 @@ AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}
} // namespace
const BaseRef ConcatOutputsForAllGather::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
@ -113,7 +144,21 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra
AnfAlgo::SetNodeAttr("fused", MakeValue(true), node);
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize);
std::vector<AnfNodePtr> new_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs);
return InsertConcatForOutput(func_graph, node, new_outputs, rank_size);
OutputInfo output_info = GetNodeOutputInfo(node);
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; ++i) {
int64_t temp = SizeToLong(i);
auto idx = NewValueNode(temp);
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(temp);
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
idx->set_abstract(abstract_scalar);
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_getitem);
AnfAlgo::SetOutputInferTypeAndShape({std::get<0>(output_info)[i]}, {std::get<1>(output_info)[i]},
tuple_getitem.get());
new_outputs.emplace_back(std::move(tuple_getitem));
}
return InsertConcatForOutput(func_graph, node, output_info, new_outputs, rank_size);
}
} // namespace mindspore::opt

View File

@ -33,8 +33,6 @@ class ConcatOutputsForAllGather : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
static AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size);
KernelSelectPtr kernel_select_;
};
} // namespace opt

View File

@ -173,6 +173,8 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(outputs);
auto type_ptr = node->Type();
auto shape_ptr = node->Shape();
for (size_t i = 0; i < output_num; i++) {
int64_t temp = SizeToLong(i);
auto idx = NewValueNode(temp);
@ -182,8 +184,8 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod
idx->set_abstract(abstract_scalar);
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_getitem);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(node, i)},
{AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get());
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(type_ptr, i)},
{AnfAlgo::GetOutputInferShape(node, shape_ptr, i)}, tuple_getitem.get());
(*outputs).push_back(tuple_getitem);
}
}

View File

@ -653,9 +653,10 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &
return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
}
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node,
const abstract::BaseShapePtr &base_shape,
size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
abstract::BaseShapePtr base_shape = node->Shape();
MS_EXCEPTION_IF_NULL(base_shape);
if (base_shape->isa<abstract::Shape>()) {
if (output_idx == 0) {
@ -691,6 +692,11 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
<< " trace: " << trace::DumpSourceLines(node);
}
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
return GetOutputInferShape(node, node->Shape(), output_idx);
}
std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
@ -763,36 +769,33 @@ std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, si
return build_info->GetOutputReshapeType(output_idx);
}
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
auto get_single_type = [](const TypePtr &type_ptr) -> TypeId {
MS_EXCEPTION_IF_NULL(type_ptr);
if (type_ptr->isa<TensorType>()) {
auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
TypePtr elem = tensor_ptr->element();
MS_EXCEPTION_IF_NULL(elem);
return elem->type_id();
}
if (type_ptr->isa<Number>()) {
return type_ptr->type_id();
}
return type_ptr->type_id();
};
auto get_tuple_type = [get_single_type](const TypePtr &type_ptr, size_t output_idx) -> TypeId {
MS_EXCEPTION_IF_NULL(type_ptr);
if (!type_ptr->isa<Tuple>()) {
return get_single_type(type_ptr);
}
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
auto type_ptr = type;
MS_EXCEPTION_IF_NULL(type_ptr);
if (type_ptr->isa<Tuple>()) {
auto tuple_ptr = type_ptr->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_ptr);
if (output_idx >= tuple_ptr->size()) {
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
}
return get_single_type((*tuple_ptr)[output_idx]);
};
TypePtr type_ptr = node->Type();
return get_tuple_type(type_ptr, output_idx);
type_ptr = (*tuple_ptr)[output_idx];
MS_EXCEPTION_IF_NULL(type_ptr);
}
if (type_ptr->isa<TensorType>()) {
auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
TypePtr elem = tensor_ptr->element();
MS_EXCEPTION_IF_NULL(elem);
return elem->type_id();
}
return type_ptr->type_id();
}
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
return GetOutputInferDataType(node->Type(), output_idx);
}
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {

View File

@ -142,6 +142,8 @@ class AnfRuntimeAlgorithm {
static std::string GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
// get output shapes inferred by ME from input nodes.
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape,
size_t output_idx);
// get input shapes inferred by ME from input nodes.
static std::vector<size_t> GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
// get output shapes which will built and run in device
@ -154,6 +156,7 @@ class AnfRuntimeAlgorithm {
static std::string GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
// get output data type inferred by ME of anf node
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
static TypeId GetOutputInferDataType(const TypePtr &type_ptr, size_t output_idx);
// get output original data type from prev node,input_index is the input index of current node related to prev node
static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
// get output select data type of anf node

View File

@ -180,7 +180,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/
add_library(_ut_mindspore_obj OBJECT ${MINDSPORE_SRC_LIST})
add_library(_ut_ut_obj OBJECT ${UT_SRCS})
add_dependencies(_ut_ut_obj engine-cache-server)
add_dependencies(_ut_ut_obj engine-cache-server graph)
add_executable(ut_tests $<TARGET_OBJECTS:_ut_ut_obj>
$<TARGET_OBJECTS:_ut_mindspore_obj>)