forked from mindspore-Ecosystem/mindspore
filter sparse pattern
This commit is contained in:
parent
ecf151929f
commit
e733bce265
|
@ -36,6 +36,16 @@ constexpr auto kAttrDepth = "depth";
|
|||
constexpr auto kAttrMultiples = "multiples";
|
||||
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
||||
|
||||
bool CheckMulInputShapeEqual(const CNodePtr &mul_node) {
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
if (!IsPrimitiveCNode(mul_node, prim::kPrimMul)) {
|
||||
MS_LOG(EXCEPTION) << "Node is not mul, but is " << mul_node->fullname_with_scope();
|
||||
}
|
||||
auto input1_shape = common::AnfAlgo::GetOutputInferShape(mul_node->input(kIndex1), 0);
|
||||
auto input2_shape = common::AnfAlgo::GetOutputInferShape(mul_node->input(kIndex2), 0);
|
||||
return input1_shape == input2_shape;
|
||||
}
|
||||
|
||||
ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
auto new_node = std::make_shared<ValueNode>(value_ptr);
|
||||
|
@ -549,6 +559,10 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
|
|||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
|
||||
|
||||
if (CheckMulInputShapeEqual(mul_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto depend_node = GetDependNode(mul_node);
|
||||
auto sparse_softmax_node = GetSparseNode(depend_node, kIndex2);
|
||||
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
|
||||
|
|
Loading…
Reference in New Issue