!31577 [MS][LITE]codex parser tf and tflite fix

Merge pull request !31577 from luoyuan/codex-parser-0315
This commit is contained in:
i-robot 2022-03-21 08:40:17 +00:00 committed by Gitee
commit 21bed8c9d8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
25 changed files with 39 additions and 38 deletions

View File

@ -49,7 +49,7 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
}
prim->set_stride(strides);
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1));
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
if (weight_node != nullptr) {
std::vector<int64_t> kernels(4);
if (ParseKernels(*weight_node, format, &kernels) != RET_OK) {

View File

@ -53,7 +53,7 @@ ops::PrimitiveC *TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op,
}
prim->set_stride({strides[0], strides[1]});
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1));
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
if (weight_node != nullptr) {
std::vector<int64_t> kernels(4);
if (ParseKernels(*weight_node, format, &kernels) != RET_OK) {

View File

@ -38,7 +38,7 @@ ops::PrimitiveC *TFGatherParser::Parse(const tensorflow::NodeDef &tf_op,
bool axis_is_set = false;
if (tf_op.input_size() == 3) {
axis_is_set = true;
auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(2));
auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(THIRD_INPUT));
if (axis_node == nullptr) {
MS_LOG(ERROR) << "Find Gather input axis failed";
return nullptr;

View File

@ -32,7 +32,7 @@ ops::PrimitiveC *TFRangeParser::Parse(const tensorflow::NodeDef &tf_op,
if (TensorFlowUtils::FindAttrValue(tf_op, "starts", &attr_value)) {
prim->set_start(attr_value.i());
} else {
auto input_0_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(0));
auto input_0_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(FIRST_INPUT));
if (tf_node_map.find(input_0_name) == tf_node_map.end()) {
MS_LOG(ERROR) << "not find start node.";
return nullptr;
@ -48,7 +48,7 @@ ops::PrimitiveC *TFRangeParser::Parse(const tensorflow::NodeDef &tf_op,
if (TensorFlowUtils::FindAttrValue(tf_op, "limits", &attr_value)) {
prim->set_limit(attr_value.i());
} else {
auto input_1_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(1));
auto input_1_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(SECOND_INPUT));
if (tf_node_map.find(input_1_name) == tf_node_map.end()) {
MS_LOG(ERROR) << "not find limit node.";
return nullptr;
@ -64,7 +64,7 @@ ops::PrimitiveC *TFRangeParser::Parse(const tensorflow::NodeDef &tf_op,
if (TensorFlowUtils::FindAttrValue(tf_op, "deltas", &attr_value)) {
prim->set_delta(attr_value.i());
} else {
auto input_2_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(2));
auto input_2_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(THIRD_INPUT));
if (tf_node_map.find(input_2_name) == tf_node_map.end()) {
MS_LOG(ERROR) << "not find delta node.";
return nullptr;

View File

@ -56,7 +56,7 @@ ops::PrimitiveC *TFResizeParser::Parse(const tensorflow::NodeDef &tf_op,
} else {
prim->set_method(mindspore::ResizeMethod::UNKNOWN);
}
auto size_node = tf_node_map.at(tf_op.input(1));
auto size_node = tf_node_map.at(tf_op.input(SECOND_INPUT));
if (size_node == nullptr) {
MS_LOG(ERROR) << "Find size input failed.";
return nullptr;

View File

@ -29,7 +29,7 @@ ops::PrimitiveC *TFReverseParser::Parse(const tensorflow::NodeDef &tf_op,
auto prim = std::make_unique<ops::ReverseV2>();
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
tensorflow::AttrValue attr_value;
auto value = GetConstInputNode(tf_node_map, tf_op.input(1));
auto value = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
if (value == nullptr) {
MS_LOG(ERROR) << "Find axis failed";
return nullptr;

View File

@ -30,7 +30,7 @@ ops::PrimitiveC *TFSliceParser::Parse(const tensorflow::NodeDef &tf_op,
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
// begin
tensorflow::AttrValue attr_value;
auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1));
auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
if (begin_node != nullptr) {
if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";

View File

@ -61,7 +61,7 @@ ops::PrimitiveC *TFSplitParser::Parse(const tensorflow::NodeDef &tf_op,
prim->set_axis(splitDim);
if (tf_op.op() == "SplitV") {
auto size_splits_node = GetConstInputNode(tf_node_map, tf_op.input(1));
auto size_splits_node = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
if (size_splits_node == nullptr) {
MS_LOG(ERROR) << "Find Split input size_splits failed";
return nullptr;

View File

@ -33,7 +33,7 @@ ops::PrimitiveC *TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::Operato
prim->set_top_k(1);
std::vector<int64_t> axes;
auto ret = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &axes);
auto ret = GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &axes);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get axes value failed.";
return nullptr;

View File

@ -33,7 +33,7 @@ ops::PrimitiveC *TfliteArgminParser::Parse(const std::unique_ptr<tflite::Operato
prim->set_top_k(1);
std::vector<int64_t> axes;
auto ret = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &axes);
auto ret = GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &axes);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get axes value failed.";
return nullptr;

View File

@ -31,7 +31,7 @@ ops::PrimitiveC *TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::Op
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
std::vector<int64_t> dst_shape;
if (GetTfliteData(tflite_op->inputs.at(1), tflite_subgraph->tensors, tflite_model->buffers, &dst_shape)) {
if (GetTfliteData(tflite_op->inputs.at(SECOND_INPUT), tflite_subgraph->tensors, tflite_model->buffers, &dst_shape)) {
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
return nullptr;
}

View File

@ -118,8 +118,8 @@ ops::PrimitiveC *TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT
MS_LOG(ERROR) << "the tflite_op shape is illegal";
return nullptr;
}
MS_CHECK_TRUE_RET(static_cast<size_t>(tflite_op->inputs[1]) < tflite_subgraph->tensors.size(), nullptr);
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs[1]);
MS_CHECK_TRUE_RET(static_cast<size_t>(tflite_op->inputs[SECOND_INPUT]) < tflite_subgraph->tensors.size(), nullptr);
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs[SECOND_INPUT]);
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "the weight tensor is null";
return nullptr;
@ -134,7 +134,7 @@ ops::PrimitiveC *TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT
prim->set_kernel_size({weight_shape[kWeightKernelH], weight_shape[kWeightKernelW]});
// calculate pad params
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs[0]);
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs[FIRST_INPUT]);
std::vector<int64_t> params;
int status = GetConvPaddingParam(dataTensor, padMode, prim.get(), &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
@ -172,7 +172,7 @@ ops::PrimitiveC *TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite
MS_LOG(ERROR) << "the tflite_op shape is illegal";
return nullptr;
}
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(1));
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(SECOND_INPUT));
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "the weight tensor is null";
return nullptr;
@ -191,7 +191,7 @@ ops::PrimitiveC *TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite
prim->set_group(weight_shape[kWeightChannelIn] / tflite_attr->depth_multiplier);
// get data tensor
const auto &data_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(0));
const auto &data_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(FIRST_INPUT));
if (data_tensor == nullptr) {
MS_LOG(ERROR) << "data_tensor is nullptr";
return nullptr;

View File

@ -49,7 +49,7 @@ ops::PrimitiveC *TfliteDeConvParser::Parse(const std::unique_ptr<tflite::Operato
MS_LOG(ERROR) << "the tflite_op shape is illegal";
return nullptr;
}
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(1));
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(SECOND_INPUT));
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "the weight tensor is null";
return nullptr;
@ -64,7 +64,7 @@ ops::PrimitiveC *TfliteDeConvParser::Parse(const std::unique_ptr<tflite::Operato
prim->set_kernel_size({weight_shape[kWeightKernelH], weight_shape[kWeightKernelW]});
// calculate pad params
const auto &data_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(2));
const auto &data_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(THIRD_INPUT));
std::vector<int64_t> params;
int status = getPaddingParam(data_tensor, padMode, tflite_attr->stride_h, tflite_attr->stride_w,
weight_shape[kWeightKernelH], weight_shape[kWeightKernelW], &params);

