filter sparse pattern

This commit is contained in:
jiaorui 2022-08-09 16:10:14 +08:00
parent ecf151929f
commit e733bce265
1 changed files with 14 additions and 0 deletions

View File

@ -36,6 +36,16 @@ constexpr auto kAttrDepth = "depth";
constexpr auto kAttrMultiples = "multiples";
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
bool CheckMulInputShapeEqual(const CNodePtr &mul_node) {
MS_EXCEPTION_IF_NULL(mul_node);
if (!IsPrimitiveCNode(mul_node, prim::kPrimMul)) {
MS_LOG(EXCEPTION) << "Node is not mul, but is " << mul_node->fullname_with_scope();
}
auto input1_shape = common::AnfAlgo::GetOutputInferShape(mul_node->input(kIndex1), 0);
auto input2_shape = common::AnfAlgo::GetOutputInferShape(mul_node->input(kIndex2), 0);
return input1_shape == input2_shape;
}
ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
MS_EXCEPTION_IF_NULL(value_ptr);
auto new_node = std::make_shared<ValueNode>(value_ptr);
@ -549,6 +559,10 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
MS_EXCEPTION_IF_NULL(mul_node);
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
if (CheckMulInputShapeEqual(mul_node)) {
return nullptr;
}
auto depend_node = GetDependNode(mul_node);
auto sparse_softmax_node = GetSparseNode(depend_node, kIndex2);
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);