From 152992d3a9c3167d1844f1ba5edf417ef74d079e Mon Sep 17 00:00:00 2001 From: yefeng Date: Fri, 15 Jan 2021 17:40:12 +0800 Subject: [PATCH] stride_slice-5 --- mindspore/lite/src/ops/strided_slice.cc | 4 +- .../parser/tf/tf_stride_slice_parser.cc | 79 ++----------------- .../fusion/bidirection_tf_gru_cell_fusion.cc | 8 +- 3 files changed, 15 insertions(+), 76 deletions(-) diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index 0e9a0e401d..08ea8a90de 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -359,7 +359,9 @@ int StridedSlice::InferShape(std::vector inputs, std::vectorshape(); auto inferflag = infer_flag(); - + if (!infer_flag()) { + return RET_INFER_INVALID; + } in_shape_.clear(); if (inferflag) { in_shape_.assign(input_shape.begin(), input_shape.end()); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc index 6d6f31f998..f7959a98b7 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc @@ -72,76 +72,6 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, return RET_ERROR; } attr->shrinkAxisMask = attr_value.i(); - - // begin - auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1)); - if (begin_node == nullptr) { - MS_LOG(ERROR) << "Find StridedSlice input begin failed"; - return RET_ERROR; - } - if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) { - MS_LOG(ERROR) << "The value attr should be specified"; - return RET_ERROR; - } - auto tensor_proto = attr_value.tensor(); - if (tensor_proto.int_val_size() > 0) { - for (int i = 0; i < tensor_proto.int_val_size(); ++i) { - attr->begin.push_back(tensor_proto.int_val(i)); - } - } else { - auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); - auto data = reinterpret_cast(tensor_proto.tensor_content().data()); - for (size_t i = 0; i < data_num; ++i) { - attr->begin.push_back(data[i]); - } - } - - // end - auto end_node = GetConstInputNode(tf_node_map, tf_op.input(2)); - if (end_node == nullptr) { - MS_LOG(ERROR) << "Find StridedSlice input end failed"; - return RET_ERROR; - } - if (!TensorFlowUtils::FindAttrValue(*end_node, "value", &attr_value)) { - MS_LOG(ERROR) << "The value attr should be specified"; - return RET_ERROR; - } - tensor_proto = attr_value.tensor(); - if (tensor_proto.int_val_size() > 0) { - for (int i = 0; i < tensor_proto.int_val_size(); ++i) { - attr->end.push_back(tensor_proto.int_val(i)); - } - } else { - auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); - auto data = reinterpret_cast(tensor_proto.tensor_content().data()); - for (size_t i = 0; i < data_num; ++i) { - attr->end.push_back(data[i]); - } - } - - // strides - auto stride_node = GetConstInputNode(tf_node_map, tf_op.input(3)); - if (stride_node == nullptr) { - MS_LOG(ERROR) << "Find StridedSlice input strides failed"; - return RET_ERROR; - } - if (!TensorFlowUtils::FindAttrValue(*stride_node, "value", &attr_value)) { - MS_LOG(ERROR) << "The value attr should be specified"; - return RET_ERROR; - } - tensor_proto = attr_value.tensor(); - if (tensor_proto.int_val_size() > 0) { - for (int i = 0; i < tensor_proto.int_val_size(); ++i) { - attr->stride.push_back(tensor_proto.int_val(i)); - } - } else { - auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); - auto data = reinterpret_cast(tensor_proto.tensor_content().data()); - for (size_t i = 0; i < data_num; ++i) { - attr->stride.push_back(data[i]); - } - } - primitive->value.type = schema::PrimitiveType_StridedSlice; primitive->value.value = attr.release(); *primitiveC = PrimitiveC::Create(primitive.release()); @@ -151,7 +81,14 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, } *output_size = 1; - auto status = AddOpInput(tf_op, 0, inputs); + STATUS status = RET_OK; + for (int i = 0; i < tf_op.input_size(); i++) { + status = AddOpInput(tf_op, i, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + } return status; } TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser()); diff --git a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc index 1cc32b483b..5c09318901 100644 --- a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc @@ -71,8 +71,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { auto fw_shape = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_}); - auto fw_stride = - VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), fw_shape}); + auto fw_stride = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), + fw_shape, std::make_shared()}); auto fw_min = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2}); @@ -106,8 +106,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { bw_reverse_seq, std::make_shared()}); auto bw_shape = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans}); - auto bw_stride = - VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), bw_shape}); + auto bw_stride = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), + bw_shape, std::make_shared()}); auto bw_min = VectorRef({std::make_shared(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2}); auto bw_reserve =