!17819 Add all gather fusion and concat pass for gpu

Merge pull request !17819 from ZPaC/master-add-gpu-all-gather-fusion
This commit is contained in:
i-robot 2021-06-07 14:32:32 +08:00 committed by Gitee
commit 8fe3da0ddc
4 changed files with 202 additions and 4 deletions

View File

@ -0,0 +1,162 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/concat_outputs_for_all_gather.h"
#include <string>
#include <tuple>
#include <utility>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore::opt {
namespace {
using OutputInfo =
std::tuple<std::vector<TypeId>, std::vector<std::vector<size_t>>, std::vector<std::string>, std::vector<TypeId>>;
OutputInfo GetNodeOutputInfo(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<TypeId> output_infer_dtype;
std::vector<std::vector<size_t>> output_infer_shape;
std::vector<std::string> output_format;
std::vector<TypeId> output_device_dtype;
auto type_ptr = node->Type();
auto shape_ptr = node->Shape();
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
for (size_t i = 0; i < output_num; i++) {
output_infer_dtype.emplace_back(AnfAlgo::GetOutputInferDataType(type_ptr, i));
output_infer_shape.emplace_back(AnfAlgo::GetOutputInferShape(node, shape_ptr, i));
output_format.emplace_back(build_info->GetOutputFormat(i));
output_device_dtype.emplace_back(build_info->GetOutputDeviceType(i));
}
return {output_infer_dtype, output_infer_shape, output_format, output_device_dtype};
}
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat, const OutputInfo &allgather_output_info,
size_t allgather_input_num, size_t allgather_input_idx) {
MS_EXCEPTION_IF_NULL(concat);
std::vector<std::string> inputs_device_format;
std::vector<std::string> outputs_device_format;
std::vector<TypeId> inputs_device_type;
std::vector<TypeId> outputs_device_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
size_t concat_input_num = AnfAlgo::GetInputTensorNum(concat);
for (size_t i = 0; i < concat_input_num; ++i) {
size_t input_index = allgather_input_idx + i * allgather_input_num;
inputs_device_format.emplace_back(std::get<2>(allgather_output_info)[input_index]);
inputs_device_type.emplace_back(std::get<3>(allgather_output_info)[input_index]);
}
// Current only support default format & float16
auto cmp_format = inputs_device_format.begin();
auto format_iter = std::find_if(inputs_device_format.begin(), inputs_device_format.end(),
[&](const auto &format) { return format != (*cmp_format); });
if (format_iter != inputs_device_format.end()) {
MS_LOG(EXCEPTION) << "Input format is not same, value: " << (*format_iter) << ", need format: " << (*cmp_format);
}
auto cmp_dtype = inputs_device_type.begin();
auto dtype_iter = std::find_if(inputs_device_type.begin(), inputs_device_type.end(),
[&](const auto &dtype) { return dtype != (*cmp_dtype); });
if (dtype_iter != inputs_device_type.end()) {
MS_LOG(EXCEPTION) << "Input dtype is not same, value: " << TypeIdLabel(*dtype_iter)
<< ", need dtype: " << TypeIdLabel(*cmp_dtype);
}
outputs_device_format.emplace_back(*cmp_format);
outputs_device_type.emplace_back(*cmp_dtype);
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetInputsFormat(inputs_device_format);
builder.SetOutputsFormat(outputs_device_format);
builder.SetInputsDeviceType(inputs_device_type);
builder.SetOutputsDeviceType(outputs_device_type);
return builder.Build();
}
AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const OutputInfo &output_info,
const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < inputs_size; ++i) {
std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) {
concat_inputs.push_back(new_tuple_getitems[idx]);
}
auto concat = func_graph->NewCNode(concat_inputs);
MS_EXCEPTION_IF_NULL(concat);
MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]);
const std::vector<TypeId> &dtypes = {std::get<0>(output_info)[i]};
const auto &shape = std::get<1>(output_info)[i];
std::vector<std::vector<size_t>> shapes = {shape};
shapes[0][0] *= rank_size;
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat);
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat);
std::vector<int64_t> dyn_input_size{rank_size};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
auto kernel_build_info = GenerateKernelBuildInfo(concat, output_info, inputs_size, i);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, concat.get());
make_tuple_inputs.push_back(concat);
}
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}
} // namespace
const BaseRef ConcatOutputsForAllGather::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto prim = std::make_shared<Primitive>(kAllGatherOpName);
return VectorRef({prim, Xs});
}
const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::HasNodeAttr(kAttrFusion, cnode) || !AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
return nullptr;
}
auto fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
if (fusion <= 0) {
return nullptr;
}
if (AnfAlgo::HasNodeAttr("fused", cnode) || AnfAlgo::GetInputTensorNum(node) == 1) {
return nullptr;
}
AnfAlgo::SetNodeAttr("fused", MakeValue(true), node);
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize);
std::vector<AnfNodePtr> new_outputs;
OutputInfo output_info = GetNodeOutputInfo(node);
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; ++i) {
int64_t temp = SizeToLong(i);
auto idx = NewValueNode(temp);
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(temp);
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
idx->set_abstract(abstract_scalar);
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_getitem);
AnfAlgo::SetOutputInferTypeAndShape({std::get<0>(output_info)[i]}, {std::get<1>(output_info)[i]},
tuple_getitem.get());
new_outputs.emplace_back(std::move(tuple_getitem));
}
return InsertConcatForOutput(func_graph, node, output_info, new_outputs, rank_size);
}
} // namespace mindspore::opt

