[MSLITE] Support convert softplus to activation

This commit is contained in:
wang_shaocong 2022-01-08 10:22:30 +08:00
parent f6a20f1e62
commit 18ff1a5248
3 changed files with 9 additions and 3 deletions

View File

@ -27,11 +27,12 @@
namespace mindspore {
namespace ops {
constexpr auto kNameSoftplus = "Softplus";
/// \brief Softplus activation function. Refer to Python API @ref mindspore.ops.Softplus for more details.
class MS_CORE_API Softplus : public PrimitiveC {
public:
/// \brief Constructor.
Softplus() : PrimitiveC(prim::kPrimSoftplus->name()) { InitIOName({"x"}, {"output"}); }
Softplus() : PrimitiveC(kNameSoftplus) { InitIOName({"x"}, {"output"}); }
/// \brief Destructor.
~Softplus() = default;
MS_DECLARE_PARENT(Softplus, PrimitiveC);

View File

@ -72,6 +72,7 @@
#include "ops/sigmoid.h"
#include "ops/stack.h"
#include "ops/tanh.h"
#include "ops/softplus.h"
#include "ops/sparse_softmax_cross_entropy_with_logits.h"
#include "ops/grad/resize_grad.h"
#include "ops/random_standard_normal.h"
@ -118,6 +119,7 @@ using mindspore::ops::kNameResizeBilinear;
using mindspore::ops::kNameResizeNearestNeighbor;
using mindspore::ops::kNameScale;
using mindspore::ops::kNameSigmoid;
using mindspore::ops::kNameSoftplus;
using mindspore::ops::kNameSparseSoftmaxCrossEntropyWithLogits;
using mindspore::ops::kNameSub;
using mindspore::ops::kNameTanh;
@ -160,6 +162,7 @@ std::map<std::string, mindspore::ActivationType> activation_map = {{ops::kNameEl
{ops::kNameReLU6, mindspore::RELU6},
{ops::kNameSigmoid, mindspore::SIGMOID},
{ops::kNameTanh, mindspore::TANH},
{ops::kNameSoftplus, mindspore::SOFTPLUS},
{kNameHSigmoid, mindspore::HSIGMOID},
{kNameHSigmoidGrad, mindspore::HSIGMOID},
{kNameHSwish, mindspore::HSWISH},
@ -671,5 +674,6 @@ REGIST_PRIMITIVE_ADJUST(kNameSparseSoftmaxCrossEntropyWithLogits,
MoveAttrMapCommon<ops::SparseSoftmaxCrossEntropyWithLogits>)
REGIST_PRIMITIVE_ADJUST(kNameResizeBilinearGrad, MoveAttrMapResizeGrad)
REGIST_PRIMITIVE_ADJUST(kNameResizeNearestNeighborGrad, MoveAttrMapResizeGrad)
REGIST_PRIMITIVE_ADJUST(kNameSoftplus, MoveAttrMapActivation)
} // namespace lite
} // namespace mindspore

View File

@ -49,8 +49,9 @@ bool ConstFoldAlongInferShape::CheckCanFold(const FuncGraphPtr &func_graph, cons
if (!is_inferred) {
return false;
}
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
return lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() != 0;
if (CheckPrimitiveType(cnode, prim::kPrimShape) &&
lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() != 0) {
return true;
}
auto inputs = cnode->inputs();
auto graph_inputs =