forked from mindspore-Ecosystem/mindspore
!17819 Add all gather fusion and concat pass for gpu
Merge pull request !17819 from ZPaC/master-add-gpu-all-gather-fusion
This commit is contained in:
commit
8fe3da0ddc
|
@ -0,0 +1,162 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/optimizer/gpu/concat_outputs_for_all_gather.h"
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
|
namespace mindspore::opt {
|
||||||
|
namespace {
|
||||||
|
using OutputInfo =
|
||||||
|
std::tuple<std::vector<TypeId>, std::vector<std::vector<size_t>>, std::vector<std::string>, std::vector<TypeId>>;
|
||||||
|
OutputInfo GetNodeOutputInfo(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
std::vector<TypeId> output_infer_dtype;
|
||||||
|
std::vector<std::vector<size_t>> output_infer_shape;
|
||||||
|
std::vector<std::string> output_format;
|
||||||
|
std::vector<TypeId> output_device_dtype;
|
||||||
|
auto type_ptr = node->Type();
|
||||||
|
auto shape_ptr = node->Shape();
|
||||||
|
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||||
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
|
auto build_info = kernel_info->select_kernel_build_info();
|
||||||
|
MS_EXCEPTION_IF_NULL(build_info);
|
||||||
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
|
output_infer_dtype.emplace_back(AnfAlgo::GetOutputInferDataType(type_ptr, i));
|
||||||
|
output_infer_shape.emplace_back(AnfAlgo::GetOutputInferShape(node, shape_ptr, i));
|
||||||
|
output_format.emplace_back(build_info->GetOutputFormat(i));
|
||||||
|
output_device_dtype.emplace_back(build_info->GetOutputDeviceType(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
return {output_infer_dtype, output_infer_shape, output_format, output_device_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat, const OutputInfo &allgather_output_info,
|
||||||
|
size_t allgather_input_num, size_t allgather_input_idx) {
|
||||||
|
MS_EXCEPTION_IF_NULL(concat);
|
||||||
|
std::vector<std::string> inputs_device_format;
|
||||||
|
std::vector<std::string> outputs_device_format;
|
||||||
|
std::vector<TypeId> inputs_device_type;
|
||||||
|
std::vector<TypeId> outputs_device_type;
|
||||||
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||||
|
size_t concat_input_num = AnfAlgo::GetInputTensorNum(concat);
|
||||||
|
for (size_t i = 0; i < concat_input_num; ++i) {
|
||||||
|
size_t input_index = allgather_input_idx + i * allgather_input_num;
|
||||||
|
inputs_device_format.emplace_back(std::get<2>(allgather_output_info)[input_index]);
|
||||||
|
inputs_device_type.emplace_back(std::get<3>(allgather_output_info)[input_index]);
|
||||||
|
}
|
||||||
|
// Current only support default format & float16
|
||||||
|
auto cmp_format = inputs_device_format.begin();
|
||||||
|
auto format_iter = std::find_if(inputs_device_format.begin(), inputs_device_format.end(),
|
||||||
|
[&](const auto &format) { return format != (*cmp_format); });
|
||||||
|
if (format_iter != inputs_device_format.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Input format is not same, value: " << (*format_iter) << ", need format: " << (*cmp_format);
|
||||||
|
}
|
||||||
|
auto cmp_dtype = inputs_device_type.begin();
|
||||||
|
auto dtype_iter = std::find_if(inputs_device_type.begin(), inputs_device_type.end(),
|
||||||
|
[&](const auto &dtype) { return dtype != (*cmp_dtype); });
|
||||||
|
if (dtype_iter != inputs_device_type.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Input dtype is not same, value: " << TypeIdLabel(*dtype_iter)
|
||||||
|
<< ", need dtype: " << TypeIdLabel(*cmp_dtype);
|
||||||
|
}
|
||||||
|
outputs_device_format.emplace_back(*cmp_format);
|
||||||
|
outputs_device_type.emplace_back(*cmp_dtype);
|
||||||
|
|
||||||
|
builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||||
|
builder.SetInputsFormat(inputs_device_format);
|
||||||
|
builder.SetOutputsFormat(outputs_device_format);
|
||||||
|
builder.SetInputsDeviceType(inputs_device_type);
|
||||||
|
builder.SetOutputsDeviceType(outputs_device_type);
|
||||||
|
return builder.Build();
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const OutputInfo &output_info,
|
||||||
|
const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
||||||
|
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
|
||||||
|
for (size_t i = 0; i < inputs_size; ++i) {
|
||||||
|
std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||||
|
for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) {
|
||||||
|
concat_inputs.push_back(new_tuple_getitems[idx]);
|
||||||
|
}
|
||||||
|
auto concat = func_graph->NewCNode(concat_inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(concat);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]);
|
||||||
|
const std::vector<TypeId> &dtypes = {std::get<0>(output_info)[i]};
|
||||||
|
const auto &shape = std::get<1>(output_info)[i];
|
||||||
|
std::vector<std::vector<size_t>> shapes = {shape};
|
||||||
|
shapes[0][0] *= rank_size;
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat);
|
||||||
|
std::vector<int64_t> dyn_input_size{rank_size};
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
|
||||||
|
auto kernel_build_info = GenerateKernelBuildInfo(concat, output_info, inputs_size, i);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, concat.get());
|
||||||
|
make_tuple_inputs.push_back(concat);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||||
|
return make_tuple;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const BaseRef ConcatOutputsForAllGather::DefinePattern() const {
|
||||||
|
VarPtr Xs = std::make_shared<SeqVar>();
|
||||||
|
auto prim = std::make_shared<Primitive>(kAllGatherOpName);
|
||||||
|
return VectorRef({prim, Xs});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (!AnfAlgo::HasNodeAttr(kAttrFusion, cnode) || !AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
|
||||||
|
if (fusion <= 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (AnfAlgo::HasNodeAttr("fused", cnode) || AnfAlgo::GetInputTensorNum(node) == 1) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
AnfAlgo::SetNodeAttr("fused", MakeValue(true), node);
|
||||||
|
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize);
|
||||||
|
std::vector<AnfNodePtr> new_outputs;
|
||||||
|
OutputInfo output_info = GetNodeOutputInfo(node);
|
||||||
|
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||||
|
for (size_t i = 0; i < output_num; ++i) {
|
||||||
|
int64_t temp = SizeToLong(i);
|
||||||
|
auto idx = NewValueNode(temp);
|
||||||
|
MS_EXCEPTION_IF_NULL(idx);
|
||||||
|
auto imm = std::make_shared<Int64Imm>(temp);
|
||||||
|
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||||
|
idx->set_abstract(abstract_scalar);
|
||||||
|
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape({std::get<0>(output_info)[i]}, {std::get<1>(output_info)[i]},
|
||||||
|
tuple_getitem.get());
|
||||||
|
new_outputs.emplace_back(std::move(tuple_getitem));
|
||||||
|
}
|
||||||
|
return InsertConcatForOutput(func_graph, node, output_info, new_outputs, rank_size);
|
||||||
|
}
|
||||||
|
} // namespace mindspore::opt
|
|
@ -0,0 +1,35 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CONCAT_OUTPUTS_FOR_ALLGATHER_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CONCAT_OUTPUTS_FOR_ALLGATHER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class ConcatOutputsForAllGather : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit ConcatOutputsForAllGather(bool multigraph = true)
|
||||||
|
: PatternProcessPass("concat_outputs_for_all_gather", multigraph) {}
|
||||||
|
~ConcatOutputsForAllGather() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_CONCAT_OUTPUTS_FOR_ALLGATHER_H_
|
|
@ -51,6 +51,7 @@
|
||||||
#endif
|
#endif
|
||||||
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
|
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
|
||||||
#include "backend/optimizer/pass/communication_op_fusion.h"
|
#include "backend/optimizer/pass/communication_op_fusion.h"
|
||||||
|
#include "backend/optimizer/gpu/concat_outputs_for_all_gather.h"
|
||||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||||
#include "common/trans.h"
|
#include "common/trans.h"
|
||||||
#include "debug/anf_ir_dump.h"
|
#include "debug/anf_ir_dump.h"
|
||||||
|
@ -175,6 +176,8 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
|
||||||
pm->AddPass(std::make_shared<opt::AddReluV2Fusion>());
|
pm->AddPass(std::make_shared<opt::AddReluV2Fusion>());
|
||||||
pm->AddPass(std::make_shared<opt::AddReluGradV2Fusion>());
|
pm->AddPass(std::make_shared<opt::AddReluGradV2Fusion>());
|
||||||
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
|
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
|
||||||
|
pm->AddPass(std::make_shared<opt::AllGatherFusion>());
|
||||||
|
pm->AddPass(std::make_shared<opt::ConcatOutputsForAllGather>());
|
||||||
pm->AddPass(std::make_shared<opt::GetitemTuple>());
|
pm->AddPass(std::make_shared<opt::GetitemTuple>());
|
||||||
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
|
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
|
||||||
optimizer->AddPassManager(pm);
|
optimizer->AddPassManager(pm);
|
||||||
|
|
|
@ -1052,7 +1052,8 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph
|
||||||
auto &kernels = graph->execution_order();
|
auto &kernels = graph->execution_order();
|
||||||
for (auto &kernel : kernels) {
|
for (auto &kernel : kernels) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
if (AnfAlgo::IsCommunicationOp(kernel) && AnfAlgo::GetCNodeName(kernel) != kHcomSendOpName &&
|
||||||
|
AnfAlgo::GetCNodeName(kernel) != kReceiveOpName) {
|
||||||
AllocCommunicationOpInputDynamicRes(kernel);
|
AllocCommunicationOpInputDynamicRes(kernel);
|
||||||
AllocCommunicationOpOutputDynamicRes(kernel);
|
AllocCommunicationOpOutputDynamicRes(kernel);
|
||||||
}
|
}
|
||||||
|
@ -1120,9 +1121,6 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf
|
||||||
void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, bool, const DeviceAddressPtrList addr_list,
|
void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, bool, const DeviceAddressPtrList addr_list,
|
||||||
size_t total_size, std::vector<size_t> size_list) {
|
size_t total_size, std::vector<size_t> size_list) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
if (!is_need_alloc_memory) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto ret = mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list);
|
auto ret = mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_LOG(EXCEPTION) << "Malloc device memory failed.";
|
MS_LOG(EXCEPTION) << "Malloc device memory failed.";
|
||||||
|
|
Loading…
Reference in New Issue