add circle check in ub fusion
This commit is contained in:
parent
03935de4bf
commit
de843b45b6
|
@ -63,10 +63,6 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
|
||||||
auto bnupdate = getitem->input(1);
|
auto bnupdate = getitem->input(1);
|
||||||
MS_EXCEPTION_IF_NULL(bnupdate);
|
MS_EXCEPTION_IF_NULL(bnupdate);
|
||||||
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
|
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
|
||||||
if (cnode->size() == ELTWISE_DOUBLE_IN_INPUT_SIZE &&
|
|
||||||
IsDepend(kernel_graph, cnode->input(2), {relu_input, bnupdate})) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
|
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
|
||||||
for (auto out_getitem : manager->node_users()[bnupdate]) {
|
for (auto out_getitem : manager->node_users()[bnupdate]) {
|
||||||
MS_EXCEPTION_IF_NULL(out_getitem.first);
|
MS_EXCEPTION_IF_NULL(out_getitem.first);
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "base/core_ops.h"
|
#include "base/core_ops.h"
|
||||||
#include "runtime/device/kernel_info.h"
|
#include "runtime/device/kernel_info.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -353,6 +354,28 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RemoveCircle(const session::KernelGraph &kernel_graph,
|
||||||
|
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
|
||||||
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
||||||
|
for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) {
|
||||||
|
bool has_circle = false;
|
||||||
|
for (const auto &inp : fusion_info.inputs_list) {
|
||||||
|
if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) {
|
||||||
|
has_circle = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_circle) {
|
||||||
|
buffer_fusion_infos->erase(fusion_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
||||||
|
@ -361,6 +384,9 @@ void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
||||||
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
|
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
|
||||||
GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos);
|
GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos);
|
||||||
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
|
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
|
||||||
|
// Remove the fusion infos which will produce a circle if do fusion
|
||||||
|
RemoveCircle(*kernel_graph, buffer_fusion_infos);
|
||||||
|
|
||||||
for (auto &buffer_fusion_info : *buffer_fusion_infos) {
|
for (auto &buffer_fusion_info : *buffer_fusion_infos) {
|
||||||
buffer_fusion_info.second.kernel_build_info =
|
buffer_fusion_info.second.kernel_build_info =
|
||||||
CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list);
|
CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list);
|
||||||
|
|
|
@ -49,23 +49,6 @@ std::vector<int64_t> Convert2Long(const std::vector<size_t> &v) {
|
||||||
|
|
||||||
bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
|
bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
std::vector<AnfNodePtr> node_list = TopoSort(graph.get_return());
|
|
||||||
std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map;
|
|
||||||
for (auto &nd : node_list) {
|
|
||||||
MS_EXCEPTION_IF_NULL(nd);
|
|
||||||
if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) {
|
|
||||||
auto control_depend = nd->cast<CNodePtr>();
|
|
||||||
auto prior_node = control_depend->input(kControlDependPriorIndex);
|
|
||||||
auto behind_node = control_depend->input(kControlDependBehindIndex);
|
|
||||||
auto it = control_depend_map.find(behind_node);
|
|
||||||
if (it == control_depend_map.end()) {
|
|
||||||
control_depend_map[behind_node] = std::set<AnfNodePtr>{prior_node};
|
|
||||||
} else {
|
|
||||||
it->second.insert(prior_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FuncGraphManagerPtr manager = graph.manager();
|
FuncGraphManagerPtr manager = graph.manager();
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
|
||||||
|
@ -88,10 +71,6 @@ bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<
|
||||||
auto inputs = cnode->inputs();
|
auto inputs = cnode->inputs();
|
||||||
(void)todo.insert(todo.end(), inputs.begin(), inputs.end());
|
(void)todo.insert(todo.end(), inputs.begin(), inputs.end());
|
||||||
}
|
}
|
||||||
auto it = control_depend_map.find(nd);
|
|
||||||
if (it != control_depend_map.end()) {
|
|
||||||
(void)todo.insert(todo.end(), it->second.begin(), it->second.end());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,12 +61,13 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
|
||||||
MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost.count() << " us";
|
MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost.count() << " us";
|
||||||
#else
|
#else
|
||||||
(void)gettimeofday(&end_time, nullptr);
|
(void)gettimeofday(&end_time, nullptr);
|
||||||
const uint64_t kUSecondInSecond = 1000000;
|
// time unit: us
|
||||||
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
uint64_t cost = 1000000 * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
||||||
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
||||||
MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost << " us";
|
MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost << " us";
|
||||||
#endif
|
#endif
|
||||||
if (save_graphs) {
|
static const auto enable_dump = (common::GetEnv("ENV_NO_DUMP_BE_PASS_IR") != "1");
|
||||||
|
if (save_graphs && enable_dump) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "verbose_ir_files"
|
oss << "verbose_ir_files"
|
||||||
<< "/";
|
<< "/";
|
||||||
|
|
Loading…
Reference in New Issue