View File

@ -165,7 +165,7 @@ ops::PrimitiveC *TfliteCustomParser::Rfft(const std::vector<uint8_t> &custom_att
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
std::vector<int64_t> fft_length;
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &fft_length)) {
if (GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &fft_length)) {
MS_LOG(ERROR) << "rfft -> fftLength get failed";
return nullptr;
}

View File

@ -27,7 +27,7 @@ ops::PrimitiveC *TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::Ope
const std::unique_ptr<tflite::ModelT> &tflite_model) {
MS_CHECK_TRUE_RET(!tflite_op->inputs.empty(), nullptr);
MS_CHECK_TRUE_RET(!tflite_op->outputs.empty(), nullptr);
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(0)];
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(FIRST_INPUT)];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "input tensor is null";
return nullptr;

View File

@ -30,7 +30,7 @@ ops::PrimitiveC *TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite:
prim->set_axis(1);
prim->set_use_axis(false);
prim->set_has_bias(tflite_op->inputs.size() > 2 && tflite_op->inputs.at(2) != -1);
prim->set_has_bias(tflite_op->inputs.size() > kInputSize1 && tflite_op->inputs.at(THIRD_INPUT) != -1);
const auto &tflite_attr = tflite_op->builtin_options.AsFullyConnectedOptions();
if (tflite_attr == nullptr) {

View File

@ -47,7 +47,7 @@ ops::PrimitiveC *TfliteAvgPoolParser::Parse(const std::unique_ptr<tflite::Operat
prim->set_activation_type(GetActivationFunctionType(tflite_attr->fused_activation_function));
// calculate pad params
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(0));
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(FIRST_INPUT));
std::vector<int64_t> params;
int status = getPaddingParam(dataTensor, padMode, tflite_attr->stride_h, tflite_attr->stride_w,
tflite_attr->filter_height, tflite_attr->filter_width, &params);
@ -84,7 +84,7 @@ ops::PrimitiveC *TfliteMaxPoolParser::Parse(const std::unique_ptr<tflite::Operat
prim->set_activation_type(GetActivationFunctionType(tflite_attr->fused_activation_function));
// calculate pad params
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(0));
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(FIRST_INPUT));
std::vector<int64_t> params;
int status = getPaddingParam(dataTensor, padMode, tflite_attr->stride_h, tflite_attr->stride_w,
tflite_attr->filter_height, tflite_attr->filter_width, &params);

