PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3

This commit is contained in:
nomindcarry 2023-02-09 10:22:48 +08:00
parent b699c4f939
commit 445702c063
2 changed files with 67 additions and 0 deletions

View File

@ -31,6 +31,7 @@ namespace mindspore {
namespace opt {
namespace {
constexpr size_t kCastInputNum = 2;
constexpr size_t kDependInputNum = 2;
constexpr auto softmax_output_shape_size = 2;
constexpr auto kAttrDepth = "depth";
constexpr auto kAttrMultiples = "multiples";
@ -411,6 +412,20 @@ CNodePtr CreateCast(const FuncGraphPtr &graph, const CNodePtr &cast, const AnfNo
return new_cast_node;
}
CNodePtr CreateDepend(const FuncGraphPtr &graph, const CNodePtr &depend, const AnfNodePtr &depend_input,
const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(depend);
MS_EXCEPTION_IF_NULL(depend_input);
std::vector<AnfNodePtr> new_depend_inputs = {depend->input(kAnfPrimitiveIndex), depend_input, depend->input(kIndex2)};
auto new_depend_node = pass.NewCNode(new_depend_inputs, graph);
MS_EXCEPTION_IF_NULL(new_depend_node);
new_depend_node->set_scope(depend->scope());
new_depend_node->set_abstract(depend->abstract());
return new_depend_node;
}
bool IsSparseSoftmaxCrossEntropyWithLogitsGrad(const CNodePtr &sparse, const string &pass_name) {
MS_EXCEPTION_IF_NULL(sparse);
if (common::AnfAlgo::GetCNodeName(sparse) != kSparseSoftmaxCrossEntropyWithLogitsOpName) {
@ -696,5 +711,48 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::P
new_mul_node->set_abstract(mul_node->abstract());
return new_mul_node;
}
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3::DefinePattern() const {
VarPtr x1 = std::make_shared<Var>();
VarPtr x2 = std::make_shared<Var>();
VarPtr x3 = std::make_shared<Var>();
VarPtr x4 = std::make_shared<Var>();
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
VectorRef depend({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3});
return VectorRef({prim::kPrimMul, depend, x4});
}
const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3::Process(const FuncGraphPtr &graph,
const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto mul_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul_node);
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
auto depend_node = mul_node->input(kIndex1);
auto depend_cnode = depend_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
CheckCNodeInputSize(depend_cnode, kDependInputNum);
auto sparse_softmax_node = depend_cnode->input(kIndex1);
bool is_sp_grad_flag = true;
std::vector<AnfNodePtr> softmax_node_outputs;
auto expand_dims_node =
CreateMulInput(graph, mul_node, sparse_softmax_node, name(), *this, &softmax_node_outputs, &is_sp_grad_flag);
if (!is_sp_grad_flag) {
return nullptr;
}
auto new_depend = CreateDepend(graph, depend_cnode, softmax_node_outputs[1], *this);
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), new_depend,
expand_dims_node};
auto new_mul_node = NewCNode(new_mul_inputs, graph);
MS_EXCEPTION_IF_NULL(new_mul_node);
new_mul_node->set_scope(mul_node->scope());
new_mul_node->set_abstract(mul_node->abstract());
return new_mul_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -77,6 +77,15 @@ class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public Patt
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
};
class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3 : public PatternProcessPass {
public:
explicit PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3(bool multigraph = true)
: PatternProcessPass("pynative_grad_sparse_softmax_cross_entropy_with_logits_unify_mindir_v3", multigraph) {}
~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H