Insert concat for AllGather outputs

This commit is contained in:
yujianfeng 2020-07-20 14:26:26 +08:00
parent 65c6755fab
commit 47ab812edb
5 changed files with 156 additions and 3 deletions

View File

@ -102,6 +102,7 @@
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
@ -341,6 +342,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto other_pm = std::make_shared<PassManager>("other_pm");
other_pm->AddPass(std::make_shared<AllReduceFusion>());
other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<ConcatOutputsForAllGather>());
other_pm->AddPass(std::make_shared<ReduceScatterFusion>());
other_pm->AddPass(std::make_shared<BroadcastFusion>());
other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>());

View File

@ -0,0 +1,104 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
namespace {
void AddOutputs(const AnfNodePtr &node, int rank_size) {
MS_EXCEPTION_IF_NULL(node);
auto origin_abstract = node->abstract();
MS_EXCEPTION_IF_NULL(origin_abstract);
auto tuple_abstract = origin_abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
auto &origin_abstracts = tuple_abstract->elements();
AbstractBasePtrList abstract_list;
std::vector<TypeId> outputs_device_type;
std::vector<std::string> outputs_device_format;
for (int i = 0; i < rank_size; ++i) {
for (size_t j = 0; j < origin_abstracts.size(); ++j) {
abstract_list.push_back(origin_abstracts[j]);
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(node, j));
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(node, j));
}
}
// Update abstract
auto new_abstracts = std::make_shared<abstract::AbstractTuple>(abstract_list);
node->set_abstract(new_abstracts);
// Update kernel build info
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
builder->SetOutputsDeviceType(outputs_device_type);
builder->SetOutputsFormat(outputs_device_format);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
} // namespace
AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::vector<AnfNodePtr> &new_tuple_getitems,
int rank_size) const {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> make_tuple_inputs;
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < inputs_size; ++i) {
for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) {
std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
concat_inputs.push_back(new_tuple_getitems[idx]);
auto concat = func_graph->NewCNode(concat_inputs);
MS_EXCEPTION_IF_NULL(concat);
MS_EXCEPTION_IF_NULL(new_tuple_getitems[idx]);
concat->set_abstract(new_tuple_getitems[idx]->abstract());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), concat);
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat);
std::vector<int> dyn_input_size{rank_size};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
kernel_select_->SelectKernel(concat);
make_tuple_inputs.push_back(concat);
}
}
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}
const BaseRef ConcatOutputsForAllGather::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto prim = std::make_shared<Primitive>(kAllGatherOpName);
return VectorRef({prim, Xs});
}
const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::HasNodeAttr(kAttrFusion, cnode) || !AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
return nullptr;
}
auto fusion = AnfAlgo::GetNodeAttr<int>(cnode, kAttrFusion);
if (fusion <= 0) {
return nullptr;
}
auto rank_size = AnfAlgo::GetNodeAttr<int>(node, kAttrRankSize);
AddOutputs(node, rank_size);
std::vector<AnfNodePtr> new_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs);
return InsertConcatForOutput(func_graph, node, new_outputs, rank_size);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class ConcatOutputsForAllGather : public PatternProcessPass {
public:
explicit ConcatOutputsForAllGather(bool multigraph = true)
: PatternProcessPass("concat_outputs_for_all_gather", multigraph),
kernel_select_(std::make_shared<KernelSelect>()) {}
~ConcatOutputsForAllGather() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::vector<AnfNodePtr> &new_tuple_getitems, int rank_size) const;
KernelSelectPtr kernel_select_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_

View File

@ -188,9 +188,13 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
MS_EXCEPTION_IF_NULL(abstract_tuple);
fused_node->set_abstract(abstract_tuple);
AnfAlgo::CopyNodeAttr("fusion", communication_op_info.communication_op_nodes[end_index], fused_node);
AnfAlgo::CopyNodeAttr("op", communication_op_info.communication_op_nodes[end_index], fused_node);
AnfAlgo::CopyNodeAttr("group", communication_op_info.communication_op_nodes[end_index], fused_node);
auto final_node = communication_op_info.communication_op_nodes[end_index];
AnfAlgo::CopyNodeAttr(kAttrFusion, final_node, fused_node);
AnfAlgo::CopyNodeAttr(kAttrOp, final_node, fused_node);
AnfAlgo::CopyNodeAttr(kAttrGroup, final_node, fused_node);
if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node)) {
AnfAlgo::CopyNodeAttr(kAttrRankSize, final_node, fused_node);
}
return fused_node;
}

View File

@ -250,6 +250,7 @@ constexpr auto kAttrChildGraph = "child_graph";
constexpr auto kAttrInputNums = "inputNums";
constexpr auto kAttrT = "T";
constexpr auto kAttrNum = "num";
constexpr auto kAttrRankSize = "rank_size";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";