!28683 fix some ubfusion pattern
Merge pull request !28683 from yuchaojie/ub_fusion2
This commit is contained in:
commit
18ce5b9817
|
@ -65,6 +65,10 @@ class TbeAdapter {
|
|||
(void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json)));
|
||||
} else {
|
||||
if (op_name == kMinimumGradOpName || op_name == kMaximumGradOpName) {
|
||||
if (inputs_list.size() < kIndex3) {
|
||||
MS_LOG(EXCEPTION) << "Op " << op_name << " should have at least " << kIndex3 << " inputs, but got "
|
||||
<< inputs_list.size();
|
||||
}
|
||||
inputs_json->push_back(inputs_list[kIndex2]);
|
||||
inputs_json->push_back(inputs_list[kIndex0]);
|
||||
inputs_json->push_back(inputs_list[kIndex1]);
|
||||
|
@ -74,6 +78,10 @@ class TbeAdapter {
|
|||
} else if (op_name == kApplyCenteredRMSPropOpName) {
|
||||
// Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map
|
||||
// TBE parameter to correspond python API parameter by latter's index using hardcode
|
||||
if (inputs_list.size() < kIndex9) {
|
||||
MS_LOG(EXCEPTION) << "Op " << op_name << " should have at least " << kIndex9 << " inputs, but got "
|
||||
<< inputs_list.size();
|
||||
}
|
||||
inputs_json->push_back(inputs_list[kIndex0]);
|
||||
inputs_json->push_back(inputs_list[kIndex1]);
|
||||
inputs_json->push_back(inputs_list[kIndex2]);
|
||||
|
@ -84,6 +92,10 @@ class TbeAdapter {
|
|||
inputs_json->push_back(inputs_list[kIndex8]);
|
||||
inputs_json->push_back(inputs_list[kIndex4]);
|
||||
} else {
|
||||
if (inputs_list.size() < kIndex2) {
|
||||
MS_LOG(EXCEPTION) << "Op " << op_name << " should have at least " << kIndex2 << " inputs, but got "
|
||||
<< inputs_list.size();
|
||||
}
|
||||
inputs_json->push_back(inputs_list[kIndex1]);
|
||||
inputs_json->push_back(inputs_list[kIndex0]);
|
||||
for (size_t i = 2; i < inputs_list.size(); ++i) {
|
||||
|
|
|
@ -213,6 +213,7 @@ bool FusionBuildTbeJsonCreator::GenInputsJson(const AnfNodePtr &anf_node, nlohma
|
|||
input_desc_list_tmp.emplace_back(optional_input_desc);
|
||||
}
|
||||
std::vector<nlohmann::json> input_desc_list;
|
||||
// TODO(jjf): error when reordered op have input not in input_nodes.
|
||||
TbeAdapter::InputOrderPass<nlohmann::json>(cnode, input_desc_list_tmp, &input_desc_list);
|
||||
(*compute_json)[kJInputDesc] = input_desc_list;
|
||||
return true;
|
||||
|
|
|
@ -33,6 +33,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
|
|||
MS_EXCEPTION_IF_NULL(relu_input);
|
||||
auto add = relu_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
if (AnfAlgo::GetInputTensorNum(cnode) != (ELTWISE_DOUBLE_IN_INPUT_SIZE - 1)) {
|
||||
return;
|
||||
}
|
||||
auto tuple_getitem = add->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) {
|
||||
|
@ -62,7 +65,8 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) {
|
||||
AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE &&
|
||||
AnfAlgo::GetInputTensorNum(cnode) == (ELTWISE_INPUT_SIZE - 1)) {
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) {
|
||||
|
|
|
@ -29,7 +29,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
|
|||
mindspore::HashSet<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input) && AnfAlgo::GetCNodeName(cnode) == kAddNOpName) {
|
||||
(void)record.insert(eltwise_input);
|
||||
} else {
|
||||
return;
|
||||
|
@ -76,7 +76,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const sess
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
(cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) {
|
||||
AnfAlgo::GetCNodeName(cnode) == kReluGradV2OpName) {
|
||||
MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,18 +30,18 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const int8_t MAX_ELTWISE_NUM = 3;
|
||||
const int8_t MIN_ELTWISE_SIZE = 2;
|
||||
const int8_t ELTWISE_INPUT_SIZE = 2;
|
||||
const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3;
|
||||
const int8_t ELTWISE_SINGLE_OUTPUT_SIZE = 1;
|
||||
const int8_t ELTWISE_DOUBLE_OUTPUT_SIZE = 2;
|
||||
const int8_t CONV_DOUBLE_IN_INPUT_SIZE = 3;
|
||||
const int8_t CONV_QUART_IN_INPUT_SIZE = 5;
|
||||
const int8_t ELTWISE_USE = 1;
|
||||
const int8_t ELTWISE_MULTI_USE = 2;
|
||||
const int8_t MAX_ELTWISE_SIZE = 6;
|
||||
const int8_t MULTI_ELTWISE_SIZE = 4;
|
||||
const size_t MAX_ELTWISE_NUM = 3;
|
||||
const size_t MIN_ELTWISE_SIZE = 2;
|
||||
const size_t ELTWISE_INPUT_SIZE = 2;
|
||||
const size_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3;
|
||||
const size_t ELTWISE_SINGLE_OUTPUT_SIZE = 1;
|
||||
const size_t ELTWISE_DOUBLE_OUTPUT_SIZE = 2;
|
||||
const size_t CONV_DOUBLE_IN_INPUT_SIZE = 3;
|
||||
const size_t CONV_QUART_IN_INPUT_SIZE = 5;
|
||||
const size_t ELTWISE_USE = 1;
|
||||
const size_t ELTWISE_MULTI_USE = 2;
|
||||
const size_t MAX_ELTWISE_SIZE = 6;
|
||||
const size_t MULTI_ELTWISE_SIZE = 4;
|
||||
|
||||
constexpr int64_t kBNTrainingUpdateOutputUsedTotalNum = 5;
|
||||
constexpr int64_t kConvOutputUsedTotalNum = 4;
|
||||
|
|
Loading…
Reference in New Issue