View File

@ -27,7 +27,7 @@ ops::PrimitiveC *TfliteQuantizeParser::Parse(const std::unique_ptr<tflite::Opera
const std::unique_ptr<tflite::ModelT> &tflite_model) {
MS_CHECK_TRUE_RET(!tflite_op->inputs.empty(), nullptr);
MS_CHECK_TRUE_RET(!tflite_op->outputs.empty(), nullptr);
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[FIRST_INPUT]];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "input tensor is null";
return nullptr;

View File

@ -77,7 +77,7 @@ ops::PrimitiveC *TfliteResizeParser::Parse(const std::unique_ptr<tflite::Operato
}
std::vector<int64_t> dims;
auto ret = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &dims);
auto ret = GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &dims);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get axes value failed.";
return nullptr;

View File

@ -29,7 +29,7 @@ ops::PrimitiveC *TfliteReverseParser::Parse(const std::unique_ptr<tflite::Operat
auto prim = std::make_unique<ops::ReverseV2>();
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
std::vector<int64_t> axis;
if (GetTfliteData(tflite_op->inputs.at(1), tflite_subgraph->tensors, tflite_model->buffers, &axis)) {
if (GetTfliteData(tflite_op->inputs.at(SECOND_INPUT), tflite_subgraph->tensors, tflite_model->buffers, &axis)) {
MS_LOG(ERROR) << "get reverse -> axis failed";
return nullptr;
}

View File

@ -30,7 +30,7 @@ ops::PrimitiveC *TfliteSliceParser::Parse(const std::unique_ptr<tflite::Operator
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
std::vector<int64_t> begin;
auto ret = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &begin);
auto ret = GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &begin);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get slice -> begin failed";
return nullptr;

View File

@ -31,13 +31,14 @@ ops::PrimitiveC *TfliteSpaceToBatchNDParser::Parse(const std::unique_ptr<tflite:
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
std::vector<int64_t> blockShape;
if (GetTfliteData(tflite_op->inputs.at(1), tflite_subgraph->tensors, tflite_model->buffers, &blockShape)) {
if (GetTfliteData(tflite_op->inputs.at(SECOND_INPUT), tflite_subgraph->tensors, tflite_model->buffers, &blockShape)) {
MS_LOG(ERROR) << "get spaceToBatchND -> blockShape failed";
return nullptr;
}
prim->set_block_shape(blockShape);
std::vector<std::vector<int64_t>> paddings;
if (TransTfliteDataToVec2D(tflite_op->inputs.at(2), tflite_subgraph->tensors, tflite_model->buffers, &paddings)) {
if (TransTfliteDataToVec2D(tflite_op->inputs.at(THIRD_INPUT), tflite_subgraph->tensors, tflite_model->buffers,
&paddings)) {
MS_LOG(ERROR) << "get spaceToBatchND -> paddings failed";
return nullptr;
}

View File

@ -36,14 +36,14 @@ ops::PrimitiveC *TfliteSplitParser::Parse(const std::unique_ptr<tflite::Operator
return nullptr;
}
auto num_splits = tflite_attr->num_splits;
const auto &shape_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(1));
const auto &shape_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(SECOND_INPUT));
if (shape_tensor == nullptr) {
MS_LOG(ERROR) << "shape_tensor is null";
return nullptr;
}
const auto tensor_shape = shape_tensor->shape;
std::vector<int64_t> axes;
auto ret = GetTfliteData(tflite_op->inputs.at(0), tflite_subgraph->tensors, tflite_model->buffers, &axes);
auto ret = GetTfliteData(tflite_op->inputs.at(FIRST_INPUT), tflite_subgraph->tensors, tflite_model->buffers, &axes);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get axes value failed.";
return nullptr;

View File

@ -38,20 +38,20 @@ ops::PrimitiveC *TfliteSplitVParser::Parse(const std::unique_ptr<tflite::Operato
prim->set_output_num(tflite_attr->num_splits);
std::vector<int64_t> size_splits;
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &size_splits)) {
if (GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &size_splits)) {
MS_LOG(ERROR) << "get spliteV -> sizeSplits failed";
return nullptr;
}
prim->set_size_splits(size_splits);
const auto &tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(0));
const auto &tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(FIRST_INPUT));
if (tensor == nullptr) {
MS_LOG(ERROR) << "tensor_shape is null";
return nullptr;
}
auto tensor_shape = tensor->shape;
std::vector<int64_t> axes;
auto ret = GetTfliteData(tflite_op->inputs.at(2), tflite_subgraph->tensors, tflite_model->buffers, &axes);
auto ret = GetTfliteData(tflite_op->inputs.at(THIRD_INPUT), tflite_subgraph->tensors, tflite_model->buffers, &axes);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get axes value failed.";
return nullptr;

View File

@ -110,8 +110,8 @@ STATUS getPaddingParam(const std::unique_ptr<tflite::TensorT> &tensor, mindspore
if (pad_mode == mindspore::PadMode::SAME) {
auto shape = tensor->shape;
MS_CHECK_TRUE_RET(shape.size() == DIMENSION_4D, RET_ERROR);
int H_input = shape.at(1);
int W_input = shape.at(2);
int H_input = shape.at(kNHWC_H);
int W_input = shape.at(kNHWC_W);
if (strideH == 0) {
MS_LOG(ERROR) << "strideH is zero";
return RET_ERROR;