forked from mindspore-Ecosystem/mindspore
!10923 unify mindir sparse op
From: @hwjiaorui Reviewed-by: @zhoufeng54 Signed-off-by:
This commit is contained in:
commit
3f0316583e
|
@ -27,6 +27,9 @@
|
|||
#include "ir/dtype/type.h"
|
||||
|
||||
constexpr auto softmax_output_shape_size = 2;
|
||||
constexpr auto kAttrDepth = "depth";
|
||||
constexpr auto kAttrMultiples = "multiples";
|
||||
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
|
@ -47,12 +50,12 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
|
|||
return new_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node) {
|
||||
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, bool is_pynative = false) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
|
||||
std::vector<size_t> logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 0);
|
||||
int64_t depth;
|
||||
int64_t depth = 0;
|
||||
if (logits_shape.size() >= 1) {
|
||||
size_t index = logits_shape.size() - 1;
|
||||
depth = logits_shape[index];
|
||||
|
@ -66,33 +69,37 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
|
|||
auto value_off = std::make_shared<tensor::Tensor>(0.0, kFloat32);
|
||||
auto value_off_node = CreateValueNode(value_off, kNumberTypeFloat32);
|
||||
MS_EXCEPTION_IF_NULL(value_off_node);
|
||||
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
kernel_graph->AddValueNodeToGraph(value_on_node);
|
||||
kernel_graph->AddValueNodeToGraph(value_off_node);
|
||||
|
||||
auto depth_node = NewValueNode(depth);
|
||||
MS_EXCEPTION_IF_NULL(depth_node);
|
||||
|
||||
auto depth_abstract = std::make_shared<abstract::AbstractScalar>();
|
||||
depth_abstract->set_type(kInt64);
|
||||
depth_node->set_abstract(depth_abstract);
|
||||
|
||||
auto one_hot_primitive = std::make_shared<Primitive>(kOneHotOpName);
|
||||
std::vector<std::string> input_names = {"indices", "depth", "on_value", "off_value"};
|
||||
std::vector<std::string> output_names = {"output"};
|
||||
one_hot_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||
one_hot_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||
std::vector<AnfNodePtr> one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), depth_node,
|
||||
value_on_node, value_off_node};
|
||||
|
||||
std::vector<AnfNodePtr> one_hot_inputs;
|
||||
if (is_pynative) {
|
||||
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), value_on_node, value_off_node};
|
||||
} else {
|
||||
auto depth_node = NewValueNode(depth);
|
||||
MS_EXCEPTION_IF_NULL(depth_node);
|
||||
auto depth_abstract = std::make_shared<abstract::AbstractScalar>();
|
||||
depth_abstract->set_type(kInt64);
|
||||
depth_node->set_abstract(depth_abstract);
|
||||
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), depth_node, value_on_node,
|
||||
value_off_node};
|
||||
}
|
||||
auto one_hot_node = graph->NewCNode(one_hot_inputs);
|
||||
MS_EXCEPTION_IF_NULL(one_hot_node);
|
||||
|
||||
one_hot_node->set_scope(sparse_softmax_node->scope());
|
||||
std::vector<size_t> labels_shape = AnfAlgo ::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
|
||||
labels_shape.emplace_back(depth);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {labels_shape}, one_hot_node.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(-1), one_hot_node);
|
||||
if (is_pynative) {
|
||||
AnfAlgo::SetNodeAttr(kAttrDepth, MakeValue(depth), one_hot_node);
|
||||
}
|
||||
return one_hot_node;
|
||||
}
|
||||
|
||||
|
@ -106,9 +113,6 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN
|
|||
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
if (one_hot_node->size() != kOneHotInputNum) {
|
||||
MS_LOG(EXCEPTION) << "ont_hot's input size not equal " << kOneHotInputNum;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)),
|
||||
sparse_softmax_node->input(1), one_hot_node};
|
||||
|
@ -131,7 +135,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN
|
|||
return softmax_node;
|
||||
}
|
||||
|
||||
ValueNodePtr GetAxis(const AnfNodePtr &node) {
|
||||
std::vector<int64_t> GetAxis(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<size_t> output_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
if (output_shape.empty()) {
|
||||
|
@ -141,13 +145,19 @@ ValueNodePtr GetAxis(const AnfNodePtr &node) {
|
|||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
range.emplace_back(i);
|
||||
}
|
||||
return range;
|
||||
}
|
||||
|
||||
ValueNodePtr GetAxisNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto range = GetAxis(node);
|
||||
auto axis_node = CreateValueNode(MakeValue(range), kNumberTypeInt64);
|
||||
MS_EXCEPTION_IF_NULL(axis_node);
|
||||
return axis_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
|
||||
const AnfNodePtr &softmax_output_node) {
|
||||
const AnfNodePtr &softmax_output_node, bool is_pynative = false) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(softmax_output_node);
|
||||
|
@ -155,10 +165,10 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
|
|||
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
auto axis_node = GetAxis(softmax_output_node);
|
||||
|
||||
auto axis_value = GetAxis(softmax_output_node);
|
||||
auto axis_node = GetAxisNode(softmax_output_node);
|
||||
MS_EXCEPTION_IF_NULL(axis_node);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
kernel_graph->AddValueNodeToGraph(axis_node);
|
||||
|
||||
auto reduce_primitive = std::make_shared<Primitive>(kReduceMeanOpName);
|
||||
std::vector<std::string> input_names = {"x", "axis"};
|
||||
|
@ -166,14 +176,23 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
|
|||
reduce_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||
reduce_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(reduce_primitive), softmax_output_node, axis_node};
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
if (is_pynative) {
|
||||
inputs = {NewValueNode(reduce_primitive), softmax_output_node};
|
||||
} else {
|
||||
kernel_graph->AddValueNodeToGraph(axis_node);
|
||||
inputs = {NewValueNode(reduce_primitive), softmax_output_node, axis_node};
|
||||
}
|
||||
auto reduce_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_node);
|
||||
|
||||
reduce_node->set_scope(sparse_softmax_node->scope());
|
||||
auto reduce_abstract = softmax_output_node->abstract();
|
||||
reduce_abstract->set_shape(std::make_shared<abstract::Shape>());
|
||||
reduce_node->set_abstract(reduce_abstract);
|
||||
if (is_pynative) {
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_value), reduce_node);
|
||||
}
|
||||
return reduce_node;
|
||||
}
|
||||
|
||||
|
@ -207,8 +226,33 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no
|
|||
expand_dims_node.get());
|
||||
return expand_dims_node;
|
||||
}
|
||||
CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(real_div_node);
|
||||
if (real_div_node->size() != kRealDivInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum;
|
||||
}
|
||||
int64_t axis = -1;
|
||||
auto expand_dims_primitive = std::make_shared<Primitive>(kExpandDimsOpName);
|
||||
std::vector<std::string> input_names = {"x"};
|
||||
std::vector<std::string> output_names = {"output"};
|
||||
expand_dims_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||
expand_dims_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||
std::vector<AnfNodePtr> expand_dims_inputs = {NewValueNode(expand_dims_primitive), real_div_node};
|
||||
auto expand_dims_node = graph->NewCNode(expand_dims_inputs);
|
||||
MS_EXCEPTION_IF_NULL(expand_dims_node);
|
||||
|
||||
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node) {
|
||||
expand_dims_node->set_scope(real_div_node->scope());
|
||||
std::vector<size_t> y_shape = AnfAlgo::GetOutputInferShape(real_div_node, 0);
|
||||
y_shape.emplace_back(1);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(real_div_node, 0)}, {y_shape},
|
||||
expand_dims_node.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), expand_dims_node);
|
||||
return expand_dims_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node,
|
||||
bool is_pynative = false) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
|
@ -224,24 +268,37 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
|||
std::vector<int64_t> multiple_value;
|
||||
std::transform(labels_shape.begin(), labels_shape.end(), std::back_inserter(multiple_value),
|
||||
[](size_t label) { return static_cast<int64_t>(label); });
|
||||
auto mutiples = MakeValue(multiple_value);
|
||||
auto mutiples_node = CreateValueNode(mutiples, kNumberTypeInt64);
|
||||
MS_EXCEPTION_IF_NULL(mutiples_node);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
kernel_graph->AddValueNodeToGraph(mutiples_node);
|
||||
auto multiples = MakeValue(multiple_value);
|
||||
auto multiples_node = CreateValueNode(multiples, kNumberTypeInt64);
|
||||
MS_EXCEPTION_IF_NULL(multiples_node);
|
||||
|
||||
auto tile_primitive = std::make_shared<Primitive>(kTileOpName);
|
||||
std::vector<std::string> input_names = {"x", "multiples"};
|
||||
std::vector<std::string> output_names = {"output"};
|
||||
tile_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
|
||||
tile_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||
std::vector<AnfNodePtr> tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2), mutiples_node};
|
||||
|
||||
std::vector<AnfNodePtr> tile_inputs;
|
||||
if (is_pynative) {
|
||||
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2)};
|
||||
} else {
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
kernel_graph->AddValueNodeToGraph(multiples_node);
|
||||
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2), multiples_node};
|
||||
}
|
||||
|
||||
auto tile_node = graph->NewCNode(tile_inputs);
|
||||
MS_EXCEPTION_IF_NULL(tile_node);
|
||||
|
||||
tile_node->set_scope(mul_node->scope());
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, {labels_shape},
|
||||
tile_node.get());
|
||||
if (is_pynative) {
|
||||
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node);
|
||||
}
|
||||
// feature map set
|
||||
std::vector<size_t> feature_map_input_indexs;
|
||||
feature_map_input_indexs.push_back(0);
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), tile_node);
|
||||
return tile_node;
|
||||
}
|
||||
|
||||
|
@ -368,7 +425,6 @@ const AnfNodePtr SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const F
|
|||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0]);
|
||||
|
||||
return reduce_node;
|
||||
}
|
||||
|
||||
|
@ -450,5 +506,76 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c
|
|||
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
|
||||
return mul_node;
|
||||
}
|
||||
|
||||
const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto sparse_softmax_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) &&
|
||||
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, true);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node);
|
||||
|
||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], true);
|
||||
return reduce_node;
|
||||
}
|
||||
|
||||
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern() const {
|
||||
VarPtr x1 = std::make_shared<Var>();
|
||||
VarPtr x2 = std::make_shared<Var>();
|
||||
VarPtr x3 = std::make_shared<Var>();
|
||||
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
||||
return VectorRef({prim::kPrimMul, sparse_softmax_cross_entropy_with_logits, x3});
|
||||
}
|
||||
|
||||
const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto mul_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
if (mul_node->size() != kMulInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum;
|
||||
}
|
||||
auto sparse_softmax_node = mul_node->input(1);
|
||||
auto sparse_softmax_node_grad = sparse_softmax_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node_grad);
|
||||
|
||||
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, true);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node);
|
||||
|
||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, true);
|
||||
auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
|
||||
auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node);
|
||||
|
||||
mul_node->set_input(1, softmax_node_outputs[1]);
|
||||
mul_node->set_input(2, expand_dims_node);
|
||||
|
||||
return mul_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,14 +18,16 @@
|
|||
#define MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass {
|
||||
public:
|
||||
explicit SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("sparse_softmax_cross_entropy_with_logits_unify_mindir", multigraph) {}
|
||||
explicit SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(
|
||||
const std::string &name = "sparse_softmax_cross_entropy_with_logits_unify_mindir", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph) {}
|
||||
~SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
@ -49,6 +51,24 @@ class GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public PatternProce
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR {
|
||||
public:
|
||||
explicit PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true)
|
||||
: SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR("pynative_sparse_softmax_cross_entropy_with_logits_unify_mindir",
|
||||
multigraph) {}
|
||||
~PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass {
|
||||
public:
|
||||
explicit PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("pynative_grad_sparse_softmax_cross_entropy_with_logits_unify_mindir", multigraph) {}
|
||||
~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H
|
||||
|
|
|
@ -449,9 +449,14 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
|
|||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2>());
|
||||
} else {
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIRPynative>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIRPynative>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
}
|
||||
|
||||
optimizer->AddPassManager(unify_mindir_pm);
|
||||
|
|
|
@ -281,15 +281,13 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
|||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0., mstype.float32)
|
||||
self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"]
|
||||
if self.is_cpugpu:
|
||||
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits()
|
||||
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits()
|
||||
|
||||
def construct(self, logits, labels):
|
||||
if self.is_cpugpu and self.sparse and self.reduction == 'mean':
|
||||
x = self.sparse_softmax_cross_entropy(logits, labels)
|
||||
return x
|
||||
|
||||
if self.sparse:
|
||||
if self.reduction == 'mean':
|
||||
x = self.sparse_softmax_cross_entropy(logits, labels)
|
||||
return x
|
||||
labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value)
|
||||
x = self.softmax_cross_entropy(logits, labels)[0]
|
||||
return self.get_loss(x)
|
||||
|
|
Loading…
Reference in New Issue