!11365 [TF_LITE] fix tf stride slice parser
From: @YeFeng_24 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
dfaa52009c
|
@ -359,7 +359,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto input_shape = input->shape();
|
||||
auto inferflag = infer_flag();
|
||||
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
in_shape_.clear();
|
||||
if (inferflag) {
|
||||
in_shape_.assign(input_shape.begin(), input_shape.end());
|
||||
|
|
|
@ -72,76 +72,6 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
return RET_ERROR;
|
||||
}
|
||||
attr->shrinkAxisMask = attr_value.i();
|
||||
|
||||
// begin
|
||||
auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1));
|
||||
if (begin_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Find StridedSlice input begin failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The value attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_proto = attr_value.tensor();
|
||||
if (tensor_proto.int_val_size() > 0) {
|
||||
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
|
||||
attr->begin.push_back(tensor_proto.int_val(i));
|
||||
}
|
||||
} else {
|
||||
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
|
||||
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
|
||||
for (size_t i = 0; i < data_num; ++i) {
|
||||
attr->begin.push_back(data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// end
|
||||
auto end_node = GetConstInputNode(tf_node_map, tf_op.input(2));
|
||||
if (end_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Find StridedSlice input end failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!TensorFlowUtils::FindAttrValue(*end_node, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The value attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
tensor_proto = attr_value.tensor();
|
||||
if (tensor_proto.int_val_size() > 0) {
|
||||
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
|
||||
attr->end.push_back(tensor_proto.int_val(i));
|
||||
}
|
||||
} else {
|
||||
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
|
||||
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
|
||||
for (size_t i = 0; i < data_num; ++i) {
|
||||
attr->end.push_back(data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// strides
|
||||
auto stride_node = GetConstInputNode(tf_node_map, tf_op.input(3));
|
||||
if (stride_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Find StridedSlice input strides failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!TensorFlowUtils::FindAttrValue(*stride_node, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The value attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
tensor_proto = attr_value.tensor();
|
||||
if (tensor_proto.int_val_size() > 0) {
|
||||
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
|
||||
attr->stride.push_back(tensor_proto.int_val(i));
|
||||
}
|
||||
} else {
|
||||
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
|
||||
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
|
||||
for (size_t i = 0; i < data_num; ++i) {
|
||||
attr->stride.push_back(data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_StridedSlice;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
|
@ -151,7 +81,14 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
}
|
||||
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
STATUS status = RET_OK;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
status = AddOpInput(tf_op, i, inputs);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Add Op input failed.";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser());
|
||||
|
|
|
@ -71,8 +71,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
|
|||
|
||||
auto fw_shape =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_});
|
||||
auto fw_stride =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), fw_shape});
|
||||
auto fw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)),
|
||||
fw_shape, std::make_shared<SeqVar>()});
|
||||
auto fw_min =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2});
|
||||
|
||||
|
@ -106,8 +106,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
|
|||
bw_reverse_seq, std::make_shared<Var>()});
|
||||
auto bw_shape =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans});
|
||||
auto bw_stride =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), bw_shape});
|
||||
auto bw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)),
|
||||
bw_shape, std::make_shared<SeqVar>()});
|
||||
auto bw_min =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2});
|
||||
auto bw_reserve =
|
||||
|
|
Loading…
Reference in New Issue