forked from mindspore-Ecosystem/mindspore
!31577 [MS][LITE]codex parser tf and tflite fix
Merge pull request !31577 from luoyuan/codex-parser-0315
This commit is contained in:
commit
21bed8c9d8
|
@ -49,7 +49,7 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
}
|
||||
prim->set_stride(strides);
|
||||
|
||||
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1));
|
||||
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
|
||||
if (weight_node != nullptr) {
|
||||
std::vector<int64_t> kernels(4);
|
||||
if (ParseKernels(*weight_node, format, &kernels) != RET_OK) {
|
||||
|
|
|
@ -53,7 +53,7 @@ ops::PrimitiveC *TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
}
|
||||
prim->set_stride({strides[0], strides[1]});
|
||||
|
||||
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1));
|
||||
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
|
||||
if (weight_node != nullptr) {
|
||||
std::vector<int64_t> kernels(4);
|
||||
if (ParseKernels(*weight_node, format, &kernels) != RET_OK) {
|
||||
|
|
|
@ -38,7 +38,7 @@ ops::PrimitiveC *TFGatherParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
bool axis_is_set = false;
|
||||
if (tf_op.input_size() == 3) {
|
||||
axis_is_set = true;
|
||||
auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(2));
|
||||
auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(THIRD_INPUT));
|
||||
if (axis_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Find Gather input axis failed";
|
||||
return nullptr;
|
||||
|
|
|
@ -32,7 +32,7 @@ ops::PrimitiveC *TFRangeParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
if (TensorFlowUtils::FindAttrValue(tf_op, "starts", &attr_value)) {
|
||||
prim->set_start(attr_value.i());
|
||||
} else {
|
||||
auto input_0_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(0));
|
||||
auto input_0_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(FIRST_INPUT));
|
||||
if (tf_node_map.find(input_0_name) == tf_node_map.end()) {
|
||||
MS_LOG(ERROR) << "not find start node.";
|
||||
return nullptr;
|
||||
|
@ -48,7 +48,7 @@ ops::PrimitiveC *TFRangeParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
if (TensorFlowUtils::FindAttrValue(tf_op, "limits", &attr_value)) {
|
||||
prim->set_limit(attr_value.i());
|
||||
} else {
|
||||
auto input_1_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(1));
|
||||
auto input_1_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(SECOND_INPUT));
|
||||
if (tf_node_map.find(input_1_name) == tf_node_map.end()) {
|
||||
MS_LOG(ERROR) << "not find limit node.";
|
||||
return nullptr;
|
||||
|
@ -64,7 +64,7 @@ ops::PrimitiveC *TFRangeParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
if (TensorFlowUtils::FindAttrValue(tf_op, "deltas", &attr_value)) {
|
||||
prim->set_delta(attr_value.i());
|
||||
} else {
|
||||
auto input_2_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(2));
|
||||
auto input_2_name = TensorFlowUtils::GetFlattenNodeName(tf_op.input(THIRD_INPUT));
|
||||
if (tf_node_map.find(input_2_name) == tf_node_map.end()) {
|
||||
MS_LOG(ERROR) << "not find delta node.";
|
||||
return nullptr;
|
||||
|
|
|
@ -56,7 +56,7 @@ ops::PrimitiveC *TFResizeParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
} else {
|
||||
prim->set_method(mindspore::ResizeMethod::UNKNOWN);
|
||||
}
|
||||
auto size_node = tf_node_map.at(tf_op.input(1));
|
||||
auto size_node = tf_node_map.at(tf_op.input(SECOND_INPUT));
|
||||
if (size_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Find size input failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -29,7 +29,7 @@ ops::PrimitiveC *TFReverseParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
auto prim = std::make_unique<ops::ReverseV2>();
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
tensorflow::AttrValue attr_value;
|
||||
auto value = GetConstInputNode(tf_node_map, tf_op.input(1));
|
||||
auto value = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Find axis failed";
|
||||
return nullptr;
|
||||
|
|
|
@ -30,7 +30,7 @@ ops::PrimitiveC *TFSliceParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
// begin
|
||||
tensorflow::AttrValue attr_value;
|
||||
auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1));
|
||||
auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
|
||||
if (begin_node != nullptr) {
|
||||
if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The value attr should be specified";
|
||||
|
|
|
@ -61,7 +61,7 @@ ops::PrimitiveC *TFSplitParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
prim->set_axis(splitDim);
|
||||
|
||||
if (tf_op.op() == "SplitV") {
|
||||
auto size_splits_node = GetConstInputNode(tf_node_map, tf_op.input(1));
|
||||
auto size_splits_node = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
|
||||
if (size_splits_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Find Split input size_splits failed";
|
||||
return nullptr;
|
||||
|
|
|
@ -33,7 +33,7 @@ ops::PrimitiveC *TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::Operato
|
|||
prim->set_top_k(1);
|
||||
|
||||
std::vector<int64_t> axes;
|
||||
auto ret = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &axes);
|
||||
auto ret = GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &axes);
|
||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get axes value failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -33,7 +33,7 @@ ops::PrimitiveC *TfliteArgminParser::Parse(const std::unique_ptr<tflite::Operato
|
|||
prim->set_top_k(1);
|
||||
|
||||
std::vector<int64_t> axes;
|
||||
auto ret = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &axes);
|
||||
auto ret = GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &axes);
|
||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get axes value failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -31,7 +31,7 @@ ops::PrimitiveC *TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::Op
|
|||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
|
||||
std::vector<int64_t> dst_shape;
|
||||
if (GetTfliteData(tflite_op->inputs.at(1), tflite_subgraph->tensors, tflite_model->buffers, &dst_shape)) {
|
||||
if (GetTfliteData(tflite_op->inputs.at(SECOND_INPUT), tflite_subgraph->tensors, tflite_model->buffers, &dst_shape)) {
|
||||
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -118,8 +118,8 @@ ops::PrimitiveC *TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT
|
|||
MS_LOG(ERROR) << "the tflite_op shape is illegal";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_RET(static_cast<size_t>(tflite_op->inputs[1]) < tflite_subgraph->tensors.size(), nullptr);
|
||||
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs[1]);
|
||||
MS_CHECK_TRUE_RET(static_cast<size_t>(tflite_op->inputs[SECOND_INPUT]) < tflite_subgraph->tensors.size(), nullptr);
|
||||
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs[SECOND_INPUT]);
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return nullptr;
|
||||
|
@ -134,7 +134,7 @@ ops::PrimitiveC *TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT
|
|||
prim->set_kernel_size({weight_shape[kWeightKernelH], weight_shape[kWeightKernelW]});
|
||||
|
||||
// calculate pad params
|
||||
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs[0]);
|
||||
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs[FIRST_INPUT]);
|
||||
std::vector<int64_t> params;
|
||||
int status = GetConvPaddingParam(dataTensor, padMode, prim.get(), ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
|
@ -172,7 +172,7 @@ ops::PrimitiveC *TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite
|
|||
MS_LOG(ERROR) << "the tflite_op shape is illegal";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(1));
|
||||
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(SECOND_INPUT));
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return nullptr;
|
||||
|
@ -191,7 +191,7 @@ ops::PrimitiveC *TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite
|
|||
prim->set_group(weight_shape[kWeightChannelIn] / tflite_attr->depth_multiplier);
|
||||
|
||||
// get data tensor
|
||||
const auto &data_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(0));
|
||||
const auto &data_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(FIRST_INPUT));
|
||||
if (data_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "data_tensor is nullptr";
|
||||
return nullptr;
|
||||
|
|
|
@ -49,7 +49,7 @@ ops::PrimitiveC *TfliteDeConvParser::Parse(const std::unique_ptr<tflite::Operato
|
|||
MS_LOG(ERROR) << "the tflite_op shape is illegal";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(1));
|
||||
const auto &weight_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(SECOND_INPUT));
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return nullptr;
|
||||
|
@ -64,7 +64,7 @@ ops::PrimitiveC *TfliteDeConvParser::Parse(const std::unique_ptr<tflite::Operato
|
|||
prim->set_kernel_size({weight_shape[kWeightKernelH], weight_shape[kWeightKernelW]});
|
||||
|
||||
// calculate pad params
|
||||
const auto &data_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(2));
|
||||
const auto &data_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(THIRD_INPUT));
|
||||
std::vector<int64_t> params;
|
||||
int status = getPaddingParam(data_tensor, padMode, tflite_attr->stride_h, tflite_attr->stride_w,
|
||||
weight_shape[kWeightKernelH], weight_shape[kWeightKernelW], ¶ms);
|
||||
|
|
|
@ -165,7 +165,7 @@ ops::PrimitiveC *TfliteCustomParser::Rfft(const std::vector<uint8_t> &custom_att
|
|||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
|
||||
std::vector<int64_t> fft_length;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &fft_length)) {
|
||||
if (GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &fft_length)) {
|
||||
MS_LOG(ERROR) << "rfft -> fftLength get failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ ops::PrimitiveC *TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::Ope
|
|||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
MS_CHECK_TRUE_RET(!tflite_op->inputs.empty(), nullptr);
|
||||
MS_CHECK_TRUE_RET(!tflite_op->outputs.empty(), nullptr);
|
||||
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(0)];
|
||||
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(FIRST_INPUT)];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor is null";
|
||||
return nullptr;
|
||||
|
|
|
@ -30,7 +30,7 @@ ops::PrimitiveC *TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite:
|
|||
|
||||
prim->set_axis(1);
|
||||
prim->set_use_axis(false);
|
||||
prim->set_has_bias(tflite_op->inputs.size() > 2 && tflite_op->inputs.at(2) != -1);
|
||||
prim->set_has_bias(tflite_op->inputs.size() > kInputSize1 && tflite_op->inputs.at(THIRD_INPUT) != -1);
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsFullyConnectedOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
|
|
|
@ -47,7 +47,7 @@ ops::PrimitiveC *TfliteAvgPoolParser::Parse(const std::unique_ptr<tflite::Operat
|
|||
prim->set_activation_type(GetActivationFunctionType(tflite_attr->fused_activation_function));
|
||||
|
||||
// calculate pad params
|
||||
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(0));
|
||||
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(FIRST_INPUT));
|
||||
std::vector<int64_t> params;
|
||||
int status = getPaddingParam(dataTensor, padMode, tflite_attr->stride_h, tflite_attr->stride_w,
|
||||
tflite_attr->filter_height, tflite_attr->filter_width, ¶ms);
|
||||
|
@ -84,7 +84,7 @@ ops::PrimitiveC *TfliteMaxPoolParser::Parse(const std::unique_ptr<tflite::Operat
|
|||
prim->set_activation_type(GetActivationFunctionType(tflite_attr->fused_activation_function));
|
||||
|
||||
// calculate pad params
|
||||
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(0));
|
||||
const auto &dataTensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(FIRST_INPUT));
|
||||
std::vector<int64_t> params;
|
||||
int status = getPaddingParam(dataTensor, padMode, tflite_attr->stride_h, tflite_attr->stride_w,
|
||||
tflite_attr->filter_height, tflite_attr->filter_width, ¶ms);
|
||||
|
|
|
@ -27,7 +27,7 @@ ops::PrimitiveC *TfliteQuantizeParser::Parse(const std::unique_ptr<tflite::Opera
|
|||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
MS_CHECK_TRUE_RET(!tflite_op->inputs.empty(), nullptr);
|
||||
MS_CHECK_TRUE_RET(!tflite_op->outputs.empty(), nullptr);
|
||||
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
|
||||
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[FIRST_INPUT]];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor is null";
|
||||
return nullptr;
|
||||
|
|
|
@ -77,7 +77,7 @@ ops::PrimitiveC *TfliteResizeParser::Parse(const std::unique_ptr<tflite::Operato
|
|||
}
|
||||
|
||||
std::vector<int64_t> dims;
|
||||
auto ret = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &dims);
|
||||
auto ret = GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &dims);
|
||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get axes value failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -29,7 +29,7 @@ ops::PrimitiveC *TfliteReverseParser::Parse(const std::unique_ptr<tflite::Operat
|
|||
auto prim = std::make_unique<ops::ReverseV2>();
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
std::vector<int64_t> axis;
|
||||
if (GetTfliteData(tflite_op->inputs.at(1), tflite_subgraph->tensors, tflite_model->buffers, &axis)) {
|
||||
if (GetTfliteData(tflite_op->inputs.at(SECOND_INPUT), tflite_subgraph->tensors, tflite_model->buffers, &axis)) {
|
||||
MS_LOG(ERROR) << "get reverse -> axis failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ ops::PrimitiveC *TfliteSliceParser::Parse(const std::unique_ptr<tflite::Operator
|
|||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
|
||||
std::vector<int64_t> begin;
|
||||
auto ret = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &begin);
|
||||
auto ret = GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &begin);
|
||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get slice -> begin failed";
|
||||
return nullptr;
|
||||
|
|
|
@ -31,13 +31,14 @@ ops::PrimitiveC *TfliteSpaceToBatchNDParser::Parse(const std::unique_ptr<tflite:
|
|||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
|
||||
std::vector<int64_t> blockShape;
|
||||
if (GetTfliteData(tflite_op->inputs.at(1), tflite_subgraph->tensors, tflite_model->buffers, &blockShape)) {
|
||||
if (GetTfliteData(tflite_op->inputs.at(SECOND_INPUT), tflite_subgraph->tensors, tflite_model->buffers, &blockShape)) {
|
||||
MS_LOG(ERROR) << "get spaceToBatchND -> blockShape failed";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_block_shape(blockShape);
|
||||
std::vector<std::vector<int64_t>> paddings;
|
||||
if (TransTfliteDataToVec2D(tflite_op->inputs.at(2), tflite_subgraph->tensors, tflite_model->buffers, &paddings)) {
|
||||
if (TransTfliteDataToVec2D(tflite_op->inputs.at(THIRD_INPUT), tflite_subgraph->tensors, tflite_model->buffers,
|
||||
&paddings)) {
|
||||
MS_LOG(ERROR) << "get spaceToBatchND -> paddings failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -36,14 +36,14 @@ ops::PrimitiveC *TfliteSplitParser::Parse(const std::unique_ptr<tflite::Operator
|
|||
return nullptr;
|
||||
}
|
||||
auto num_splits = tflite_attr->num_splits;
|
||||
const auto &shape_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(1));
|
||||
const auto &shape_tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(SECOND_INPUT));
|
||||
if (shape_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "shape_tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
const auto tensor_shape = shape_tensor->shape;
|
||||
std::vector<int64_t> axes;
|
||||
auto ret = GetTfliteData(tflite_op->inputs.at(0), tflite_subgraph->tensors, tflite_model->buffers, &axes);
|
||||
auto ret = GetTfliteData(tflite_op->inputs.at(FIRST_INPUT), tflite_subgraph->tensors, tflite_model->buffers, &axes);
|
||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get axes value failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -38,20 +38,20 @@ ops::PrimitiveC *TfliteSplitVParser::Parse(const std::unique_ptr<tflite::Operato
|
|||
prim->set_output_num(tflite_attr->num_splits);
|
||||
|
||||
std::vector<int64_t> size_splits;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, &size_splits)) {
|
||||
if (GetTfliteData(tflite_op->inputs[SECOND_INPUT], tflite_subgraph->tensors, tflite_model->buffers, &size_splits)) {
|
||||
MS_LOG(ERROR) << "get spliteV -> sizeSplits failed";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_size_splits(size_splits);
|
||||
|
||||
const auto &tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(0));
|
||||
const auto &tensor = tflite_subgraph->tensors.at(tflite_op->inputs.at(FIRST_INPUT));
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_shape is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor_shape = tensor->shape;
|
||||
std::vector<int64_t> axes;
|
||||
auto ret = GetTfliteData(tflite_op->inputs.at(2), tflite_subgraph->tensors, tflite_model->buffers, &axes);
|
||||
auto ret = GetTfliteData(tflite_op->inputs.at(THIRD_INPUT), tflite_subgraph->tensors, tflite_model->buffers, &axes);
|
||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get axes value failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -110,8 +110,8 @@ STATUS getPaddingParam(const std::unique_ptr<tflite::TensorT> &tensor, mindspore
|
|||
if (pad_mode == mindspore::PadMode::SAME) {
|
||||
auto shape = tensor->shape;
|
||||
MS_CHECK_TRUE_RET(shape.size() == DIMENSION_4D, RET_ERROR);
|
||||
int H_input = shape.at(1);
|
||||
int W_input = shape.at(2);
|
||||
int H_input = shape.at(kNHWC_H);
|
||||
int W_input = shape.at(kNHWC_W);
|
||||
if (strideH == 0) {
|
||||
MS_LOG(ERROR) << "strideH is zero";
|
||||
return RET_ERROR;
|
||||
|
|
Loading…
Reference in New Issue