View File

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CONCAT_OUTPUTS_FOR_ALLGATHER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CONCAT_OUTPUTS_FOR_ALLGATHER_H_
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class ConcatOutputsForAllGather : public PatternProcessPass {
public:
explicit ConcatOutputsForAllGather(bool multigraph = true)
: PatternProcessPass("concat_outputs_for_all_gather", multigraph) {}
~ConcatOutputsForAllGather() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CONCAT_OUTPUTS_FOR_ALLGATHER_H_

View File

@ -51,6 +51,7 @@
#endif #endif
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h" #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
#include "backend/optimizer/pass/communication_op_fusion.h" #include "backend/optimizer/pass/communication_op_fusion.h"
#include "backend/optimizer/gpu/concat_outputs_for_all_gather.h"
#include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/pass/getitem_tuple.h"
#include "common/trans.h" #include "common/trans.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
@ -175,6 +176,8 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
pm->AddPass(std::make_shared<opt::AddReluV2Fusion>()); pm->AddPass(std::make_shared<opt::AddReluV2Fusion>());
pm->AddPass(std::make_shared<opt::AddReluGradV2Fusion>()); pm->AddPass(std::make_shared<opt::AddReluGradV2Fusion>());
pm->AddPass(std::make_shared<opt::AllReduceFusion>()); pm->AddPass(std::make_shared<opt::AllReduceFusion>());
pm->AddPass(std::make_shared<opt::AllGatherFusion>());
pm->AddPass(std::make_shared<opt::ConcatOutputsForAllGather>());
pm->AddPass(std::make_shared<opt::GetitemTuple>()); pm->AddPass(std::make_shared<opt::GetitemTuple>());
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision")); pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
optimizer->AddPassManager(pm); optimizer->AddPassManager(pm);

View File

@ -1052,7 +1052,8 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph
auto &kernels = graph->execution_order(); auto &kernels = graph->execution_order();
for (auto &kernel : kernels) { for (auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::IsCommunicationOp(kernel)) { if (AnfAlgo::IsCommunicationOp(kernel) && AnfAlgo::GetCNodeName(kernel) != kHcomSendOpName &&
AnfAlgo::GetCNodeName(kernel) != kReceiveOpName) {
AllocCommunicationOpInputDynamicRes(kernel); AllocCommunicationOpInputDynamicRes(kernel);
AllocCommunicationOpOutputDynamicRes(kernel); AllocCommunicationOpOutputDynamicRes(kernel);
} }
@ -1120,9 +1121,6 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf
void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, bool, const DeviceAddressPtrList addr_list, void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, bool, const DeviceAddressPtrList addr_list,
size_t total_size, std::vector<size_t> size_list) { size_t total_size, std::vector<size_t> size_list) {
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
if (!is_need_alloc_memory) {
return;
}
auto ret = mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); auto ret = mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list);
if (!ret) { if (!ret) {
MS_LOG(EXCEPTION) << "Malloc device memory failed."; MS_LOG(EXCEPTION) << "Malloc device memory failed.";