!28683 fix some ubfusion pattern

Merge pull request !28683 from yuchaojie/ub_fusion2
This commit is contained in:
i-robot 2022-01-08 08:28:09 +00:00 committed by Gitee
commit 18ce5b9817
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 32 additions and 15 deletions

View File

@ -65,6 +65,10 @@ class TbeAdapter {
(void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json)));
} else {
if (op_name == kMinimumGradOpName || op_name == kMaximumGradOpName) {
if (inputs_list.size() < kIndex3) {
MS_LOG(EXCEPTION) << "Op " << op_name << " should have at least " << kIndex3 << " inputs, but got "
<< inputs_list.size();
}
inputs_json->push_back(inputs_list[kIndex2]);
inputs_json->push_back(inputs_list[kIndex0]);
inputs_json->push_back(inputs_list[kIndex1]);
@ -74,6 +78,10 @@ class TbeAdapter {
} else if (op_name == kApplyCenteredRMSPropOpName) {
// Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map
// TBE parameter to correspond python API parameter by latter's index using hardcode
if (inputs_list.size() < kIndex9) {
MS_LOG(EXCEPTION) << "Op " << op_name << " should have at least " << kIndex9 << " inputs, but got "
<< inputs_list.size();
}
inputs_json->push_back(inputs_list[kIndex0]);
inputs_json->push_back(inputs_list[kIndex1]);
inputs_json->push_back(inputs_list[kIndex2]);
@ -84,6 +92,10 @@ class TbeAdapter {
inputs_json->push_back(inputs_list[kIndex8]);
inputs_json->push_back(inputs_list[kIndex4]);
} else {
if (inputs_list.size() < kIndex2) {
MS_LOG(EXCEPTION) << "Op " << op_name << " should have at least " << kIndex2 << " inputs, but got "
<< inputs_list.size();
}
inputs_json->push_back(inputs_list[kIndex1]);
inputs_json->push_back(inputs_list[kIndex0]);
for (size_t i = 2; i < inputs_list.size(); ++i) {

View File

@ -213,6 +213,7 @@ bool FusionBuildTbeJsonCreator::GenInputsJson(const AnfNodePtr &anf_node, nlohma
input_desc_list_tmp.emplace_back(optional_input_desc);
}
std::vector<nlohmann::json> input_desc_list;
// TODO(jjf): error when reordered op have input not in input_nodes.
TbeAdapter::InputOrderPass<nlohmann::json>(cnode, input_desc_list_tmp, &input_desc_list);
(*compute_json)[kJInputDesc] = input_desc_list;
return true;

View File

@ -33,6 +33,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
MS_EXCEPTION_IF_NULL(relu_input);
auto add = relu_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add);
if (AnfAlgo::GetInputTensorNum(cnode) != (ELTWISE_DOUBLE_IN_INPUT_SIZE - 1)) {
return;
}
auto tuple_getitem = add->input(kIndex1);
MS_EXCEPTION_IF_NULL(tuple_getitem);
if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) {
@ -62,7 +65,8 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) {
AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE &&
AnfAlgo::GetInputTensorNum(cnode) == (ELTWISE_INPUT_SIZE - 1)) {
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) {

View File

@ -29,7 +29,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
mindspore::HashSet<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input) && AnfAlgo::GetCNodeName(cnode) == kAddNOpName) {
(void)record.insert(eltwise_input);
} else {
return;
@ -76,7 +76,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const sess
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
(cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) {
AnfAlgo::GetCNodeName(cnode) == kReluGradV2OpName) {
MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion);
}
}

View File

@ -30,18 +30,18 @@
namespace mindspore {
namespace opt {
const int8_t MAX_ELTWISE_NUM = 3;
const int8_t MIN_ELTWISE_SIZE = 2;
const int8_t ELTWISE_INPUT_SIZE = 2;
const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3;
const int8_t ELTWISE_SINGLE_OUTPUT_SIZE = 1;
const int8_t ELTWISE_DOUBLE_OUTPUT_SIZE = 2;
const int8_t CONV_DOUBLE_IN_INPUT_SIZE = 3;
const int8_t CONV_QUART_IN_INPUT_SIZE = 5;
const int8_t ELTWISE_USE = 1;
const int8_t ELTWISE_MULTI_USE = 2;
const int8_t MAX_ELTWISE_SIZE = 6;
const int8_t MULTI_ELTWISE_SIZE = 4;
const size_t MAX_ELTWISE_NUM = 3;
const size_t MIN_ELTWISE_SIZE = 2;
const size_t ELTWISE_INPUT_SIZE = 2;
const size_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3;
const size_t ELTWISE_SINGLE_OUTPUT_SIZE = 1;
const size_t ELTWISE_DOUBLE_OUTPUT_SIZE = 2;
const size_t CONV_DOUBLE_IN_INPUT_SIZE = 3;
const size_t CONV_QUART_IN_INPUT_SIZE = 5;
const size_t ELTWISE_USE = 1;
const size_t ELTWISE_MULTI_USE = 2;
const size_t MAX_ELTWISE_SIZE = 6;
const size_t MULTI_ELTWISE_SIZE = 4;
constexpr int64_t kBNTrainingUpdateOutputUsedTotalNum = 5;
constexpr int64_t kConvOutputUsedTotalNum = 4;