!14354 [MS_LITE] fix leakrelu

From: @YeFeng_24
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-03-30 20:21:29 +08:00 committed by Gitee
commit 9bc02b1e69
2 changed files with 9 additions and 32 deletions

View File

@ -41,6 +41,14 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
prim->set_activation_type(mindspore::ActivationType::SELU);
} else if (tf_op.op() == "Softplus") {
prim->set_activation_type(mindspore::ActivationType::SOFTPLUS);
} else if (tf_op.op() == "LeakyRelu") {
prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU);
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
MS_LOG(ERROR) << "The attribute alpha should be specified.";
return nullptr;
}
prim->set_alpha(attr_value.f());
} else {
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
return nullptr;
@ -55,33 +63,12 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
return prim.release();
}
ops::PrimitiveC *TFLeakyReluParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::LeakyRelu>();
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
MS_LOG(ERROR) << "The attribute alpha should be specified.";
return nullptr;
}
prim->set_negative_slope(attr_value.f());
*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
MS_LOG(ERROR) << "add op input failed";
return nullptr;
}
return prim.release();
}
TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser());
TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser());
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFLeakyReluParser());
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser());
TFNodeRegistrar g_tfSoftplusParser("Softplus", new TFActivationParser());
} // namespace lite
} // namespace mindspore

View File

@ -33,16 +33,6 @@ class TFActivationParser : public TFNodeParser {
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
class TFLeakyReluParser : public TFNodeParser {
public:
TFLeakyReluParser() = default;
~TFLeakyReluParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore