forked from mindspore-Ecosystem/mindspore
fix confusion_softmax_grad_rule pass
This commit is contained in:
parent
507b63ea20
commit
14df771175
|
@ -47,7 +47,7 @@ void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_n
|
|||
|
||||
const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const {
|
||||
return VectorRef(
|
||||
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})});
|
||||
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input1_, input0_})})});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
|
|
|
@ -41,7 +41,7 @@ def test_confusion_softmax_grad_rule(tag):
|
|||
|
||||
@fns
|
||||
def before(input0, input1):
|
||||
res = mul(input0, input1)
|
||||
res = mul(input1, input0)
|
||||
# input axis will be convert to attr in ConstructKernelGraph step
|
||||
res = reduce_sum(res, axis)
|
||||
res = sub(input0, res)
|
||||
|
|
Loading…
Reference in New Issue