fix confusion_softmax_grad_rule pass

This commit is contained in:
huanghui 2020-04-23 16:45:06 +08:00
parent 507b63ea20
commit 14df771175
2 changed files with 2 additions and 2 deletions

View File

@ -47,7 +47,7 @@ void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_n
const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const {
return VectorRef(
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})});
{prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input1_, input0_})})});
}
const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,

View File

@ -41,7 +41,7 @@ def test_confusion_softmax_grad_rule(tag):
@fns
def before(input0, input1):
res = mul(input0, input1)
res = mul(input1, input0)
# input axis will be convert to attr in ConstructKernelGraph step
res = reduce_sum(res, axis)
res = sub(input0, res)