forked from mindspore-Ecosystem/mindspore
PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3
This commit is contained in:
parent
b699c4f939
commit
445702c063
|
@ -31,6 +31,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kCastInputNum = 2;
|
||||
constexpr size_t kDependInputNum = 2;
|
||||
constexpr auto softmax_output_shape_size = 2;
|
||||
constexpr auto kAttrDepth = "depth";
|
||||
constexpr auto kAttrMultiples = "multiples";
|
||||
|
@ -411,6 +412,20 @@ CNodePtr CreateCast(const FuncGraphPtr &graph, const CNodePtr &cast, const AnfNo
|
|||
return new_cast_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateDepend(const FuncGraphPtr &graph, const CNodePtr &depend, const AnfNodePtr &depend_input,
|
||||
const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
MS_EXCEPTION_IF_NULL(depend_input);
|
||||
|
||||
std::vector<AnfNodePtr> new_depend_inputs = {depend->input(kAnfPrimitiveIndex), depend_input, depend->input(kIndex2)};
|
||||
auto new_depend_node = pass.NewCNode(new_depend_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_depend_node);
|
||||
new_depend_node->set_scope(depend->scope());
|
||||
new_depend_node->set_abstract(depend->abstract());
|
||||
return new_depend_node;
|
||||
}
|
||||
|
||||
bool IsSparseSoftmaxCrossEntropyWithLogitsGrad(const CNodePtr &sparse, const string &pass_name) {
|
||||
MS_EXCEPTION_IF_NULL(sparse);
|
||||
if (common::AnfAlgo::GetCNodeName(sparse) != kSparseSoftmaxCrossEntropyWithLogitsOpName) {
|
||||
|
@ -696,5 +711,48 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::P
|
|||
new_mul_node->set_abstract(mul_node->abstract());
|
||||
return new_mul_node;
|
||||
}
|
||||
|
||||
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3::DefinePattern() const {
|
||||
VarPtr x1 = std::make_shared<Var>();
|
||||
VarPtr x2 = std::make_shared<Var>();
|
||||
VarPtr x3 = std::make_shared<Var>();
|
||||
VarPtr x4 = std::make_shared<Var>();
|
||||
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
||||
VectorRef depend({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3});
|
||||
return VectorRef({prim::kPrimMul, depend, x4});
|
||||
}
|
||||
|
||||
const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3::Process(const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto mul_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
|
||||
|
||||
auto depend_node = mul_node->input(kIndex1);
|
||||
auto depend_cnode = depend_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
CheckCNodeInputSize(depend_cnode, kDependInputNum);
|
||||
|
||||
auto sparse_softmax_node = depend_cnode->input(kIndex1);
|
||||
bool is_sp_grad_flag = true;
|
||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
auto expand_dims_node =
|
||||
CreateMulInput(graph, mul_node, sparse_softmax_node, name(), *this, &softmax_node_outputs, &is_sp_grad_flag);
|
||||
if (!is_sp_grad_flag) {
|
||||
return nullptr;
|
||||
}
|
||||
auto new_depend = CreateDepend(graph, depend_cnode, softmax_node_outputs[1], *this);
|
||||
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), new_depend,
|
||||
expand_dims_node};
|
||||
auto new_mul_node = NewCNode(new_mul_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_mul_node);
|
||||
new_mul_node->set_scope(mul_node->scope());
|
||||
new_mul_node->set_abstract(mul_node->abstract());
|
||||
return new_mul_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -77,6 +77,15 @@ class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public Patt
|
|||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3 : public PatternProcessPass {
|
||||
public:
|
||||
explicit PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3(bool multigraph = true)
|
||||
: PatternProcessPass("pynative_grad_sparse_softmax_cross_entropy_with_logits_unify_mindir_v3", multigraph) {}
|
||||
~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H
|
||||
|
|
Loading…
Reference in New Issue