forked from mindspore-Ecosystem/mindspore
reselect the domask's child node after rectify the node domask
This commit is contained in:
parent
17319d8dfd
commit
60e3849178
|
@ -119,6 +119,8 @@ bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_
|
||||||
|
|
||||||
bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }
|
bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }
|
||||||
|
|
||||||
|
bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); }
|
||||||
|
|
||||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
|
void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||||
kernel_build_info_->kernel_type_ = kernel_type;
|
kernel_build_info_->kernel_type_ = kernel_type;
|
||||||
|
|
|
@ -85,6 +85,8 @@ class KernelBuildInfo {
|
||||||
|
|
||||||
bool operator==(const KernelBuildInfo &other) const;
|
bool operator==(const KernelBuildInfo &other) const;
|
||||||
|
|
||||||
|
bool operator!=(const KernelBuildInfo &other) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static auto constexpr kInvalidFormat = "InvalidFormat";
|
static auto constexpr kInvalidFormat = "InvalidFormat";
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "kernel/common_utils.h"
|
#include "kernel/common_utils.h"
|
||||||
#include "utils/context/ms_context.h"
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "pre_activate/common/helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -50,16 +51,11 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
std::vector<CNodePtr> do_mask_node_list;
|
std::vector<CNodePtr> do_mask_node_list;
|
||||||
auto manager = graph->manager();
|
auto gen_mask_output_nodes = GetRealNodeUsedList(graph, cnode);
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(gen_mask_output_nodes);
|
||||||
auto node_map = manager->node_users();
|
for (const auto &output_node : *gen_mask_output_nodes) {
|
||||||
auto iter = node_map.find(node);
|
|
||||||
if (iter == node_map.end()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Cannot find the node " << node->DebugString() << " in the graph manager!";
|
|
||||||
}
|
|
||||||
auto gen_mask_output_nodes = iter->second;
|
|
||||||
for (const auto &output_node : gen_mask_output_nodes) {
|
|
||||||
if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) {
|
if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(output_node.first);
|
||||||
auto output_cnode = output_node.first->cast<CNodePtr>();
|
auto output_cnode = output_node.first->cast<CNodePtr>();
|
||||||
do_mask_node_list.push_back(output_cnode);
|
do_mask_node_list.push_back(output_cnode);
|
||||||
}
|
}
|
||||||
|
@ -76,11 +72,12 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
|
||||||
<< " GenMask " << node->DebugString();
|
<< " GenMask " << node->DebugString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
RectifyKernelInfo(do_mask_node_list);
|
RectifyKernelInfo(do_mask_node_list, graph);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const {
|
void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
|
||||||
|
const FuncGraphPtr &graph) const {
|
||||||
std::map<std::string, size_t> format_counter;
|
std::map<std::string, size_t> format_counter;
|
||||||
std::string special_format;
|
std::string special_format;
|
||||||
std::string convert_format;
|
std::string convert_format;
|
||||||
|
@ -94,17 +91,6 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
|
||||||
} else {
|
} else {
|
||||||
format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1;
|
format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1;
|
||||||
}
|
}
|
||||||
// if has two or more special format we need change all domask's format to default that can avoid insert more
|
|
||||||
// transdata
|
|
||||||
if (format_counter.size() > 2) {
|
|
||||||
convert_format = kOpFormat_DEFAULT;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end() &&
|
|
||||||
special_format != do_mask_data_format) {
|
|
||||||
convert_format = kOpFormat_DEFAULT;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (format_counter.size() == 1) {
|
if (format_counter.size() == 1) {
|
||||||
return;
|
return;
|
||||||
|
@ -112,17 +98,23 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
|
||||||
if (convert_format.empty()) {
|
if (convert_format.empty()) {
|
||||||
convert_format = GetConvertFormat(format_counter);
|
convert_format = GetConvertFormat(format_counter);
|
||||||
}
|
}
|
||||||
RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format);
|
RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const {
|
std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const {
|
||||||
std::string convert_format;
|
std::string convert_format = kOpFormat_DEFAULT;
|
||||||
const size_t counter = 0;
|
size_t counter = 0;
|
||||||
|
if (format_counter.size() > 2) {
|
||||||
|
return kOpFormat_DEFAULT;
|
||||||
|
}
|
||||||
|
if (format_counter.size() == 2 && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) {
|
||||||
|
return kOpFormat_DEFAULT;
|
||||||
|
}
|
||||||
for (const auto &iter : format_counter) {
|
for (const auto &iter : format_counter) {
|
||||||
if (counter < iter.second) {
|
if (counter < iter.second) {
|
||||||
convert_format = iter.first;
|
convert_format = iter.first;
|
||||||
}
|
counter = iter.second;
|
||||||
if (counter == iter.second && kHWSpecialFormatSet.find(convert_format) == kHWSpecialFormatSet.end()) {
|
} else if (counter == iter.second && kHWSpecialFormatSet.find(iter.first) != kHWSpecialFormatSet.end()) {
|
||||||
convert_format = iter.first;
|
convert_format = iter.first;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -130,13 +122,17 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
|
||||||
}
|
}
|
||||||
|
|
||||||
void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
|
void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
|
||||||
const std::string &format) const {
|
const std::string &format,
|
||||||
|
const FuncGraphPtr &graph) const {
|
||||||
for (const auto &do_mask : do_mask_node_list) {
|
for (const auto &do_mask : do_mask_node_list) {
|
||||||
auto builder =
|
if (AnfAlgo::GetInputFormat(do_mask, 0) != format) {
|
||||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(do_mask));
|
auto builder =
|
||||||
builder->SetInputFormat(format, 0);
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(do_mask));
|
||||||
builder->SetOutputFormat(format, 0);
|
builder->SetInputFormat(format, 0);
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get());
|
builder->SetOutputFormat(format, 0);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get());
|
||||||
|
ReSelecChildNodeKernelInfo(do_mask, graph);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,5 +155,30 @@ AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const Anf
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto output_node_list = GetRealNodeUsedList(graph, cnode);
|
||||||
|
MS_EXCEPTION_IF_NULL(output_node_list);
|
||||||
|
for (const auto &out_node_info : *output_node_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(out_node_info.first);
|
||||||
|
auto out_node = out_node_info.first->cast<CNodePtr>();
|
||||||
|
if (AnfAlgo::IsRealKernel(out_node_info.first)) {
|
||||||
|
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node);
|
||||||
|
kernel_selecter->SelectKernel(out_node);
|
||||||
|
auto new_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_build_info);
|
||||||
|
MS_EXCEPTION_IF_NULL(ori_build_info);
|
||||||
|
if ((*new_build_info) != (*ori_build_info)) {
|
||||||
|
ReSelecChildNodeKernelInfo(out_node, graph);
|
||||||
|
}
|
||||||
|
} else if (AnfAlgo::GetCNodeName(out_node) == prim::kPrimTupleGetItem->name() ||
|
||||||
|
AnfAlgo::GetCNodeName(out_node) == prim::kPrimDepend->name()) {
|
||||||
|
ReSelecChildNodeKernelInfo(out_node, graph);
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "Reselected the node " << cnode->DebugString() << " failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,23 +19,28 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "pre_activate/common/optimizer.h"
|
#include "pre_activate/common/optimizer.h"
|
||||||
|
#include "pre_activate/ascend/ascend_helper.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class RectifyDoMaskKernelInfo : public PatternProcessPass {
|
class RectifyDoMaskKernelInfo : public PatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit RectifyDoMaskKernelInfo(bool multigraph = true)
|
explicit RectifyDoMaskKernelInfo(bool multigraph = true)
|
||||||
: PatternProcessPass("batch_norm_bert_fission", multigraph) {}
|
: PatternProcessPass("batch_norm_bert_fission", multigraph), kernel_selecter(std::make_shared<KernelSelect>()) {}
|
||||||
~RectifyDoMaskKernelInfo() override = default;
|
~RectifyDoMaskKernelInfo() override = default;
|
||||||
const BaseRef DefinePattern() const override;
|
const BaseRef DefinePattern() const override;
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const;
|
void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const FuncGraphPtr &graph) const;
|
||||||
AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const;
|
AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const;
|
||||||
std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const;
|
std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const;
|
||||||
void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format) const;
|
void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format,
|
||||||
|
const FuncGraphPtr &graph) const;
|
||||||
|
void ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const;
|
||||||
|
KernelSelectPtr kernel_selecter;
|
||||||
};
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue