ReduceScatter Fusion

This commit is contained in:
alouhahaha 2020-12-18 10:47:06 +08:00
parent 24466b92c1
commit 3da59427bc
6 changed files with 174 additions and 2 deletions

View File

@ -168,6 +168,9 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
if (AnfAlgo::GetInputTensorNum(anf_node_) > 1 && op_name == kAllGatherOpName && fusion >= 1) {
loop_size *= rank_size;
}
if (op_name == kReduceScatterOpName && fusion >= 1) {
loop_size = AnfAlgo::GetOutputTensorNum(anf_node_);
}
for (ulong i = 0; i < loop_size; ++i) {
if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) {
MS_LOG(ERROR) << "GetHcclOpOutputSize failed";

View File

@ -70,7 +70,11 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index));
if (op_name == kReduceScatter && AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrFusion) > 0) {
outputs_format.emplace_back(GetKernelFormat(kernel_node, 0));
} else {
outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index));
}
outputs_type.push_back(type);
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();

View File

@ -125,6 +125,7 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp
if (AnfAlgo::GetCNodeName(anf_node) == kReduceScatterOpName) {
int64_t rank_size;
auto cnode = anf_node->cast<CNodePtr>();
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("rank_size") != nullptr) {
@ -133,7 +134,11 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp
MS_LOG(ERROR) << "Get rank size failed";
return false;
}
block_size = input_size / LongToSize(rank_size);
int64_t actual_input_size = input_size;
if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) {
actual_input_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
}
block_size = actual_input_size / LongToSize(rank_size);
total_size = total_size + block_size;
} else {
if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) {

View File

@ -116,6 +116,7 @@
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
#include "backend/optimizer/ascend/enhancer/split_inputs_for_reduce_scatter.h"
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h"
#include "utils/ms_context.h"
@ -363,6 +364,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<ConcatOutputsForAllGather>());
other_pm->AddPass(std::make_shared<ReduceScatterFusion>());
other_pm->AddPass(std::make_shared<SplitInputsForReduceScatter>());
other_pm->AddPass(std::make_shared<BroadcastFusion>());
other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());

View File

@ -0,0 +1,114 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/enhancer/split_inputs_for_reduce_scatter.h"
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const FuncGraphPtr &func_graph,
const CNodePtr &node,
int64_t rank_size) const {
MS_EXCEPTION_IF_NULL(func_graph);
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
std::vector<AnfNodePtr> split_outputs;
for (size_t i = 0; i < inputs_size; i++) {
std::vector<AnfNodePtr> split_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_inputs.push_back(AnfAlgo::GetInputNode(node, i));
auto split = func_graph->NewCNode(split_inputs);
MS_EXCEPTION_IF_NULL(split);
std::vector<TypeId> dtypes(rank_size, AnfAlgo::GetPrevNodeOutputInferDataType(node, i));
std::vector<std::vector<size_t>> shapes;
std::vector<int> size_splits;
for (size_t j = 0; j < IntToSize(rank_size); j++) {
std::vector<size_t> output_node_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, i);
output_node_shape[0] /= rank_size;
shapes.push_back(output_node_shape);
size_splits.push_back(output_node_shape[0]);
}
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0), split);
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToInt(rank_size)), split);
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split);
kernel_select_->SelectKernel(split);
std::vector<AnfNodePtr> new_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, split, AnfAlgo::GetOutputTensorNum(split), &new_outputs);
for (size_t j = 0; j < new_outputs.size(); j++) {
split_outputs.push_back(new_outputs[j]);
}
}
return split_outputs;
}
AnfNodePtr SplitInputsForReduceScatter::RearrangeInputsForReduceScatter(const FuncGraphPtr &func_graph,
const AnfNodePtr &node,
const std::vector<AnfNodePtr> &inputs,
int64_t rank_size) const {
MS_EXCEPTION_IF_NULL(func_graph);
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
std::vector<AnfNodePtr> reduce_scatter_inputs{
NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceScatter->name()))};
for (size_t i = 0; i < IntToSize(rank_size); i++) {
for (size_t j = 0, idx = i; j < inputs_size; j++, idx += IntToSize(rank_size)) {
reduce_scatter_inputs.push_back(inputs[idx]);
}
}
auto reduce_scatter = func_graph->NewCNode(reduce_scatter_inputs);
MS_EXCEPTION_IF_NULL(reduce_scatter);
reduce_scatter->set_abstract(node->abstract());
AnfAlgo::CopyNodeAttrs(node, reduce_scatter);
AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(1L), reduce_scatter);
kernel_select_->SelectKernel(reduce_scatter);
return reduce_scatter;
}
const BaseRef SplitInputsForReduceScatter::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto prim = std::make_shared<Primitive>(kReduceScatterOpName);
return VectorRef({prim, Xs});
}
const AnfNodePtr SplitInputsForReduceScatter::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetInputTensorNum(node) == 1) {
AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(0L), node);
return nullptr;
}
if (!AnfAlgo::HasNodeAttr(kAttrFusion, cnode) || !AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
return nullptr;
}
auto fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
if (fusion <= 0) {
return nullptr;
}
if (AnfAlgo::HasNodeAttr("Fused", cnode)) {
return nullptr;
}
AnfAlgo::SetNodeAttr("Fused", MakeValue(true), node);
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize);
std::vector<AnfNodePtr> split_outputs = InsertSplitForInput(func_graph, cnode, rank_size);
return RearrangeInputsForReduceScatter(func_graph, node, split_outputs, rank_size);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_SPLIT_INPUTS_FOR_REDUCE_SCATTER_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_SPLIT_INPUTS_FOR_REDUCE_SCATTER_H_
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class SplitInputsForReduceScatter : public PatternProcessPass {
public:
explicit SplitInputsForReduceScatter(bool multigraph = true)
: PatternProcessPass("split_inputs_for_reduce_scatter", multigraph),
kernel_select_(std::make_shared<KernelSelect>()) {}
~SplitInputsForReduceScatter() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr RearrangeInputsForReduceScatter(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::vector<AnfNodePtr> &inputs, int64_t rank_size) const;
std::vector<AnfNodePtr> InsertSplitForInput(const FuncGraphPtr &func_graph, const CNodePtr &node,
int64_t rank_size) const;
KernelSelectPtr kernel_select_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_SPLIT_INPUTS_FOR_REDUCE_SCATTER_H_