forked from mindspore-Ecosystem/mindspore
!10923 unify mindir sparse op
From: @hwjiaorui Reviewed-by: @zhoufeng54 Signed-off-by:
This commit is contained in:
commit
3f0316583e
|
@ -27,6 +27,9 @@
|
||||||
#include "ir/dtype/type.h"
|
#include "ir/dtype/type.h"
|
||||||
|
|
||||||
constexpr auto softmax_output_shape_size = 2;
|
constexpr auto softmax_output_shape_size = 2;
|
||||||
|
constexpr auto kAttrDepth = "depth";
|
||||||
|
constexpr auto kAttrMultiples = "multiples";
|
||||||
|
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -47,12 +50,12 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
|
||||||
return new_node;
|
return new_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node) {
|
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, bool is_pynative = false) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||||
|
|
||||||
std::vector<size_t> logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 0);
|
std::vector<size_t> logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 0);
|
||||||
int64_t depth;
|
int64_t depth = 0;
|
||||||
if (logits_shape.size() >= 1) {
|
if (logits_shape.size() >= 1) {
|
||||||
size_t index = logits_shape.size() - 1;
|
size_t index = logits_shape.size() - 1;
|
||||||
depth = logits_shape[index];
|
depth = logits_shape[index];
|
||||||
|
@ -66,33 +69,37 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
|
||||||
auto value_off = std::make_shared<tensor::Tensor>(0.0, kFloat32);
|
auto value_off = std::make_shared<tensor::Tensor>(0.0, kFloat32);
|
||||||
auto value_off_node = CreateValueNode(value_off, kNumberTypeFloat32);
|
auto value_off_node = CreateValueNode(value_off, kNumberTypeFloat32);
|
||||||
MS_EXCEPTION_IF_NULL(value_off_node);
|
MS_EXCEPTION_IF_NULL(value_off_node);
|
||||||
|
|
||||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||||
kernel_graph->AddValueNodeToGraph(value_on_node);
|
kernel_graph->AddValueNodeToGraph(value_on_node);
|
||||||
kernel_graph->AddValueNodeToGraph(value_off_node);
|
kernel_graph->AddValueNodeToGraph(value_off_node);
|
||||||
|
|
||||||
auto depth_node = NewValueNode(depth);
|
|
||||||
MS_EXCEPTION_IF_NULL(depth_node);
|
|
||||||
|
|
||||||
auto depth_abstract = std::make_shared<abstract::AbstractScalar>();
|
|
||||||
depth_abstract->set_type(kInt64);
|
|
||||||
depth_node->set_abstract(depth_abstract);
|
|
||||||
|
|
||||||
auto one_hot_primitive = std::make_shared<Primitive>(kOneHotOpName);
|
auto one_hot_primitive = std::make_shared<Primitive>(kOneHotOpName);
|
||||||
std::vector<std::string> input_names = {"indices", "depth", "on_value", "off_value"};
|
std::vector<std::string> input_names = {"indices", "depth", "on_value", "off_value"};
|
||||||
std::vector<std::string> output_names = {"output"};
|
std::vector<std::string> output_names = {"output"};
|
||||||
one_hot_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
one_hot_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||||
one_hot_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
one_hot_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||||
std::vector<AnfNodePtr> one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), depth_node,
|
|
||||||
value_on_node, value_off_node};
|
std::vector<AnfNodePtr> one_hot_inputs;
|
||||||
|
if (is_pynative) {
|
||||||
|
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), value_on_node, value_off_node};
|
||||||
|
} else {
|
||||||
|
auto depth_node = NewValueNode(depth);
|
||||||
|
MS_EXCEPTION_IF_NULL(depth_node);
|
||||||
|
auto depth_abstract = std::make_shared<abstract::AbstractScalar>();
|
||||||
|
depth_abstract->set_type(kInt64);
|
||||||
|
depth_node->set_abstract(depth_abstract);
|
||||||
|
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), depth_node, value_on_node,
|
||||||
|
value_off_node};
|
||||||
|
}
|
||||||
auto one_hot_node = graph->NewCNode(one_hot_inputs);
|
auto one_hot_node = graph->NewCNode(one_hot_inputs);
|
||||||
MS_EXCEPTION_IF_NULL(one_hot_node);
|
MS_EXCEPTION_IF_NULL(one_hot_node);
|
||||||
|
|
||||||
one_hot_node->set_scope(sparse_softmax_node->scope());
|
one_hot_node->set_scope(sparse_softmax_node->scope());
|
||||||
std::vector<size_t> labels_shape = AnfAlgo ::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
|
std::vector<size_t> labels_shape = AnfAlgo ::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
|
||||||
labels_shape.emplace_back(depth);
|
labels_shape.emplace_back(depth);
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {labels_shape}, one_hot_node.get());
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {labels_shape}, one_hot_node.get());
|
||||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(-1), one_hot_node);
|
if (is_pynative) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrDepth, MakeValue(depth), one_hot_node);
|
||||||
|
}
|
||||||
return one_hot_node;
|
return one_hot_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,9 +113,6 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN
|
||||||
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
|
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
|
||||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||||
}
|
}
|
||||||
if (one_hot_node->size() != kOneHotInputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "ont_hot's input size not equal " << kOneHotInputNum;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)),
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)),
|
||||||
sparse_softmax_node->input(1), one_hot_node};
|
sparse_softmax_node->input(1), one_hot_node};
|
||||||
|
@ -131,7 +135,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN
|
||||||
return softmax_node;
|
return softmax_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueNodePtr GetAxis(const AnfNodePtr &node) {
|
std::vector<int64_t> GetAxis(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
std::vector<size_t> output_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
std::vector<size_t> output_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||||
if (output_shape.empty()) {
|
if (output_shape.empty()) {
|
||||||
|
@ -141,13 +145,19 @@ ValueNodePtr GetAxis(const AnfNodePtr &node) {
|
||||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||||
range.emplace_back(i);
|
range.emplace_back(i);
|
||||||
}
|
}
|
||||||
|
return range;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueNodePtr GetAxisNode(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto range = GetAxis(node);
|
||||||
auto axis_node = CreateValueNode(MakeValue(range), kNumberTypeInt64);
|
auto axis_node = CreateValueNode(MakeValue(range), kNumberTypeInt64);
|
||||||
MS_EXCEPTION_IF_NULL(axis_node);
|
MS_EXCEPTION_IF_NULL(axis_node);
|
||||||
return axis_node;
|
return axis_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
|
CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
|
||||||
const AnfNodePtr &softmax_output_node) {
|
const AnfNodePtr &softmax_output_node, bool is_pynative = false) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||||
MS_EXCEPTION_IF_NULL(softmax_output_node);
|
MS_EXCEPTION_IF_NULL(softmax_output_node);
|
||||||
|
@ -155,10 +165,10 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
|
||||||
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
|
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
|
||||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||||
}
|
}
|
||||||
auto axis_node = GetAxis(softmax_output_node);
|
|
||||||
|
auto axis_value = GetAxis(softmax_output_node);
|
||||||
|
auto axis_node = GetAxisNode(softmax_output_node);
|
||||||
MS_EXCEPTION_IF_NULL(axis_node);
|
MS_EXCEPTION_IF_NULL(axis_node);
|
||||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
|
||||||
kernel_graph->AddValueNodeToGraph(axis_node);
|
|
||||||
|
|
||||||
auto reduce_primitive = std::make_shared<Primitive>(kReduceMeanOpName);
|
auto reduce_primitive = std::make_shared<Primitive>(kReduceMeanOpName);
|
||||||
std::vector<std::string> input_names = {"x", "axis"};
|
std::vector<std::string> input_names = {"x", "axis"};
|
||||||
|
@ -166,14 +176,23 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
|
||||||
reduce_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
reduce_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||||
reduce_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
reduce_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||||
|
|
||||||
std::vector<AnfNodePtr> inputs = {NewValueNode(reduce_primitive), softmax_output_node, axis_node};
|
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||||
|
std::vector<AnfNodePtr> inputs;
|
||||||
|
if (is_pynative) {
|
||||||
|
inputs = {NewValueNode(reduce_primitive), softmax_output_node};
|
||||||
|
} else {
|
||||||
|
kernel_graph->AddValueNodeToGraph(axis_node);
|
||||||
|
inputs = {NewValueNode(reduce_primitive), softmax_output_node, axis_node};
|
||||||
|
}
|
||||||
auto reduce_node = graph->NewCNode(inputs);
|
auto reduce_node = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(reduce_node);
|
MS_EXCEPTION_IF_NULL(reduce_node);
|
||||||
|
|
||||||
reduce_node->set_scope(sparse_softmax_node->scope());
|
reduce_node->set_scope(sparse_softmax_node->scope());
|
||||||
auto reduce_abstract = softmax_output_node->abstract();
|
auto reduce_abstract = softmax_output_node->abstract();
|
||||||
reduce_abstract->set_shape(std::make_shared<abstract::Shape>());
|
reduce_abstract->set_shape(std::make_shared<abstract::Shape>());
|
||||||
reduce_node->set_abstract(reduce_abstract);
|
reduce_node->set_abstract(reduce_abstract);
|
||||||
|
if (is_pynative) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_value), reduce_node);
|
||||||
|
}
|
||||||
return reduce_node;
|
return reduce_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,8 +226,33 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no
|
||||||
expand_dims_node.get());
|
expand_dims_node.get());
|
||||||
return expand_dims_node;
|
return expand_dims_node;
|
||||||
}
|
}
|
||||||
|
CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(real_div_node);
|
||||||
|
if (real_div_node->size() != kRealDivInputNum) {
|
||||||
|
MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum;
|
||||||
|
}
|
||||||
|
int64_t axis = -1;
|
||||||
|
auto expand_dims_primitive = std::make_shared<Primitive>(kExpandDimsOpName);
|
||||||
|
std::vector<std::string> input_names = {"x"};
|
||||||
|
std::vector<std::string> output_names = {"output"};
|
||||||
|
expand_dims_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||||
|
expand_dims_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||||
|
std::vector<AnfNodePtr> expand_dims_inputs = {NewValueNode(expand_dims_primitive), real_div_node};
|
||||||
|
auto expand_dims_node = graph->NewCNode(expand_dims_inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(expand_dims_node);
|
||||||
|
|
||||||
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node) {
|
expand_dims_node->set_scope(real_div_node->scope());
|
||||||
|
std::vector<size_t> y_shape = AnfAlgo::GetOutputInferShape(real_div_node, 0);
|
||||||
|
y_shape.emplace_back(1);
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(real_div_node, 0)}, {y_shape},
|
||||||
|
expand_dims_node.get());
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), expand_dims_node);
|
||||||
|
return expand_dims_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node,
|
||||||
|
bool is_pynative = false) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||||
MS_EXCEPTION_IF_NULL(mul_node);
|
MS_EXCEPTION_IF_NULL(mul_node);
|
||||||
|
@ -224,24 +268,37 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
||||||
std::vector<int64_t> multiple_value;
|
std::vector<int64_t> multiple_value;
|
||||||
std::transform(labels_shape.begin(), labels_shape.end(), std::back_inserter(multiple_value),
|
std::transform(labels_shape.begin(), labels_shape.end(), std::back_inserter(multiple_value),
|
||||||
[](size_t label) { return static_cast<int64_t>(label); });
|
[](size_t label) { return static_cast<int64_t>(label); });
|
||||||
auto mutiples = MakeValue(multiple_value);
|
auto multiples = MakeValue(multiple_value);
|
||||||
auto mutiples_node = CreateValueNode(mutiples, kNumberTypeInt64);
|
auto multiples_node = CreateValueNode(multiples, kNumberTypeInt64);
|
||||||
MS_EXCEPTION_IF_NULL(mutiples_node);
|
MS_EXCEPTION_IF_NULL(multiples_node);
|
||||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
|
||||||
kernel_graph->AddValueNodeToGraph(mutiples_node);
|
|
||||||
|
|
||||||
auto tile_primitive = std::make_shared<Primitive>(kTileOpName);
|
auto tile_primitive = std::make_shared<Primitive>(kTileOpName);
|
||||||
std::vector<std::string> input_names = {"x", "multiples"};
|
std::vector<std::string> input_names = {"x", "multiples"};
|
||||||
std::vector<std::string> output_names = {"output"};
|
std::vector<std::string> output_names = {"output"};
|
||||||
tile_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
tile_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||||
tile_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
tile_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||||
std::vector<AnfNodePtr> tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2), mutiples_node};
|
|
||||||
|
std::vector<AnfNodePtr> tile_inputs;
|
||||||
|
if (is_pynative) {
|
||||||
|
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2)};
|
||||||
|
} else {
|
||||||
|
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||||
|
kernel_graph->AddValueNodeToGraph(multiples_node);
|
||||||
|
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2), multiples_node};
|
||||||
|
}
|
||||||
|
|
||||||
auto tile_node = graph->NewCNode(tile_inputs);
|
auto tile_node = graph->NewCNode(tile_inputs);
|
||||||
MS_EXCEPTION_IF_NULL(tile_node);
|
MS_EXCEPTION_IF_NULL(tile_node);
|
||||||
|
|
||||||
tile_node->set_scope(mul_node->scope());
|
tile_node->set_scope(mul_node->scope());
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, {labels_shape},
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, {labels_shape},
|
||||||
tile_node.get());
|
tile_node.get());
|
||||||
|
if (is_pynative) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node);
|
||||||
|
}
|
||||||
|
// feature map set
|
||||||
|
std::vector<size_t> feature_map_input_indexs;
|
||||||
|
feature_map_input_indexs.push_back(0);
|
||||||
|
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), tile_node);
|
||||||
return tile_node;
|
return tile_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -368,7 +425,6 @@ const AnfNodePtr SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const F
|
||||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0]);
|
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0]);
|
||||||
|
|
||||||
return reduce_node;
|
return reduce_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -450,5 +506,76 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c
|
||||||
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
|
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
|
||||||
return mul_node;
|
return mul_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph,
|
||||||
|
const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
|
auto sparse_softmax_node = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||||
|
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||||
|
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
|
||||||
|
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||||
|
}
|
||||||
|
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) &&
|
||||||
|
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr softmax_node;
|
||||||
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, true);
|
||||||
|
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node);
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||||
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||||
|
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], true);
|
||||||
|
return reduce_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern() const {
|
||||||
|
VarPtr x1 = std::make_shared<Var>();
|
||||||
|
VarPtr x2 = std::make_shared<Var>();
|
||||||
|
VarPtr x3 = std::make_shared<Var>();
|
||||||
|
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
||||||
|
return VectorRef({prim::kPrimMul, sparse_softmax_cross_entropy_with_logits, x3});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph,
|
||||||
|
const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
|
auto mul_node = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(mul_node);
|
||||||
|
if (mul_node->size() != kMulInputNum) {
|
||||||
|
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum;
|
||||||
|
}
|
||||||
|
auto sparse_softmax_node = mul_node->input(1);
|
||||||
|
auto sparse_softmax_node_grad = sparse_softmax_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node_grad);
|
||||||
|
|
||||||
|
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||||
|
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
|
||||||
|
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr softmax_node;
|
||||||
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, true);
|
||||||
|
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node);
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||||
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||||
|
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, true);
|
||||||
|
auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
|
||||||
|
auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node);
|
||||||
|
|
||||||
|
mul_node->set_input(1, softmax_node_outputs[1]);
|
||||||
|
mul_node->set_input(2, expand_dims_node);
|
||||||
|
|
||||||
|
return mul_node;
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,14 +18,16 @@
|
||||||
#define MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H
|
#define MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass {
|
class SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true)
|
explicit SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(
|
||||||
: PatternProcessPass("sparse_softmax_cross_entropy_with_logits_unify_mindir", multigraph) {}
|
const std::string &name = "sparse_softmax_cross_entropy_with_logits_unify_mindir", bool multigraph = true)
|
||||||
|
: PatternProcessPass(name, multigraph) {}
|
||||||
~SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default;
|
~SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default;
|
||||||
const BaseRef DefinePattern() const override;
|
const BaseRef DefinePattern() const override;
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
@ -49,6 +51,24 @@ class GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public PatternProce
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR {
|
||||||
|
public:
|
||||||
|
explicit PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true)
|
||||||
|
: SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR("pynative_sparse_softmax_cross_entropy_with_logits_unify_mindir",
|
||||||
|
multigraph) {}
|
||||||
|
~PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true)
|
||||||
|
: PatternProcessPass("pynative_grad_sparse_softmax_cross_entropy_with_logits_unify_mindir", multigraph) {}
|
||||||
|
~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H
|
#endif // MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H
|
||||||
|
|
|
@ -449,9 +449,14 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
|
||||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2>());
|
||||||
} else {
|
} else {
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIRPynative>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIRPynative>());
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIRPynative>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIRPynative>());
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||||
}
|
}
|
||||||
|
|
||||||
optimizer->AddPassManager(unify_mindir_pm);
|
optimizer->AddPassManager(unify_mindir_pm);
|
||||||
|
|
|
@ -281,15 +281,13 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
||||||
self.on_value = Tensor(1.0, mstype.float32)
|
self.on_value = Tensor(1.0, mstype.float32)
|
||||||
self.off_value = Tensor(0., mstype.float32)
|
self.off_value = Tensor(0., mstype.float32)
|
||||||
self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"]
|
self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"]
|
||||||
if self.is_cpugpu:
|
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits()
|
||||||
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits()
|
|
||||||
|
|
||||||
def construct(self, logits, labels):
|
def construct(self, logits, labels):
|
||||||
if self.is_cpugpu and self.sparse and self.reduction == 'mean':
|
|
||||||
x = self.sparse_softmax_cross_entropy(logits, labels)
|
|
||||||
return x
|
|
||||||
|
|
||||||
if self.sparse:
|
if self.sparse:
|
||||||
|
if self.reduction == 'mean':
|
||||||
|
x = self.sparse_softmax_cross_entropy(logits, labels)
|
||||||
|
return x
|
||||||
labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value)
|
labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value)
|
||||||
x = self.softmax_cross_entropy(logits, labels)[0]
|
x = self.softmax_cross_entropy(logits, labels)[0]
|
||||||
return self.get_loss(x)
|
return self.get_loss(x)
|
||||||
|
|
Loading…
Reference in New Issue