forked from mindspore-Ecosystem/mindspore
ReduceScatter Fusion
This commit is contained in:
parent
24466b92c1
commit
3da59427bc
|
@ -168,6 +168,9 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
|
|||
if (AnfAlgo::GetInputTensorNum(anf_node_) > 1 && op_name == kAllGatherOpName && fusion >= 1) {
|
||||
loop_size *= rank_size;
|
||||
}
|
||||
if (op_name == kReduceScatterOpName && fusion >= 1) {
|
||||
loop_size = AnfAlgo::GetOutputTensorNum(anf_node_);
|
||||
}
|
||||
for (ulong i = 0; i < loop_size; ++i) {
|
||||
if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) {
|
||||
MS_LOG(ERROR) << "GetHcclOpOutputSize failed";
|
||||
|
|
|
@ -70,7 +70,11 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
||||
outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index));
|
||||
if (op_name == kReduceScatter && AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrFusion) > 0) {
|
||||
outputs_format.emplace_back(GetKernelFormat(kernel_node, 0));
|
||||
} else {
|
||||
outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index));
|
||||
}
|
||||
outputs_type.push_back(type);
|
||||
}
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
|
|
|
@ -125,6 +125,7 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp
|
|||
|
||||
if (AnfAlgo::GetCNodeName(anf_node) == kReduceScatterOpName) {
|
||||
int64_t rank_size;
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
if (primitive->GetAttr("rank_size") != nullptr) {
|
||||
|
@ -133,7 +134,11 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp
|
|||
MS_LOG(ERROR) << "Get rank size failed";
|
||||
return false;
|
||||
}
|
||||
block_size = input_size / LongToSize(rank_size);
|
||||
int64_t actual_input_size = input_size;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) {
|
||||
actual_input_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
|
||||
}
|
||||
block_size = actual_input_size / LongToSize(rank_size);
|
||||
total_size = total_size + block_size;
|
||||
} else {
|
||||
if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) {
|
||||
|
|
|
@ -116,6 +116,7 @@
|
|||
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
|
||||
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
|
||||
#include "backend/optimizer/ascend/enhancer/split_inputs_for_reduce_scatter.h"
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -363,6 +364,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
other_pm->AddPass(std::make_shared<AllGatherFusion>());
|
||||
other_pm->AddPass(std::make_shared<ConcatOutputsForAllGather>());
|
||||
other_pm->AddPass(std::make_shared<ReduceScatterFusion>());
|
||||
other_pm->AddPass(std::make_shared<SplitInputsForReduceScatter>());
|
||||
other_pm->AddPass(std::make_shared<BroadcastFusion>());
|
||||
other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>());
|
||||
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/enhancer/split_inputs_for_reduce_scatter.h"
|
||||
#include <string>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &node,
|
||||
int64_t rank_size) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
|
||||
std::vector<AnfNodePtr> split_outputs;
|
||||
for (size_t i = 0; i < inputs_size; i++) {
|
||||
std::vector<AnfNodePtr> split_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_inputs.push_back(AnfAlgo::GetInputNode(node, i));
|
||||
auto split = func_graph->NewCNode(split_inputs);
|
||||
MS_EXCEPTION_IF_NULL(split);
|
||||
std::vector<TypeId> dtypes(rank_size, AnfAlgo::GetPrevNodeOutputInferDataType(node, i));
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
std::vector<int> size_splits;
|
||||
for (size_t j = 0; j < IntToSize(rank_size); j++) {
|
||||
std::vector<size_t> output_node_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, i);
|
||||
output_node_shape[0] /= rank_size;
|
||||
shapes.push_back(output_node_shape);
|
||||
size_splits.push_back(output_node_shape[0]);
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
|
||||
|
||||
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0), split);
|
||||
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToInt(rank_size)), split);
|
||||
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split);
|
||||
kernel_select_->SelectKernel(split);
|
||||
std::vector<AnfNodePtr> new_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, split, AnfAlgo::GetOutputTensorNum(split), &new_outputs);
|
||||
for (size_t j = 0; j < new_outputs.size(); j++) {
|
||||
split_outputs.push_back(new_outputs[j]);
|
||||
}
|
||||
}
|
||||
return split_outputs;
|
||||
}
|
||||
|
||||
AnfNodePtr SplitInputsForReduceScatter::RearrangeInputsForReduceScatter(const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &node,
|
||||
const std::vector<AnfNodePtr> &inputs,
|
||||
int64_t rank_size) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
|
||||
std::vector<AnfNodePtr> reduce_scatter_inputs{
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceScatter->name()))};
|
||||
for (size_t i = 0; i < IntToSize(rank_size); i++) {
|
||||
for (size_t j = 0, idx = i; j < inputs_size; j++, idx += IntToSize(rank_size)) {
|
||||
reduce_scatter_inputs.push_back(inputs[idx]);
|
||||
}
|
||||
}
|
||||
auto reduce_scatter = func_graph->NewCNode(reduce_scatter_inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_scatter);
|
||||
reduce_scatter->set_abstract(node->abstract());
|
||||
|
||||
AnfAlgo::CopyNodeAttrs(node, reduce_scatter);
|
||||
AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(1L), reduce_scatter);
|
||||
kernel_select_->SelectKernel(reduce_scatter);
|
||||
return reduce_scatter;
|
||||
}
|
||||
|
||||
const BaseRef SplitInputsForReduceScatter::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto prim = std::make_shared<Primitive>(kReduceScatterOpName);
|
||||
return VectorRef({prim, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr SplitInputsForReduceScatter::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
if (AnfAlgo::GetInputTensorNum(node) == 1) {
|
||||
AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(0L), node);
|
||||
return nullptr;
|
||||
}
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrFusion, cnode) || !AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
|
||||
if (fusion <= 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr("Fused", cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfAlgo::SetNodeAttr("Fused", MakeValue(true), node);
|
||||
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize);
|
||||
std::vector<AnfNodePtr> split_outputs = InsertSplitForInput(func_graph, cnode, rank_size);
|
||||
return RearrangeInputsForReduceScatter(func_graph, node, split_outputs, rank_size);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_SPLIT_INPUTS_FOR_REDUCE_SCATTER_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_SPLIT_INPUTS_FOR_REDUCE_SCATTER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SplitInputsForReduceScatter : public PatternProcessPass {
|
||||
public:
|
||||
explicit SplitInputsForReduceScatter(bool multigraph = true)
|
||||
: PatternProcessPass("split_inputs_for_reduce_scatter", multigraph),
|
||||
kernel_select_(std::make_shared<KernelSelect>()) {}
|
||||
~SplitInputsForReduceScatter() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr RearrangeInputsForReduceScatter(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const std::vector<AnfNodePtr> &inputs, int64_t rank_size) const;
|
||||
std::vector<AnfNodePtr> InsertSplitForInput(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
int64_t rank_size) const;
|
||||
KernelSelectPtr kernel_select_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_SPLIT_INPUTS_FOR_REDUCE_SCATTER_H_
|
Loading…
Reference in New Issue