forked from mindspore-Ecosystem/mindspore
concat allgather pass performance optimization
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
858c3b19b6
commit
ab5b36988e
|
@ -16,19 +16,50 @@
|
|||
|
||||
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat) {
|
||||
namespace {
|
||||
using OutputInfo =
|
||||
std::tuple<std::vector<TypeId>, std::vector<std::vector<size_t>>, std::vector<std::string>, std::vector<TypeId>>;
|
||||
OutputInfo GetNodeOutputInfo(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<TypeId> output_infer_dtype;
|
||||
std::vector<std::vector<size_t>> output_infer_shape;
|
||||
std::vector<std::string> output_format;
|
||||
std::vector<TypeId> output_device_dtype;
|
||||
auto type_ptr = node->Type();
|
||||
auto shape_ptr = node->Shape();
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
output_infer_dtype.emplace_back(AnfAlgo::GetOutputInferDataType(type_ptr, i));
|
||||
output_infer_shape.emplace_back(AnfAlgo::GetOutputInferShape(node, shape_ptr, i));
|
||||
output_format.emplace_back(build_info->GetOutputFormat(i));
|
||||
output_device_dtype.emplace_back(build_info->GetOutputDeviceType(i));
|
||||
}
|
||||
|
||||
return {output_infer_dtype, output_infer_shape, output_format, output_device_dtype};
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat, const OutputInfo &allgather_output_info,
|
||||
size_t allgather_input_num, size_t allgather_input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
std::vector<std::string> inputs_device_format;
|
||||
std::vector<std::string> outputs_device_format;
|
||||
std::vector<TypeId> inputs_device_type;
|
||||
std::vector<TypeId> outputs_device_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(concat); ++input_index) {
|
||||
inputs_device_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(concat, input_index));
|
||||
inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(concat, input_index));
|
||||
size_t concat_input_num = AnfAlgo::GetInputTensorNum(concat);
|
||||
for (size_t i = 0; i < concat_input_num; ++i) {
|
||||
size_t input_index = allgather_input_idx + i * allgather_input_num;
|
||||
inputs_device_format.emplace_back(std::get<2>(allgather_output_info)[input_index]);
|
||||
inputs_device_type.emplace_back(std::get<3>(allgather_output_info)[input_index]);
|
||||
}
|
||||
// Current only support default format & float16
|
||||
auto cmp_format = inputs_device_format.begin();
|
||||
|
@ -57,9 +88,8 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat) {
|
|||
return builder.Build();
|
||||
}
|
||||
|
||||
AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const std::vector<AnfNodePtr> &new_tuple_getitems,
|
||||
int64_t rank_size) {
|
||||
AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const OutputInfo &output_info,
|
||||
const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
||||
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
|
||||
|
@ -71,16 +101,16 @@ AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &
|
|||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]);
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(new_tuple_getitems[i], 0)};
|
||||
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(new_tuple_getitems[i], 0);
|
||||
shape[0] *= rank_size;
|
||||
const std::vector<TypeId> &dtypes = {std::get<0>(output_info)[i]};
|
||||
const auto &shape = std::get<1>(output_info)[i];
|
||||
std::vector<std::vector<size_t>> shapes = {shape};
|
||||
shapes[0][0] *= rank_size;
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat);
|
||||
std::vector<int64_t> dyn_input_size{rank_size};
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(concat);
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(concat, output_info, inputs_size, i);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, concat.get());
|
||||
make_tuple_inputs.push_back(concat);
|
||||
}
|
||||
|
@ -88,6 +118,7 @@ AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &
|
|||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ConcatOutputsForAllGather::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
@ -113,7 +144,21 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra
|
|||
AnfAlgo::SetNodeAttr("fused", MakeValue(true), node);
|
||||
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize);
|
||||
std::vector<AnfNodePtr> new_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs);
|
||||
return InsertConcatForOutput(func_graph, node, new_outputs, rank_size);
|
||||
OutputInfo output_info = GetNodeOutputInfo(node);
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
int64_t temp = SizeToLong(i);
|
||||
auto idx = NewValueNode(temp);
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int64Imm>(temp);
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||
idx->set_abstract(abstract_scalar);
|
||||
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({std::get<0>(output_info)[i]}, {std::get<1>(output_info)[i]},
|
||||
tuple_getitem.get());
|
||||
new_outputs.emplace_back(std::move(tuple_getitem));
|
||||
}
|
||||
return InsertConcatForOutput(func_graph, node, output_info, new_outputs, rank_size);
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -33,8 +33,6 @@ class ConcatOutputsForAllGather : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
static AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size);
|
||||
KernelSelectPtr kernel_select_;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -173,6 +173,8 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
auto type_ptr = node->Type();
|
||||
auto shape_ptr = node->Shape();
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
int64_t temp = SizeToLong(i);
|
||||
auto idx = NewValueNode(temp);
|
||||
|
@ -182,8 +184,8 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod
|
|||
idx->set_abstract(abstract_scalar);
|
||||
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(node, i)},
|
||||
{AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get());
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(type_ptr, i)},
|
||||
{AnfAlgo::GetOutputInferShape(node, shape_ptr, i)}, tuple_getitem.get());
|
||||
(*outputs).push_back(tuple_getitem);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -653,9 +653,10 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &
|
|||
return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
|
||||
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node,
|
||||
const abstract::BaseShapePtr &base_shape,
|
||||
size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
abstract::BaseShapePtr base_shape = node->Shape();
|
||||
MS_EXCEPTION_IF_NULL(base_shape);
|
||||
if (base_shape->isa<abstract::Shape>()) {
|
||||
if (output_idx == 0) {
|
||||
|
@ -691,6 +692,11 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
|
|||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
|
||||
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return GetOutputInferShape(node, node->Shape(), output_idx);
|
||||
}
|
||||
|
||||
std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
|
||||
|
@ -763,10 +769,19 @@ std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, si
|
|||
return build_info->GetOutputReshapeType(output_idx);
|
||||
}
|
||||
|
||||
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto get_single_type = [](const TypePtr &type_ptr) -> TypeId {
|
||||
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
|
||||
auto type_ptr = type;
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
if (type_ptr->isa<Tuple>()) {
|
||||
auto tuple_ptr = type_ptr->cast<TuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
||||
if (output_idx >= tuple_ptr->size()) {
|
||||
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
|
||||
}
|
||||
type_ptr = (*tuple_ptr)[output_idx];
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
}
|
||||
|
||||
if (type_ptr->isa<TensorType>()) {
|
||||
auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
|
@ -774,25 +789,13 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
|
|||
MS_EXCEPTION_IF_NULL(elem);
|
||||
return elem->type_id();
|
||||
}
|
||||
if (type_ptr->isa<Number>()) {
|
||||
|
||||
return type_ptr->type_id();
|
||||
}
|
||||
return type_ptr->type_id();
|
||||
};
|
||||
auto get_tuple_type = [get_single_type](const TypePtr &type_ptr, size_t output_idx) -> TypeId {
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
if (!type_ptr->isa<Tuple>()) {
|
||||
return get_single_type(type_ptr);
|
||||
}
|
||||
auto tuple_ptr = type_ptr->cast<TuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
||||
if (output_idx >= tuple_ptr->size()) {
|
||||
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
|
||||
}
|
||||
return get_single_type((*tuple_ptr)[output_idx]);
|
||||
};
|
||||
TypePtr type_ptr = node->Type();
|
||||
return get_tuple_type(type_ptr, output_idx);
|
||||
}
|
||||
|
||||
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return GetOutputInferDataType(node->Type(), output_idx);
|
||||
}
|
||||
|
||||
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
|
||||
|
|
|
@ -142,6 +142,8 @@ class AnfRuntimeAlgorithm {
|
|||
static std::string GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output shapes inferred by ME from input nodes.
|
||||
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
|
||||
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape,
|
||||
size_t output_idx);
|
||||
// get input shapes inferred by ME from input nodes.
|
||||
static std::vector<size_t> GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output shapes which will built and run in device
|
||||
|
@ -154,6 +156,7 @@ class AnfRuntimeAlgorithm {
|
|||
static std::string GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
||||
// get output data type inferred by ME of anf node
|
||||
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
|
||||
static TypeId GetOutputInferDataType(const TypePtr &type_ptr, size_t output_idx);
|
||||
// get output original data type from prev node,input_index is the input index of current node related to prev node
|
||||
static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output select data type of anf node
|
||||
|
|
|
@ -180,7 +180,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/
|
|||
|
||||
add_library(_ut_mindspore_obj OBJECT ${MINDSPORE_SRC_LIST})
|
||||
add_library(_ut_ut_obj OBJECT ${UT_SRCS})
|
||||
add_dependencies(_ut_ut_obj engine-cache-server)
|
||||
add_dependencies(_ut_ut_obj engine-cache-server graph)
|
||||
add_executable(ut_tests $<TARGET_OBJECTS:_ut_ut_obj>
|
||||
$<TARGET_OBJECTS:_ut_mindspore_obj>)
|
||||
|
||||
|
|
Loading…
Reference in New Issue