forked from mindspore-Ecosystem/mindspore
!492 Add AllGather fusion pass
Merge pull request !492 from YuJianfeng/allgather
This commit is contained in:
commit
4a3e5cb944
|
@ -21,7 +21,7 @@
|
|||
#include "pre_activate/ascend/ir_fission/bn_grad_split.h"
|
||||
#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h"
|
||||
#include "pre_activate/pass/allreduce_fusion.h"
|
||||
#include "pre_activate/pass/communication_op_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h"
|
||||
|
@ -261,6 +261,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto other_pm = std::make_shared<PassManager>("other_pm");
|
||||
other_pm->AddPass(std::make_shared<AllReduceFusion>());
|
||||
other_pm->AddPass(std::make_shared<AllGatherFusion>());
|
||||
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
|
||||
other_pm->AddPass(std::make_shared<BufferFusion>());
|
||||
other_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
|
||||
#include <vector>
|
||||
|
||||
#include "pre_activate/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
struct AllReduceInfo_t {
|
||||
std::vector<CNodePtr> allreduce_node;
|
||||
std::vector<float> input_grad_size;
|
||||
std::vector<float> input_grad_time;
|
||||
};
|
||||
|
||||
class AllReduceFusion : public Pass {
|
||||
public:
|
||||
explicit AllReduceFusion(size_t groups = 1) : Pass("all_reduce_fusion"), groups_(groups) {}
|
||||
~AllReduceFusion() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
bool DoFusion(const FuncGraphPtr &func_graph, const AllReduceInfo_t &allreduce_node_info, size_t segment_num,
|
||||
const std::vector<size_t> &segment_index) const;
|
||||
AnfNodePtr CreateFusedAllReduce(const FuncGraphPtr &func_graph, const AllReduceInfo_t &allreduce_node_info,
|
||||
size_t start_index, size_t end_index) const;
|
||||
bool GetSplitSegments(const AllReduceInfo_t &allreduce_node_info, size_t *segment_num,
|
||||
std::vector<size_t> *segment_index) const;
|
||||
size_t groups_ = 1;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_
|
|
@ -13,14 +13,12 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/pass/allreduce_fusion.h"
|
||||
#include "pre_activate/pass/communication_op_fusion.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "utils/utils.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "operator/ops.h"
|
||||
#include "device/kernel_info.h"
|
||||
|
@ -31,9 +29,12 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allreduce_node_info, size_t start_index,
|
||||
constexpr auto kAttrDefaultGroup = "default_group";
|
||||
constexpr auto kAttrDefaultOp = "default_op";
|
||||
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index,
|
||||
size_t end_index) {
|
||||
if (end_index >= allreduce_node_info.allreduce_node.size()) {
|
||||
if (end_index >= communication_op_info.communication_op_nodes.size()) {
|
||||
MS_LOG(EXCEPTION) << "end index out of vector size";
|
||||
}
|
||||
std::vector<std::string> inputs_device_format;
|
||||
|
@ -43,7 +44,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allred
|
|||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto cnode = allreduce_node_info.allreduce_node[idx];
|
||||
auto cnode = communication_op_info.communication_op_nodes[idx];
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
|
||||
|
@ -64,14 +65,38 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allred
|
|||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
std::string GetFusionGroupKey(const AnfNodePtr &node) {
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
|
||||
if (attr_fusion == nullptr) {
|
||||
return "";
|
||||
}
|
||||
int fusion = GetValue<int>(attr_fusion);
|
||||
if (fusion == 0) {
|
||||
return "";
|
||||
}
|
||||
std::string group = kAttrDefaultGroup;
|
||||
ValuePtr attr_group = primitive->GetAttr(kAttrGroup);
|
||||
if (attr_group != nullptr) {
|
||||
group = GetValue<std::string>(attr_group);
|
||||
}
|
||||
std::string op = kAttrDefaultOp;
|
||||
ValuePtr attr_op = primitive->GetAttr(kAttrOp);
|
||||
if (attr_op != nullptr) {
|
||||
op = GetValue<std::string>(attr_op);
|
||||
}
|
||||
return group + op + std::to_string(fusion);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_info, size_t *segment_num,
|
||||
std::vector<size_t> *segment_index) const {
|
||||
bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
|
||||
std::vector<size_t> *segment_index) const {
|
||||
MS_EXCEPTION_IF_NULL(segment_num);
|
||||
MS_EXCEPTION_IF_NULL(segment_index);
|
||||
size_t allreduce_node_size = allreduce_node_info.allreduce_node.size();
|
||||
MS_LOG(INFO) << "graph all reduce node size " << allreduce_node_size;
|
||||
size_t communication_op_node_size = communication_op_info.communication_op_nodes.size();
|
||||
MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size;
|
||||
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
|
@ -82,30 +107,31 @@ bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_inf
|
|||
uint32_t last_index = 0;
|
||||
for (size_t i = 0; i < split_indices.size(); ++i) {
|
||||
uint32_t index = split_indices[i];
|
||||
if (index <= last_index || index >= allreduce_node_size) {
|
||||
MS_LOG(EXCEPTION) << "invalid allreduce split index " << i << " " << index;
|
||||
if (index <= last_index || index >= communication_op_node_size) {
|
||||
MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index;
|
||||
}
|
||||
segment_index->push_back(index);
|
||||
last_index = index;
|
||||
segments++;
|
||||
}
|
||||
if (last_index != allreduce_node_size - 1) {
|
||||
segment_index->push_back(allreduce_node_size - 1);
|
||||
if (last_index != communication_op_node_size - 1) {
|
||||
segment_index->push_back(communication_op_node_size - 1);
|
||||
segments++;
|
||||
}
|
||||
} else {
|
||||
segments = groups_;
|
||||
for (size_t i = 0; i < segments - 1; ++i) {
|
||||
segment_index->push_back((i + 1) * (allreduce_node_size / segments) - 1);
|
||||
segment_index->push_back((i + 1) * (communication_op_node_size / segments) - 1);
|
||||
}
|
||||
segment_index->push_back(allreduce_node_size - 1);
|
||||
segment_index->push_back(communication_op_node_size - 1);
|
||||
}
|
||||
|
||||
if (segments >= allreduce_node_size) {
|
||||
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments << ", allreduce_node_size=" << allreduce_node_size;
|
||||
if (segments >= communication_op_node_size) {
|
||||
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments
|
||||
<< ", communication_op_node_size=" << communication_op_node_size;
|
||||
return false;
|
||||
}
|
||||
if (segment_index->at(segments - 1) != allreduce_node_size - 1) {
|
||||
if (segment_index->at(segments - 1) != communication_op_node_size - 1) {
|
||||
MS_LOG(EXCEPTION) << "the last segment index is invalid.";
|
||||
}
|
||||
for (size_t i = 0; i < segments - 1; ++i) {
|
||||
|
@ -118,19 +144,19 @@ bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_inf
|
|||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph,
|
||||
const AllReduceInfo_t &allreduce_node_info, size_t start_index,
|
||||
size_t end_index) const {
|
||||
AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
|
||||
const CommunicationOpInfo &communication_op_info,
|
||||
size_t start_index, size_t end_index) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto prim = std::make_shared<Primitive>(kAllReduceOpName);
|
||||
auto prim = std::make_shared<Primitive>(op_name_);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> fusion_inputs = {NewValueNode(prim)};
|
||||
// get all inputs of current segment
|
||||
if (end_index >= allreduce_node_info.allreduce_node.size()) {
|
||||
if (end_index >= communication_op_info.communication_op_nodes.size()) {
|
||||
MS_LOG(EXCEPTION) << "end index out of vector size";
|
||||
}
|
||||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto cnode = allreduce_node_info.allreduce_node[idx];
|
||||
auto cnode = communication_op_info.communication_op_nodes[idx];
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
}
|
||||
|
@ -141,14 +167,14 @@ AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph,
|
|||
fused_node->set_kernel_info(kernel_info);
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto cnode = allreduce_node_info.allreduce_node[idx];
|
||||
auto cnode = communication_op_info.communication_op_nodes[idx];
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node);
|
||||
AnfAlgo::CopyNodeAttr("op", cnode, fused_node);
|
||||
AnfAlgo::CopyNodeAttr("group", cnode, fused_node);
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
}
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(allreduce_node_info, start_index, end_index);
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get());
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
|
@ -156,8 +182,8 @@ AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph,
|
|||
return fused_node;
|
||||
}
|
||||
|
||||
bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceInfo_t &allreduce_node_info,
|
||||
size_t segment_num, const std::vector<size_t> &segment_index) const {
|
||||
bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info,
|
||||
size_t segment_num, const std::vector<size_t> &segment_index) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -169,12 +195,13 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn
|
|||
start_index = end_index + 1;
|
||||
continue;
|
||||
}
|
||||
AnfNodePtr new_allreduce = CreateFusedAllReduce(func_graph, allreduce_node_info, start_index, end_index);
|
||||
// replace old allreduce with new allreduce
|
||||
AnfNodePtr new_communication_op =
|
||||
CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
|
||||
// replace old communication op with new communication op
|
||||
for (auto idx = start_index; idx <= end_index; ++idx) {
|
||||
std::vector<AnfNodePtr> tuple_getitem_input;
|
||||
tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
|
||||
tuple_getitem_input.push_back(new_allreduce);
|
||||
tuple_getitem_input.push_back(new_communication_op);
|
||||
auto index = NewValueNode(SizeToInt(idx - start_index));
|
||||
MS_EXCEPTION_IF_NULL(index);
|
||||
auto imm = std::make_shared<Int32Imm>(idx - start_index);
|
||||
|
@ -185,10 +212,10 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn
|
|||
tuple_getitem_input.push_back(index);
|
||||
AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
auto allreduce_node_item = allreduce_node_info.allreduce_node.at(idx);
|
||||
MS_EXCEPTION_IF_NULL(allreduce_node_item);
|
||||
tuple_getitem->set_abstract(allreduce_node_item->abstract());
|
||||
if (!manager->Replace(allreduce_node_item, tuple_getitem)) {
|
||||
auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx);
|
||||
MS_EXCEPTION_IF_NULL(communication_op_node_item);
|
||||
tuple_getitem->set_abstract(communication_op_node_item->abstract());
|
||||
if (!manager->Replace(communication_op_node_item, tuple_getitem)) {
|
||||
MS_LOG(EXCEPTION) << "manager replace node failed";
|
||||
}
|
||||
}
|
||||
|
@ -198,29 +225,24 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn
|
|||
return changed;
|
||||
}
|
||||
|
||||
bool AllReduceFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const float input_grad_size_num = 0.0;
|
||||
const float input_grad_time_num = 0.0;
|
||||
// divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion
|
||||
std::unordered_map<std::string, AllReduceInfo_t> candidate_groups;
|
||||
std::unordered_map<std::string, CommunicationOpInfo> candidate_groups;
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) {
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
int fusion = GetValue<int>(primitive->GetAttr("fusion"));
|
||||
if (fusion == 0) {
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == op_name_) {
|
||||
std::string key = GetFusionGroupKey(node);
|
||||
if (key.empty()) {
|
||||
continue;
|
||||
}
|
||||
std::string group = GetValue<std::string>(primitive->GetAttr("group"));
|
||||
std::string op = GetValue<std::string>(primitive->GetAttr("op"));
|
||||
std::string key = group + op + std::to_string(fusion);
|
||||
if (candidate_groups.find(key) == candidate_groups.end()) {
|
||||
AllReduceInfo_t allreduce_node_info;
|
||||
candidate_groups[key] = allreduce_node_info;
|
||||
CommunicationOpInfo communication_op_info;
|
||||
candidate_groups[key] = communication_op_info;
|
||||
}
|
||||
candidate_groups[key].allreduce_node.push_back(node->cast<CNodePtr>());
|
||||
candidate_groups[key].communication_op_nodes.push_back(node->cast<CNodePtr>());
|
||||
candidate_groups[key].input_grad_size.push_back(input_grad_size_num);
|
||||
candidate_groups[key].input_grad_time.push_back(input_grad_time_num);
|
||||
}
|
||||
|
@ -228,7 +250,7 @@ bool AllReduceFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
// split candidate group to segments according to _group class member
|
||||
bool changed = false;
|
||||
for (auto &it : candidate_groups) {
|
||||
if (it.second.allreduce_node.size() <= 1) {
|
||||
if (it.second.communication_op_nodes.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
size_t segment_num = 0;
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "pre_activate/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
struct CommunicationOpInfo {
|
||||
std::vector<CNodePtr> communication_op_nodes;
|
||||
std::vector<float> input_grad_size;
|
||||
std::vector<float> input_grad_time;
|
||||
};
|
||||
|
||||
class CommunicationOpFusion : public Pass {
|
||||
public:
|
||||
explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1)
|
||||
: Pass(name), op_name_(std::move(op_name)), groups_(groups) {}
|
||||
~CommunicationOpFusion() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, size_t segment_num,
|
||||
const std::vector<size_t> &segment_index) const;
|
||||
AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
|
||||
const CommunicationOpInfo &communication_op_info, size_t start_index,
|
||||
size_t end_index) const;
|
||||
bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
|
||||
std::vector<size_t> *segment_index) const;
|
||||
std::string op_name_;
|
||||
size_t groups_ = 1;
|
||||
};
|
||||
|
||||
class AllReduceFusion : public CommunicationOpFusion {
|
||||
public:
|
||||
explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {}
|
||||
~AllReduceFusion() override = default;
|
||||
};
|
||||
|
||||
class AllGatherFusion : public CommunicationOpFusion {
|
||||
public:
|
||||
explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {}
|
||||
~AllGatherFusion() override = default;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
|
|
@ -20,7 +20,7 @@
|
|||
#include "device/gpu/gpu_stream_assign.h"
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/common/pass_manager.h"
|
||||
#include "pre_activate/pass/allreduce_fusion.h"
|
||||
#include "pre_activate/pass/communication_op_fusion.h"
|
||||
#include "device/kernel_runtime_manager.h"
|
||||
#include "predict/predict.h"
|
||||
#include "common/utils.h"
|
||||
|
|
|
@ -155,6 +155,9 @@ constexpr auto kAttrOutputUsedNum = "output_used_num";
|
|||
constexpr auto kAttrHasBias = "has_bias";
|
||||
constexpr auto kAttrN = "n";
|
||||
constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active";
|
||||
constexpr auto kAttrFusion = "fusion";
|
||||
constexpr auto kAttrGroup = "group";
|
||||
constexpr auto kAttrOp = "op";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "ir/manager.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "pre_activate/pass/allreduce_fusion.h"
|
||||
#include "pre_activate/pass/communication_op_fusion.h"
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "pre_activate/common/pass_manager.h"
|
||||
|
|
Loading…
Reference in New Issue