forked from mindspore-Ecosystem/mindspore
!15814 [lite]add tf padv2 parse and fix bug
From: @xu_anyue Reviewed-by: @hangangqiang,@jpc_chenjianping Signed-off-by: @hangangqiang
This commit is contained in:
commit
1ea424810a
|
@ -31,7 +31,28 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
if (tf_op.op() == "Pad") {
|
||||
prim->set_padding_mode(mindspore::PaddingMode::CONSTANT);
|
||||
prim->set_constant_value(0.0f);
|
||||
|
||||
} else if (tf_op.op() == "PadV2") {
|
||||
prim->set_padding_mode(mindspore::PaddingMode::CONSTANT);
|
||||
if (tf_op.input_size() < 3) {
|
||||
MS_LOG(ERROR) << "tf padv2 input size less than 3, which is " << tf_op.input_size();
|
||||
return nullptr;
|
||||
}
|
||||
auto &const_value_name = tf_op.input(2);
|
||||
if (tf_node_map.find(const_value_name) == tf_node_map.end()) {
|
||||
MS_LOG(ERROR) << "cannot find the input.";
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(*tf_node_map.at(const_value_name), "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "the input may be not const, which is not support now.";
|
||||
return nullptr;
|
||||
}
|
||||
auto &tensor_proto = attr_value.tensor();
|
||||
if (tensor_proto.dtype() != tensorflow::DT_FLOAT) {
|
||||
MS_LOG(ERROR) << "input data type only support float now.";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_constant_value(tensor_proto.float_val(0));
|
||||
} else if (tf_op.op() == "MirrorPad") {
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "mode", &attr_value)) {
|
||||
|
@ -58,6 +79,7 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
return prim.release();
|
||||
}
|
||||
TFNodeRegistrar g_tfPadParser("Pad", new TFPadParser());
|
||||
TFNodeRegistrar g_tfPadV2Parser("PadV2", new TFPadParser());
|
||||
TFNodeRegistrar g_tfMirrorPadParser("MirrorPad", new TFPadParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -807,7 +807,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra
|
|||
return false;
|
||||
}
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -868,7 +867,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
|
|||
return false;
|
||||
}
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1029,6 +1027,7 @@ bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "run framework transpose unify failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
// delete insert transpose op and update op output shape.
|
||||
if (!ResetFuncGraph(func_graph)) {
|
||||
MS_LOG(ERROR) << "reset func_graph failed.";
|
||||
|
@ -1059,11 +1058,13 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "run framework transpose unify failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
// if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op.
|
||||
if (!DecreaseTransposeForSingleOp(func_graph)) {
|
||||
MS_LOG(ERROR) << "run local trans insert optimizer failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
// if input format of several ops surrounded only by transpose op all can be NHWC,
|
||||
// we can delete these transpose ops, and at the same time, transform these middle ops.
|
||||
if (!DecreaseTransposeForMultiOp(func_graph)) {
|
||||
|
|
|
@ -41,8 +41,8 @@ STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
}
|
||||
}
|
||||
if (utils::isa<FuncGraphPtr>(node)) {
|
||||
auto sub_graph = utils::cast<FuncGraphPtr>(node);
|
||||
if (utils::isa<ValueNode>(node) && GetValueNode<FuncGraphPtr>(node) != nullptr) {
|
||||
auto sub_graph = GetValueNode<FuncGraphPtr>(node);
|
||||
auto status = ProcessGraph(sub_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "process sub graph failed";
|
||||
|
@ -51,8 +51,10 @@ STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
}
|
||||
auto nodes = func_graph->nodes();
|
||||
auto graph_inputs = func_graph->get_inputs();
|
||||
for (auto &node : nodes) {
|
||||
if (vis.find(node) == vis.end()) {
|
||||
if (vis.find(node) == vis.end() &&
|
||||
std::find(graph_inputs.begin(), graph_inputs.end(), node) == graph_inputs.end()) {
|
||||
func_graph->DropNode(node);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue