!15756 add conv2d + add +re lu & add+reluv2 ub fusion

From: @jjfeing
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2021-04-28 11:02:35 +08:00 committed by Gitee
commit 75504ad378
4 changed files with 12 additions and 5 deletions

View File

@ -950,7 +950,7 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_
}
(*output_desc)[kJFormat] = format;
// special node
if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) {
if ((fusion_data_type == kFusionAddN || fusion_data_type == kFusionAdd) && shape.size() == 5) {
std::vector<size_t> spec_shape = {};
spec_shape.emplace_back(shape[0]);
spec_shape.emplace_back(shape[1]);
@ -995,7 +995,8 @@ void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr<mindspore::AnfNod
bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name,
const std::vector<mindspore::AnfNodePtr> &reorder_layer,
std::map<const AnfNodePtr, FusionDataType> *spec_data_input) {
if ((op_name == kReluGradV2OpName || op_name == kAddNOpName) && reorder_layer.empty()) {
if ((op_name == kReluGradV2OpName || op_name == kAddNOpName || op_name == kTensorAddOpName) &&
reorder_layer.empty()) {
MS_LOG(INFO) << "Fusion error: node(" << op_name << " )'s input is null. ";
return false;
}
@ -1005,6 +1006,8 @@ bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name,
for (const auto &it : reorder_layer) {
(*spec_data_input)[it] = kFusionAddN;
}
} else if (op_name == kTensorAddOpName) {
(*spec_data_input)[reorder_layer[0]] = kFusionAdd;
}
return true;
}
@ -1020,7 +1023,7 @@ bool TbeKernelBuild::GetInputLayers(const std::vector<mindspore::AnfNodePtr> &in
MS_EXCEPTION_IF_NULL(spec_data_input);
auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) {
auto op_name = AnfAlgo::GetCNodeName(it);
return op_name == kConv2DBackpropInputOpName;
return (op_name == kConv2DBackpropInputOpName || op_name == kConv2DOpName);
});
bool need_spec = (result != compute_nodes.end());
size_t input_size = 0;

View File

@ -34,7 +34,7 @@ namespace kernel {
// kernel operate type used for generate json
class TbeKernelBuild {
enum FusionDataType { kFusionNormal = 0, kFusionAddN, kFusionReLUGradV2 };
enum FusionDataType { kFusionNormal = 0, kFusionAddN, kFusionReLUGradV2, kFusionAdd };
public:
static bool GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,

View File

@ -66,7 +66,8 @@ void ConvDoubleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE &&
!AnfAlgo::CheckPrimitiveType(node, prim::kPrimReluV2)) {
MatchConvDoubleInEltwise(cnode, kernel_graph, candidate_fusion);
}
}

View File

@ -42,6 +42,9 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
MS_EXCEPTION_IF_NULL(input_cnode);
eltwise_input = input_cnode->input(1);
}
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
}
if (record.size() < MIN_ELTWISE_SIZE) {
return;
}