forked from mindspore-Ecosystem/mindspore
!48624 SparseSoftmaxCrossEntropyWithLogits ADD DefinePattern V3
Merge pull request !48624 from nomindcarry/master
This commit is contained in:
commit
43ffa26a3a
|
@ -31,6 +31,7 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kCastInputNum = 2;
|
constexpr size_t kCastInputNum = 2;
|
||||||
|
constexpr size_t kDependInputNum = 2;
|
||||||
constexpr auto softmax_output_shape_size = 2;
|
constexpr auto softmax_output_shape_size = 2;
|
||||||
constexpr auto kAttrDepth = "depth";
|
constexpr auto kAttrDepth = "depth";
|
||||||
constexpr auto kAttrMultiples = "multiples";
|
constexpr auto kAttrMultiples = "multiples";
|
||||||
|
@ -411,6 +412,20 @@ CNodePtr CreateCast(const FuncGraphPtr &graph, const CNodePtr &cast, const AnfNo
|
||||||
return new_cast_node;
|
return new_cast_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CNodePtr CreateDepend(const FuncGraphPtr &graph, const CNodePtr &depend, const AnfNodePtr &depend_input,
|
||||||
|
const PatternProcessPass &pass) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(depend);
|
||||||
|
MS_EXCEPTION_IF_NULL(depend_input);
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> new_depend_inputs = {depend->input(kAnfPrimitiveIndex), depend_input, depend->input(kIndex2)};
|
||||||
|
auto new_depend_node = pass.NewCNode(new_depend_inputs, graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_depend_node);
|
||||||
|
new_depend_node->set_scope(depend->scope());
|
||||||
|
new_depend_node->set_abstract(depend->abstract());
|
||||||
|
return new_depend_node;
|
||||||
|
}
|
||||||
|
|
||||||
bool IsSparseSoftmaxCrossEntropyWithLogitsGrad(const CNodePtr &sparse, const string &pass_name) {
|
bool IsSparseSoftmaxCrossEntropyWithLogitsGrad(const CNodePtr &sparse, const string &pass_name) {
|
||||||
MS_EXCEPTION_IF_NULL(sparse);
|
MS_EXCEPTION_IF_NULL(sparse);
|
||||||
if (common::AnfAlgo::GetCNodeName(sparse) != kSparseSoftmaxCrossEntropyWithLogitsOpName) {
|
if (common::AnfAlgo::GetCNodeName(sparse) != kSparseSoftmaxCrossEntropyWithLogitsOpName) {
|
||||||
|
@ -696,5 +711,48 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::P
|
||||||
new_mul_node->set_abstract(mul_node->abstract());
|
new_mul_node->set_abstract(mul_node->abstract());
|
||||||
return new_mul_node;
|
return new_mul_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3::DefinePattern() const {
|
||||||
|
VarPtr x1 = std::make_shared<Var>();
|
||||||
|
VarPtr x2 = std::make_shared<Var>();
|
||||||
|
VarPtr x3 = std::make_shared<Var>();
|
||||||
|
VarPtr x4 = std::make_shared<Var>();
|
||||||
|
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
||||||
|
VectorRef depend({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3});
|
||||||
|
return VectorRef({prim::kPrimMul, depend, x4});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3::Process(const FuncGraphPtr &graph,
|
||||||
|
const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
|
auto mul_node = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(mul_node);
|
||||||
|
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
|
||||||
|
|
||||||
|
auto depend_node = mul_node->input(kIndex1);
|
||||||
|
auto depend_cnode = depend_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||||
|
CheckCNodeInputSize(depend_cnode, kDependInputNum);
|
||||||
|
|
||||||
|
auto sparse_softmax_node = depend_cnode->input(kIndex1);
|
||||||
|
bool is_sp_grad_flag = true;
|
||||||
|
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||||
|
auto expand_dims_node =
|
||||||
|
CreateMulInput(graph, mul_node, sparse_softmax_node, name(), *this, &softmax_node_outputs, &is_sp_grad_flag);
|
||||||
|
if (!is_sp_grad_flag) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto new_depend = CreateDepend(graph, depend_cnode, softmax_node_outputs[1], *this);
|
||||||
|
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), new_depend,
|
||||||
|
expand_dims_node};
|
||||||
|
auto new_mul_node = NewCNode(new_mul_inputs, graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_mul_node);
|
||||||
|
new_mul_node->set_scope(mul_node->scope());
|
||||||
|
new_mul_node->set_abstract(mul_node->abstract());
|
||||||
|
return new_mul_node;
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -77,6 +77,15 @@ class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public Patt
|
||||||
const BaseRef DefinePattern() const override;
|
const BaseRef DefinePattern() const override;
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3 : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3(bool multigraph = true)
|
||||||
|
: PatternProcessPass("pynative_grad_sparse_softmax_cross_entropy_with_logits_unify_mindir_v3", multigraph) {}
|
||||||
|
~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H
|
#endif // MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H
|
||||||
|
|
Loading…
Reference in New Issue