!14354 [MS_LITE] fix leakrelu
From: @YeFeng_24 Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiang
This commit is contained in:
commit
9bc02b1e69
|
@ -41,6 +41,14 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
prim->set_activation_type(mindspore::ActivationType::SELU);
|
||||
} else if (tf_op.op() == "Softplus") {
|
||||
prim->set_activation_type(mindspore::ActivationType::SOFTPLUS);
|
||||
} else if (tf_op.op() == "LeakyRelu") {
|
||||
prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU);
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The attribute alpha should be specified.";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_alpha(attr_value.f());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
|
||||
return nullptr;
|
||||
|
@ -55,33 +63,12 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
return prim.release();
|
||||
}
|
||||
|
||||
ops::PrimitiveC *TFLeakyReluParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
auto prim = std::make_unique<ops::LeakyRelu>();
|
||||
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The attribute alpha should be specified.";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_negative_slope(attr_value.f());
|
||||
|
||||
*output_size = 1;
|
||||
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
|
||||
MS_LOG(ERROR) << "add op input failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return prim.release();
|
||||
}
|
||||
|
||||
TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser());
|
||||
TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
|
||||
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
|
||||
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
|
||||
TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser());
|
||||
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFLeakyReluParser());
|
||||
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser());
|
||||
TFNodeRegistrar g_tfSoftplusParser("Softplus", new TFActivationParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,16 +33,6 @@ class TFActivationParser : public TFNodeParser {
|
|||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
|
||||
class TFLeakyReluParser : public TFNodeParser {
|
||||
public:
|
||||
TFLeakyReluParser() = default;
|
||||
~TFLeakyReluParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue