!15814 [lite]add tf padv2 parse and fix bug

From: @xu_anyue
Reviewed-by: @hangangqiang,@jpc_chenjianping
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-04-28 16:16:26 +08:00 committed by Gitee
commit 1ea424810a
3 changed files with 31 additions and 6 deletions

View File

@ -31,7 +31,28 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op,
if (tf_op.op() == "Pad") {
prim->set_padding_mode(mindspore::PaddingMode::CONSTANT);
prim->set_constant_value(0.0f);
} else if (tf_op.op() == "PadV2") {
prim->set_padding_mode(mindspore::PaddingMode::CONSTANT);
if (tf_op.input_size() < 3) {
MS_LOG(ERROR) << "tf padv2 input size less than 3, which is " << tf_op.input_size();
return nullptr;
}
auto &const_value_name = tf_op.input(2);
if (tf_node_map.find(const_value_name) == tf_node_map.end()) {
MS_LOG(ERROR) << "cannot find the input.";
return nullptr;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(*tf_node_map.at(const_value_name), "value", &attr_value)) {
MS_LOG(ERROR) << "the input may be not const, which is not support now.";
return nullptr;
}
auto &tensor_proto = attr_value.tensor();
if (tensor_proto.dtype() != tensorflow::DT_FLOAT) {
MS_LOG(ERROR) << "input data type only support float now.";
return nullptr;
}
prim->set_constant_value(tensor_proto.float_val(0));
} else if (tf_op.op() == "MirrorPad") {
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "mode", &attr_value)) {
@ -58,6 +79,7 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op,
return prim.release();
}
TFNodeRegistrar g_tfPadParser("Pad", new TFPadParser());
TFNodeRegistrar g_tfPadV2Parser("PadV2", new TFPadParser());
TFNodeRegistrar g_tfMirrorPadParser("MirrorPad", new TFPadParser());
} // namespace lite
} // namespace mindspore

View File

@ -807,7 +807,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra
return false;
}
}
ResetSubGraphInput();
return true;
}
@ -868,7 +867,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
return false;
}
}
ResetSubGraphInput();
return true;
}
@ -1029,6 +1027,7 @@ bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "run framework transpose unify failed.";
return false;
}
ResetSubGraphInput();
// delete insert transpose op and update op output shape.
if (!ResetFuncGraph(func_graph)) {
MS_LOG(ERROR) << "reset func_graph failed.";
@ -1059,11 +1058,13 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "run framework transpose unify failed.";
return false;
}
ResetSubGraphInput();
// if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op.
if (!DecreaseTransposeForSingleOp(func_graph)) {
MS_LOG(ERROR) << "run local trans insert optimizer failed.";
return false;
}
ResetSubGraphInput();
// if input format of several ops surrounded only by transpose op all can be NHWC,
// we can delete these transpose ops, and at the same time, transform these middle ops.
if (!DecreaseTransposeForMultiOp(func_graph)) {

View File

@ -41,8 +41,8 @@ STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) {
}
}
}
if (utils::isa<FuncGraphPtr>(node)) {
auto sub_graph = utils::cast<FuncGraphPtr>(node);
if (utils::isa<ValueNode>(node) && GetValueNode<FuncGraphPtr>(node) != nullptr) {
auto sub_graph = GetValueNode<FuncGraphPtr>(node);
auto status = ProcessGraph(sub_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "process sub graph failed";
@ -51,8 +51,10 @@ STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) {
}
}
auto nodes = func_graph->nodes();
auto graph_inputs = func_graph->get_inputs();
for (auto &node : nodes) {
if (vis.find(node) == vis.end()) {
if (vis.find(node) == vis.end() &&
std::find(graph_inputs.begin(), graph_inputs.end(), node) == graph_inputs.end()) {
func_graph->DropNode(node);
}
}