refactor tflite parsers.

append ut

refactor tflite parsers

modify tflite parser, ut and model

supplement caffe flatten parser

fix the weight tensor format of deconv bug

fix bug when idx=-1

fix the weight tensor format of depthConv bug.
This commit is contained in:
lyvette 2020-08-12 17:26:51 +08:00
parent a0c12e7aa7
commit 123c2024a5
177 changed files with 2643 additions and 2223 deletions

View File

@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./
TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/ TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/
cp -fr $TEST_DATA_DIR/testPK ./data cp -fr $TEST_DATA_DIR/testPK ./data
#./lite-test --gtest_filter="*MindDataTestTensorDE*" ./lite-test --gtest_filter="*MindDataTestTensorDE*"
#./lite-test --gtest_filter="*MindDataTestEager*" ./lite-test --gtest_filter="*MindDataTestEager*"
#
#./lite-test --gtest_filter="TestTfliteParser*" ./lite-test --gtest_filter="TestTfliteParser*"
#
#./lite-test --gtest_filter="*TestHebing*" ./lite-test --gtest_filter="*TestHebing*"
#
#./lite-test --gtest_filter=TestFcFp32* ./lite-test --gtest_filter=TestFcFp32*
#./lite-test --gtest_filter=TestConv1x1Fp32* ./lite-test --gtest_filter=TestConv1x1Fp32*
#./lite-test --gtest_filter=TestStrassenFp32* ./lite-test --gtest_filter=TestStrassenFp32*
#./lite-test --gtest_filter=TestDeConvolutionFp32* ./lite-test --gtest_filter=TestDeConvolutionFp32*
#
#./lite-test --gtest_filter=TestPadInt8.* ./lite-test --gtest_filter=TestPadInt8.*
#./lite-test --gtest_filter=TestDeconvInt8.* ./lite-test --gtest_filter=TestDeconvInt8.*
./lite-test --gtest_filter="TestTfliteParser*" ./lite-test --gtest_filter="TestTfliteParser*"

View File

@ -31,6 +31,12 @@ TEST_F(TestTfliteParserRelu, OpType) {
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
} }
TEST_F(TestTfliteParserRelu, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_RELU);
}
class TestTfliteParserRelu6 : public TestTfliteParser { class TestTfliteParserRelu6 : public TestTfliteParser {
public: public:
TestTfliteParserRelu6() = default; TestTfliteParserRelu6() = default;
@ -43,6 +49,12 @@ TEST_F(TestTfliteParserRelu6, OpType) {
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
} }
TEST_F(TestTfliteParserRelu6, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_RELU6);
}
class TestTfliteParserTanh : public TestTfliteParser { class TestTfliteParserTanh : public TestTfliteParser {
public: public:
TestTfliteParserTanh() = default; TestTfliteParserTanh() = default;
@ -55,7 +67,45 @@ TEST_F(TestTfliteParserTanh, OpType) {
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
} }
// logistic TEST_F(TestTfliteParserTanh, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_TANH);
}
class TestTfliteParserLogistic : public TestTfliteParser {
public:
TestTfliteParserLogistic() = default;
void SetUp() override { meta_graph = LoadAndConvert("./logistic.tflite", ""); }
};
TEST_F(TestTfliteParserLogistic, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
}
TEST_F(TestTfliteParserLogistic, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_SIGMOID);
}
class TestTfliteParserHardSwish : public TestTfliteParser {
public:
TestTfliteParserHardSwish() = default;
void SetUp() override { meta_graph = LoadAndConvert("./hardswish.tflite", ""); }
};
TEST_F(TestTfliteParserHardSwish, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
}
TEST_F(TestTfliteParserHardSwish, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
ASSERT_EQ(val->type, schema::ActivationType_SIGMOID);
}
class TestTfliteParserPrelu : public TestTfliteParser { class TestTfliteParserPrelu : public TestTfliteParser {
public: public:
@ -73,12 +123,11 @@ TEST_F(TestTfliteParserPrelu, OpType) {
} }
TEST_F(TestTfliteParserPrelu, AttrValue) { TEST_F(TestTfliteParserPrelu, AttrValue) {
std::vector<float> slope(20, 0);
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsPrelu()->slope, slope); auto val = meta_graph->nodes.front()->primitive->value;
std::vector<float> slope(20, 0);
ASSERT_EQ(val.AsPrelu()->slope, slope);
ASSERT_EQ(val.type, schema::PrimitiveType_Prelu);
} }
class TestTfliteParserLeakyRelu : public TestTfliteParser { class TestTfliteParserLeakyRelu : public TestTfliteParser {
@ -94,12 +143,10 @@ TEST_F(TestTfliteParserLeakyRelu, OpType) {
} }
TEST_F(TestTfliteParserLeakyRelu, AttrValue) { TEST_F(TestTfliteParserLeakyRelu, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLeakyReLU(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); auto val = meta_graph->nodes.front()->primitive->value;
ASSERT_EQ(val.AsLeakyReLU()->negativeSlope, 0.20000000298023224);
auto val = meta_graph->nodes.front()->primitive->value.AsLeakyReLU(); ASSERT_EQ(val.type, schema::PrimitiveType_LeakyReLU);
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->negativeSlope, 0.20000000298023224);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -35,10 +35,8 @@ TEST_F(TestTfliteParserAddN, OpType) {
} }
TEST_F(TestTfliteParserAddN, AttrValue) { TEST_F(TestTfliteParserAddN, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsAddN()->N, 4); auto val = meta_graph->nodes.front()->primitive->value.AsAddN();
ASSERT_EQ(val->N, 4);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserArgmax : public TestTfliteParser {
public:
TestTfliteParserArgmax() = default;
void SetUp() override { meta_graph = LoadAndConvert("./argmax.tflite", ""); }
};
TEST_F(TestTfliteParserArgmax, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMax) << "wrong Op Type";
}
TEST_F(TestTfliteParserArgmax, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMax(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsArgMax();
ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->topK, 1);
ASSERT_EQ(val->axisType, 1);
ASSERT_EQ(val->keepDims, false);
ASSERT_EQ(val->outMaxValue, false);
}
} // namespace mindspore

View File

@ -25,15 +25,14 @@ class TestTfliteParserArgmin : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserArgmin, OpType) { TEST_F(TestTfliteParserArgmin, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type";
} }
TEST_F(TestTfliteParserArgmin, AttrValue) { TEST_F(TestTfliteParserArgmin, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMin(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsArgMin(); auto val = meta_graph->nodes.front()->primitive->value.AsArgMin();
ASSERT_EQ(val->axis, 1); ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->topK, 1); ASSERT_EQ(val->topK, 1);

View File

@ -19,234 +19,57 @@
namespace mindspore { namespace mindspore {
// doubleInputOp // doubleInputOp
class TestTfliteParserAdd1 : public TestTfliteParser { class TestTfliteParserAdd : public TestTfliteParser {
public: public:
TestTfliteParserAdd1() = default; TestTfliteParserAdd() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add1.tflite", ""); } void SetUp() override { meta_graph = LoadAndConvert("./add.tflite", ""); }
}; };
TEST_F(TestTfliteParserAdd1, OpType) { TEST_F(TestTfliteParserAdd, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
} }
TEST_F(TestTfliteParserAdd1, Tensor) { class TestTfliteParserSub : public TestTfliteParser {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserAdd2 : public TestTfliteParser {
public: public:
TestTfliteParserAdd2() = default; TestTfliteParserSub() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add2.tflite", ""); } void SetUp() override { meta_graph = LoadAndConvert("./sub.tflite", ""); }
}; };
TEST_F(TestTfliteParserAdd2, OpType) { TEST_F(TestTfliteParserSub, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph, nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
}
TEST_F(TestTfliteParserAdd2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserAdd3 : public TestTfliteParser {
public:
TestTfliteParserAdd3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./add3.tflite", ""); }
};
TEST_F(TestTfliteParserAdd3, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
}
TEST_F(TestTfliteParserAdd3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserSub1 : public TestTfliteParser {
public:
TestTfliteParserSub1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub1.tflite", ""); }
};
TEST_F(TestTfliteParserSub1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
} }
TEST_F(TestTfliteParserSub1, Tensor) { class TestTfliteParserMul : public TestTfliteParser {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserSub2 : public TestTfliteParser {
public: public:
TestTfliteParserSub2() = default; TestTfliteParserMul() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub2.tflite", ""); } void SetUp() override { meta_graph = LoadAndConvert("./mul.tflite", ""); }
}; };
TEST_F(TestTfliteParserSub2, OpType) { TEST_F(TestTfliteParserMul, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph, nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
}
TEST_F(TestTfliteParserSub2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserSub3 : public TestTfliteParser {
public:
TestTfliteParserSub3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./sub3.tflite", ""); }
};
TEST_F(TestTfliteParserSub3, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
}
TEST_F(TestTfliteParserSub3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserMul1 : public TestTfliteParser {
public:
TestTfliteParserMul1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul1.tflite", ""); }
};
TEST_F(TestTfliteParserMul1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
} }
TEST_F(TestTfliteParserMul1, Tensor) { class TestTfliteParserDiv : public TestTfliteParser {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserMul2 : public TestTfliteParser {
public: public:
TestTfliteParserMul2() = default; TestTfliteParserDiv() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul2.tflite", ""); } void SetUp() override { meta_graph = LoadAndConvert("./div.tflite", ""); }
}; };
TEST_F(TestTfliteParserMul2, OpType) { TEST_F(TestTfliteParserDiv, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph, nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
}
TEST_F(TestTfliteParserMul2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserMul3 : public TestTfliteParser {
public:
TestTfliteParserMul3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./mul3.tflite", ""); }
};
TEST_F(TestTfliteParserMul3, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
}
TEST_F(TestTfliteParserMul3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserDiv1 : public TestTfliteParser {
public:
TestTfliteParserDiv1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div1.tflite", ""); }
};
TEST_F(TestTfliteParserDiv1, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
} }
TEST_F(TestTfliteParserDiv1, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserDiv2 : public TestTfliteParser {
public:
TestTfliteParserDiv2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div2.tflite", ""); }
};
TEST_F(TestTfliteParserDiv2, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
}
TEST_F(TestTfliteParserDiv2, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserDiv3 : public TestTfliteParser {
public:
TestTfliteParserDiv3() = default;
void SetUp() override { meta_graph = LoadAndConvert("./div3.tflite", ""); }
};
TEST_F(TestTfliteParserDiv3, OpType) {
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
}
TEST_F(TestTfliteParserDiv3, Tensor) {
ASSERT_GT(meta_graph->allTensors.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
}
class TestTfliteParserFloorDiv : public TestTfliteParser { class TestTfliteParserFloorDiv : public TestTfliteParser {
public: public:
TestTfliteParserFloorDiv() = default; TestTfliteParserFloorDiv() = default;
@ -254,6 +77,7 @@ class TestTfliteParserFloorDiv : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserFloorDiv, OpType) { TEST_F(TestTfliteParserFloorDiv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type";
@ -266,12 +90,26 @@ class TestTfliteParserFloorMod : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserFloorMod, OpType) { TEST_F(TestTfliteParserFloorMod, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type";
} }
// realDiv class TestTfliteParserRealDiv : public TestTfliteParser {
public:
TestTfliteParserRealDiv() = default;
void SetUp() override {
meta_graph = LoadAndConvert("./realdiv.tflite");
}
};
TEST_F(TestTfliteParserRealDiv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
}
class TestTfliteParserSquaredDifference : public TestTfliteParser { class TestTfliteParserSquaredDifference : public TestTfliteParser {
public: public:
@ -296,17 +134,15 @@ class TestTfliteParserPow : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserPow, OpType) { TEST_F(TestTfliteParserPow, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type";
} }
TEST_F(TestTfliteParserPow, AttrValue) { TEST_F(TestTfliteParserPow, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPower(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPower(); auto val = meta_graph->nodes.front()->primitive->value.AsPower();
ASSERT_EQ(val->scale, 1.0); ASSERT_EQ(val->scale, 1.0);
ASSERT_EQ(val->shift, 0.0); ASSERT_EQ(val->shift, 0.0);
ASSERT_EQ(val->power, 0.0); ASSERT_EQ(val->power, 0.0);
@ -477,6 +313,7 @@ class TestTfliteParserFloor : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserFloor, OpType) { TEST_F(TestTfliteParserFloor, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type";

View File

@ -32,14 +32,12 @@ TEST_F(TestTfliteParserBatchToSpaceNd, OpType) {
} }
TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) { TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) {
const std::vector<int> blockShape{2, 2};
const std::vector<int> crops{0, 0, 2, 0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->blockShape, blockShape); auto val = meta_graph->nodes.front()->primitive->value.AsBatchToSpace();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->crops, crops); const std::vector<int> blockShape = {2, 2};
ASSERT_EQ(val->blockShape, blockShape);
const std::vector<int> crops = {0, 0, 2, 0};
ASSERT_EQ(val->crops, crops);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -35,12 +35,9 @@ TEST_F(TestTfliteParserCast, OpType) {
} }
TEST_F(TestTfliteParserCast, AttrValue) { TEST_F(TestTfliteParserCast, AttrValue) {
// float32 --> int32
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->srcT, 43); auto val = meta_graph->nodes.front()->primitive->value.AsCast();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->dstT, 34); ASSERT_EQ(val->srcT, 43);
ASSERT_EQ(val->dstT, 34);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserConcat : public TestTfliteParser {
public:
TestTfliteParserConcat() = default;
void SetUp() override { meta_graph = LoadAndConvert("./concat.tflite", ""); }
};
TEST_F(TestTfliteParserConcat, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Concat) << "wrong Op Type";
}
TEST_F(TestTfliteParserConcat, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsConcat();
ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->n, 2);
}
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserConv : public TestTfliteParser {
public:
TestTfliteParserConv() = default;
void SetUp() override { meta_graph = LoadAndConvert("./conv.tflite", ""); }
};
TEST_F(TestTfliteParserConv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
}
TEST_F(TestTfliteParserConv, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 1);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);
ASSERT_EQ(val->channelIn, 1);
ASSERT_EQ(val->channelOut, 4);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserDeConv : public TestTfliteParser {
public:
TestTfliteParserDeConv() = default;
void SetUp() override { meta_graph = LoadAndConvert("./deconv.tflite", ""); }
};
TEST_F(TestTfliteParserDeConv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type";
}
TEST_F(TestTfliteParserDeConv, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 1);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);
ASSERT_EQ(val->channelIn, 1);
ASSERT_EQ(val->channelOut, 4);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}
} // namespace mindspore

View File

@ -35,11 +35,9 @@ TEST_F(TestTfliteParserDepthToSpace, OpType) {
} }
TEST_F(TestTfliteParserDepthToSpace, AttrValue) { TEST_F(TestTfliteParserDepthToSpace, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->blockSize, 4); auto val = meta_graph->nodes.front()->primitive->value.AsDepthToSpace();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->format, schema::Format_NHWC); ASSERT_EQ(val->blockSize, 4);
ASSERT_EQ(val->format, schema::Format_NHWC);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,92 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserDepthwiseConv1 : public TestTfliteParser {
public:
TestTfliteParserDepthwiseConv1() = default;
void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv1.tflite", ""); }
};
TEST_F(TestTfliteParserDepthwiseConv1, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
}
TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 0);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);
ASSERT_EQ(val->channelIn, 1);
ASSERT_EQ(val->channelOut, 4);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}
class TestTfliteParserDepthwiseConv2 : public TestTfliteParser {
public:
TestTfliteParserDepthwiseConv2() = default;
void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv2.tflite", ""); }
};
TEST_F(TestTfliteParserDepthwiseConv2, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type";
}
TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);
ASSERT_EQ(val->channelIn, 2);
ASSERT_EQ(val->channelMultiplier, 1);
ASSERT_EQ(val->kernelH, 3);
ASSERT_EQ(val->kernelW, 3);
ASSERT_EQ(val->strideH, 1);
ASSERT_EQ(val->strideW, 1);
ASSERT_EQ(val->dilateH, 1);
ASSERT_EQ(val->dilateW, 1);
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
ASSERT_EQ(val->padUp, 1);
ASSERT_EQ(val->padDown, 1);
ASSERT_EQ(val->padLeft, 1);
ASSERT_EQ(val->padRight, 1);
}
} // namespace mindspore

View File

@ -25,17 +25,15 @@ class TestTfliteParserFill : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserFill, OpType) { TEST_F(TestTfliteParserFill, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type";
} }
TEST_F(TestTfliteParserFill, AttrValue) { TEST_F(TestTfliteParserFill, AttrValue) {;
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsFill(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsFill(); auto val = meta_graph->nodes.front()->primitive->value.AsFill();
std::vector<int32_t> dims = {9}; std::vector<int32_t> dims = {9};
ASSERT_EQ(val->dims, dims); ASSERT_EQ(val->dims, dims);
} }

View File

@ -25,15 +25,14 @@ class TestTfliteParserGatherNd : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserGatherNd, OpType) { TEST_F(TestTfliteParserGatherNd, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type";
} }
TEST_F(TestTfliteParserGatherNd, AttrValue) { TEST_F(TestTfliteParserGatherNd, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd(); auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd();
ASSERT_EQ(val->batchDims, 0); ASSERT_EQ(val->batchDims, 0);
} }

View File

@ -25,15 +25,14 @@ class TestTfliteParserGather : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserGather, OpType) { TEST_F(TestTfliteParserGather, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type";
} }
TEST_F(TestTfliteParserGather, AttrValue) { TEST_F(TestTfliteParserGather, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGather(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsGather(); auto val = meta_graph->nodes.front()->primitive->value.AsGather();
ASSERT_EQ(val->axis, 0); ASSERT_EQ(val->axis, 0);
ASSERT_EQ(val->batchDims, 0); ASSERT_EQ(val->batchDims, 0);

View File

@ -25,6 +25,7 @@ class TestTfliteParserLRN : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserLRN, OpType) { TEST_F(TestTfliteParserLRN, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type,
@ -32,9 +33,7 @@ TEST_F(TestTfliteParserLRN, OpType) {
} }
TEST_F(TestTfliteParserLRN, AttrValue) { TEST_F(TestTfliteParserLRN, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(); auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization();
ASSERT_EQ(val->alpha, 1); ASSERT_EQ(val->alpha, 1);
ASSERT_EQ(val->beta, 0.5); ASSERT_EQ(val->beta, 0.5);

View File

@ -32,12 +32,9 @@ TEST_F(TestTfliteParserOneHot, OpType) {
} }
TEST_F(TestTfliteParserOneHot, AttrValue) { TEST_F(TestTfliteParserOneHot, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr);
// in OneHot parser axis = axis > 0 ? axis : axis + tensor_shape.size() auto val = meta_graph->nodes.front()->primitive->value.AsOneHot();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsOneHot()->axis, 2); ASSERT_EQ(val->axis, 2);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -25,17 +25,15 @@ class TestTfliteParserPad : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserPad, OpType) { TEST_F(TestTfliteParserPad, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type";
} }
TEST_F(TestTfliteParserPad, AttrValue) { TEST_F(TestTfliteParserPad, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPad(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPad(); auto val = meta_graph->nodes.front()->primitive->value.AsPad();
std::vector<int32_t> paddings = {1, 1, 2, 2, 3, 3, 4, 4}; std::vector<int32_t> paddings = {1, 1, 2, 2, 3, 3, 4, 4};
ASSERT_EQ(val->paddings, paddings); ASSERT_EQ(val->paddings, paddings);
} }

View File

@ -35,12 +35,8 @@ TEST_F(TestTfliteParserMaxPooling, OpType) {
} }
TEST_F(TestTfliteParserMaxPooling, AttrValue) { TEST_F(TestTfliteParserMaxPooling, AttrValue) {
ASSERT_NE(meta_graph, nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING); ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING);
ASSERT_EQ(val->global, false); ASSERT_EQ(val->global, false);
@ -72,12 +68,8 @@ TEST_F(TestTfliteParserAvgPooling, OpType) {
} }
TEST_F(TestTfliteParserAvgPooling, AttrValue) { TEST_F(TestTfliteParserAvgPooling, AttrValue) {
ASSERT_NE(meta_graph, nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING); ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING);
ASSERT_EQ(val->global, false); ASSERT_EQ(val->global, false);

View File

@ -32,13 +32,9 @@ TEST_F(TestTfliteParserReduceMax, OpType) {
} }
TEST_F(TestTfliteParserReduceMax, AttrValue) { TEST_F(TestTfliteParserReduceMax, AttrValue) {
ASSERT_NE(meta_graph, nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr); ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax) << "wrong reduce mode";
ASSERT_EQ(val->keepDims, false); ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2}; std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes); ASSERT_EQ(val->axes, axes);
@ -58,13 +54,9 @@ TEST_F(TestTfliteParserReduceMin, OpType) {
} }
TEST_F(TestTfliteParserReduceMin, AttrValue) { TEST_F(TestTfliteParserReduceMin, AttrValue) {
ASSERT_NE(meta_graph, nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr); ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin) << "wrong reduce mode";
ASSERT_EQ(val->keepDims, false); ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2}; std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes); ASSERT_EQ(val->axes, axes);
@ -84,13 +76,9 @@ TEST_F(TestTfliteParserReduceProd, OpType) {
} }
TEST_F(TestTfliteParserReduceProd, AttrValue) { TEST_F(TestTfliteParserReduceProd, AttrValue) {
ASSERT_NE(meta_graph, nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr); ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd) << "wrong reduce mode";
ASSERT_EQ(val->keepDims, false); ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2}; std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes); ASSERT_EQ(val->axes, axes);
@ -111,13 +99,9 @@ TEST_F(TestTfliteParserSum, OpType) {
} }
TEST_F(TestTfliteParserSum, AttrValue) { TEST_F(TestTfliteParserSum, AttrValue) {
ASSERT_NE(meta_graph, nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr); ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum) << "wrong reduce mode";
ASSERT_EQ(val->keepDims, false); ASSERT_EQ(val->keepDims, false);
std::vector<int32_t> axes = {2}; std::vector<int32_t> axes = {2};
ASSERT_EQ(val->axes, axes); ASSERT_EQ(val->axes, axes);
@ -138,13 +122,9 @@ TEST_F(TestTfliteParserMean, OpType) {
} }
TEST_F(TestTfliteParserMean, AttrValue) { TEST_F(TestTfliteParserMean, AttrValue) {
ASSERT_NE(meta_graph, nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
ASSERT_NE(val, nullptr); ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean);
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean) << "wrong reduce mode";
ASSERT_EQ(val->keepDims, true); ASSERT_EQ(val->keepDims, true);
std::vector<int32_t> axes = {2, 3}; std::vector<int32_t> axes = {2, 3};
ASSERT_EQ(val->axes, axes); ASSERT_EQ(val->axes, axes);

View File

@ -35,12 +35,9 @@ TEST_F(TestTfliteParserReshape, OpType) {
} }
TEST_F(TestTfliteParserReshape, AttrValue) { TEST_F(TestTfliteParserReshape, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReshape();
std::vector<int64_t> shape = {3, 5, 20}; std::vector<int64_t> shape = {3, 5, 20};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReshape()->shape, shape); // int32 ASSERT_EQ(val->shape, shape);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -26,17 +26,15 @@ class TestTfliteParserResizeNN : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserResizeNN, OpType) { TEST_F(TestTfliteParserResizeNN, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
} }
TEST_F(TestTfliteParserResizeNN, AttrValue) { TEST_F(TestTfliteParserResizeNN, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsResize(); auto val = meta_graph->nodes.front()->primitive->value.AsResize();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->alignCorners, false); ASSERT_EQ(val->alignCorners, false);
ASSERT_EQ(val->newHeight, 3); ASSERT_EQ(val->newHeight, 3);
ASSERT_EQ(val->newWidth, 100); ASSERT_EQ(val->newWidth, 100);
@ -52,17 +50,15 @@ class TestTfliteParserResizeBilinear : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserResizeBilinear, OpType) { TEST_F(TestTfliteParserResizeBilinear, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
} }
TEST_F(TestTfliteParserResizeBilinear, AttrValue) { TEST_F(TestTfliteParserResizeBilinear, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsResize(); auto val = meta_graph->nodes.front()->primitive->value.AsResize();
ASSERT_NE(val, nullptr);
ASSERT_EQ(val->alignCorners, false); ASSERT_EQ(val->alignCorners, false);
ASSERT_EQ(val->newHeight, 75); ASSERT_EQ(val->newHeight, 75);
ASSERT_EQ(val->newWidth, 4); ASSERT_EQ(val->newWidth, 4);

View File

@ -25,17 +25,15 @@ class TestTfliteParserReverse : public TestTfliteParser {
}; };
TEST_F(TestTfliteParserReverse, OpType) { TEST_F(TestTfliteParserReverse, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type"; ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type";
} }
TEST_F(TestTfliteParserReverse, AttrValue) { TEST_F(TestTfliteParserReverse, AttrValue) {
ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverse(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsReverse(); auto val = meta_graph->nodes.front()->primitive->value.AsReverse();
std::vector<int32_t> axis = {3}; std::vector<int32_t> axis = {3};
ASSERT_EQ(val->axis, axis); ASSERT_EQ(val->axis, axis);
} }

View File

@ -35,13 +35,11 @@ TEST_F(TestTfliteParserReverseSequence, OpType) {
} }
TEST_F(TestTfliteParserReverseSequence, AttrValue) { TEST_F(TestTfliteParserReverseSequence, AttrValue) {
std::vector<int> seq_length{7, 2, 3, 5};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1); auto val = meta_graph->nodes.front()->primitive->value.AsReverseSequence();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1); ASSERT_EQ(val->seqAxis, 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqLengths, seq_length); ASSERT_EQ(val->seqAxis, 1);
std::vector<int> seq_length = {7, 2, 3, 5};
ASSERT_EQ(val->seqLengths, seq_length);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSlice : public TestTfliteParser {
public:
TestTfliteParserSlice() = default;
void SetUp() override { meta_graph = LoadAndConvert("./slice.tflite"); }
};
TEST_F(TestTfliteParserSlice, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Slice) << "wrong Op Type";
}
TEST_F(TestTfliteParserSlice, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSlice(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsSlice();
ASSERT_EQ(val->format, schema::Format_NHWC);
std::vector<int32_t> begin = {1, 0, 0};
ASSERT_EQ(val->begin, begin);
std::vector<int32_t> size = {1, 1, 3};
ASSERT_EQ(val->size, size);
}
} // namespace mindspore

View File

@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSoftmax, OpType) {
} }
TEST_F(TestTfliteParserSoftmax, AttrValue) { TEST_F(TestTfliteParserSoftmax, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSoftMax()->axis, -1); auto val = meta_graph->nodes.front()->primitive->value.AsSoftMax();
ASSERT_EQ(val->axis, -1);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -35,13 +35,11 @@ TEST_F(TestTfliteParserSpaceToBatchND, OpType) {
} }
TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) { TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) {
std::vector<int> blockshape{2, 2};
std::vector<int> padding{0, 0, 2, 0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->blockShape, blockshape); auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->paddings, padding); std::vector<int> blockshape = {2, 2};
ASSERT_EQ(val->blockShape, blockshape);
std::vector<int> padding = {0, 0, 2, 0};
ASSERT_EQ(val->paddings, padding);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSpaceToDepth, OpType) {
} }
TEST_F(TestTfliteParserSpaceToDepth, AttrValue) { TEST_F(TestTfliteParserSpaceToDepth, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->blockSize, 2); auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToDepth();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->format, schema::Format_NHWC); ASSERT_EQ(val->blockSize, 2);
ASSERT_EQ(val->format, schema::Format_NHWC);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -35,16 +35,14 @@ TEST_F(TestTfliteParserSparseToDense, OpType) {
} }
TEST_F(TestTfliteParserSparseToDense, AttrValue) { TEST_F(TestTfliteParserSparseToDense, AttrValue) {
std::vector<int> outputShape{5, 5};
std::vector<int> sparseValue{1};
std::vector<int> defaultValue{0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->outputShape, outputShape); auto val = meta_graph->nodes.front()->primitive->value.AsSparseToDense();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->sparseValue, sparseValue); std::vector<int> outputShape = {5, 5};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->defaultValue, defaultValue); ASSERT_EQ(val->outputShape, outputShape);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->validateIndices, false); std::vector<int> sparseValue = {1};
ASSERT_EQ(val->sparseValue, sparseValue);
std::vector<int> defaultValue = {0};
ASSERT_EQ(val->defaultValue, defaultValue);
ASSERT_EQ(val->validateIndices, false);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplit, OpType) {
} }
TEST_F(TestTfliteParserSplit, AttrValue) { TEST_F(TestTfliteParserSplit, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
const std::vector<int> sizeSplits{2, 2}; auto val = meta_graph->nodes.front()->primitive->value.AsSplit();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 2); ASSERT_EQ(val->splitDim, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2); ASSERT_EQ(val->numberSplit, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits); const std::vector<int> sizeSplits = {2, 2};
ASSERT_EQ(val->sizeSplits, sizeSplits);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplitV, OpType) {
} }
TEST_F(TestTfliteParserSplitV, AttrValue) { TEST_F(TestTfliteParserSplitV, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
const std::vector<int> sizeSplits{1, 3}; auto val = meta_graph->nodes.front()->primitive->value.AsSplit();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 0); ASSERT_EQ(val->splitDim, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2); ASSERT_EQ(val->numberSplit, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits); const std::vector<int> sizeSplits = {1, 3};
ASSERT_EQ(val->sizeSplits, sizeSplits);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserStack : public TestTfliteParser {
public:
TestTfliteParserStack() = default;
void SetUp() override { meta_graph = LoadAndConvert("./stack.tflite"); }
};
TEST_F(TestTfliteParserStack, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Stack) << "wrong Op Type";
}
TEST_F(TestTfliteParserStack, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStack(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsStack();
ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->n, 2);
const std::vector<int> isScale = {3, 2, 3};
ASSERT_EQ(val->isScale, isScale);
}
} // namespace mindspore

View File

@ -35,21 +35,19 @@ TEST_F(TestTfliteParserStridedSlice, OpType) {
} }
TEST_F(TestTfliteParserStridedSlice, AttrValue) { TEST_F(TestTfliteParserStridedSlice, AttrValue) {
std::vector<int> begin{1, -1, 0};
std::vector<int> end{2, -3, 3};
std::vector<int> stride{1, -1, 1};
std::vector<int> isscale{3, 2, 3};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); auto val = meta_graph->nodes.front()->primitive->value.AsStridedSlice();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->endMask, 0); ASSERT_EQ(val->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); ASSERT_EQ(val->endMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); ASSERT_EQ(val->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->begin, begin); ASSERT_EQ(val->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->end, end); std::vector<int> begin = {1, -1, 0};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->stride, stride); ASSERT_EQ(val->begin, begin);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->isScale, isscale); std::vector<int> end = {2, -3, 3};
ASSERT_EQ(val->end, end);
std::vector<int> stride = {1, -1, 1};
ASSERT_EQ(val->stride, stride);
std::vector<int> isscale = {3, 2, 3};
ASSERT_EQ(val->isScale, isscale);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -35,11 +35,9 @@ TEST_F(TestTfliteParserTile, OpType) {
} }
TEST_F(TestTfliteParserTile, AttrValue) { TEST_F(TestTfliteParserTile, AttrValue) {
std::vector<int> multiply{2, 3, 4};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTile()->multiples, multiply); auto val = meta_graph->nodes.front()->primitive->value.AsTile();
std::vector<int> multiply = {2, 3, 4};
ASSERT_EQ(val->multiples, multiply);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -35,13 +35,10 @@ TEST_F(TestTfliteParserTopKV2, OpType) {
} }
TEST_F(TestTfliteParserTopKV2, AttrValue) { TEST_F(TestTfliteParserTopKV2, AttrValue) {
// attr->sorted default is true
std::vector<int> k{3};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->k, k); auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->sorted, true); std::vector<int> k = {3};
ASSERT_EQ(val->k, k);
ASSERT_EQ(val->sorted, true);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserTranspose : public TestTfliteParser {
public:
TestTfliteParserTranspose() = default;
void SetUp() override { meta_graph = LoadAndConvert("./transpose.tflite"); }
};
TEST_F(TestTfliteParserTranspose, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
}
TEST_F(TestTfliteParserTranspose, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTranspose(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsTranspose();
ASSERT_EQ(val->conjugate, false);
std::vector<int32_t> perm = {1, 0};
ASSERT_EQ(val->perm, perm);
}
} // namespace mindspore

View File

@ -35,10 +35,9 @@ TEST_F(TestTfliteParserUnique, OpType) {
} }
TEST_F(TestTfliteParserUnique, AttrValue) { TEST_F(TestTfliteParserUnique, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnique()->outType, 34); // int32 auto val = meta_graph->nodes.front()->primitive->value.AsUnique();
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
ASSERT_EQ(val->outType, 34);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -35,11 +35,9 @@ TEST_F(TestTfliteParserUnstack, OpType) {
} }
TEST_F(TestTfliteParserUnstack, AttrValue) { TEST_F(TestTfliteParserUnstack, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->num, 5); auto val = meta_graph->nodes.front()->primitive->value.AsUnstack();
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->axis, 1); ASSERT_EQ(val->num, 5);
ASSERT_EQ(val->axis, 1);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -353,7 +353,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->format == schema::Format_KCHW) { } else if (weightTensor->format == schema::Format_KCHW) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) { } else if (weightTensor->format == schema::Format_CHWK) { // from tflite
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else { } else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
@ -369,7 +369,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) { // from tf } else if (weightTensor->format == schema::Format_CHWK) { // from tflite
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else { } else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;

View File

@ -21,11 +21,16 @@ namespace lite {
STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
if (op == nullptr) { if (op == nullptr) {
// MS_LOGE("null pointer dereferencing."); // MS_LOG(ERROR) << "null pointer dereferencing.";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT()); std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
attr->format = schema::Format_NCHW; const caffe::FlattenParameter flattenParam = proto.flatten_param();
attr->axis = (int32_t)flattenParam.axis();
attr->useAxis = true;
attr->hasBias = false;
attr->activationType = schema::ActivationType_NO_ACTIVATION;
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Flatten; op->primitive->value.type = schema::PrimitiveType_Flatten;

View File

@ -14,18 +14,21 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tools/converter/parser/tflite/tflite_activation_parser.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <string> #include <string>
#include "tools/converter/parser/tflite/tflite_activation_parser.h" #include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -35,13 +38,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
MS_LOG(ERROR) << "op->primitive is null"; MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
std::vector<std::string> node_name_str; std::vector<std::string> node_name_str;
Split(op->name, &node_name_str, "-"); Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str(); const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Relu") == 0) { if (std::strcmp(node_name, "Relu") == 0) {
MS_LOG(DEBUG) << "parse TfliteReluParser"; MS_LOG(DEBUG) << "parse TfliteReluParser";
attr->type = schema::ActivationType_RELU; attr->type = schema::ActivationType_RELU;
@ -54,29 +55,65 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
} else if (std::strcmp(node_name, "Logistic") == 0) { } else if (std::strcmp(node_name, "Logistic") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogisticParser"; MS_LOG(DEBUG) << "parse TfliteLogisticParser";
attr->type = schema::ActivationType_SIGMOID; attr->type = schema::ActivationType_SIGMOID;
} else if (std::strcmp(node_name, "LeakyRelu") == 0) { } else if (std::strcmp(node_name, "HardSwish") == 0) {
const auto &option = tfliteOp->builtin_options.AsLeakyReluOptions(); MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
if (option == nullptr) { attr->type = schema::ActivationType_SIGMOID;
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_LEAKY_RELU;
attr->alpha = option->alpha;
} else {
MS_LOG(ERROR) << "wrong activation type";
return RET_ERROR;
} }
attr->alpha = 0.2f;
op->primitive->value.type = schema::PrimitiveType_Activation; op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }
STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TflitePreluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PreluT> attr(new schema::PreluT());
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
MS_LOG(ERROR) << "get pRelu -> slope failed";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_Prelu;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}
STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteLeakyReluParser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -87,22 +124,29 @@ STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "paser TflitePreluParser"; std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT());
std::unique_ptr<schema::PreluT> attr(new schema::PreluT());
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) { const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
MS_LOG(ERROR) << "get pRelu -> slope failed"; if (tflite_attr == nullptr) {
return RET_ERROR; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
} }
attr->negativeSlope = tflite_attr->alpha;
op->primitive->value.type = schema::PrimitiveType_Prelu; op->primitive->value.type = schema::PrimitiveType_LeakyReLU;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser()); TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());

View File

@ -14,13 +14,14 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PREDICT_TFLITE_RELU_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
#define PREDICT_TFLITE_RELU_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -29,11 +30,13 @@ class TfliteActivationParser : public TfliteNodeParser {
public: public:
TfliteActivationParser() : TfliteNodeParser("node_name") {} TfliteActivationParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override; std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
class TfliteReluParser : public TfliteActivationParser { class TfliteReluParser : public TfliteActivationParser {
@ -56,9 +59,9 @@ class TfliteLogisticParser : public TfliteActivationParser {
TfliteLogisticParser() : TfliteActivationParser() {} TfliteLogisticParser() : TfliteActivationParser() {}
}; };
class TfliteLeakyReluParser : public TfliteActivationParser { class TfliteHardSwishParser : public TfliteActivationParser {
public: public:
TfliteLeakyReluParser() : TfliteActivationParser() {} TfliteHardSwishParser() : TfliteActivationParser() {}
}; };
class TflitePreluParser : public TfliteNodeParser { class TflitePreluParser : public TfliteNodeParser {
@ -68,12 +71,27 @@ class TflitePreluParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) override; std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
class TfliteLeakyReluParser : public TfliteNodeParser {
public:
TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // PREDICT_TFLITE_RELU_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H

View File

@ -18,14 +18,20 @@
#include "tools/converter/parser/tflite/tflite_addn_parser.h" #include "tools/converter/parser/tflite/tflite_addn_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteAddNParser";
// set attr
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -36,13 +42,19 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteAddNParser";
std::unique_ptr<schema::AddNT> attr(new schema::AddNT()); std::unique_ptr<schema::AddNT> attr(new schema::AddNT());
attr->N = tfliteTensors.size() - 1; attr->N = tflite_tensors.size() - 1;
op->primitive->value.type = schema::PrimitiveType_AddN; op->primitive->value.type = schema::PrimitiveType_AddN;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
// set input
for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef LITE_TFLITE_ADDN_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H
#define LITE_TFLITE_ADDN_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -31,11 +32,12 @@ class TfliteAddNParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantized_model) override; std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // LITE_TFLITE_ADDN_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H

View File

@ -17,16 +17,19 @@
#include "tools/converter/parser/tflite/tflite_argmax_parser.h" #include "tools/converter/parser/tflite/tflite_argmax_parser.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantizedModel) { std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -37,7 +40,6 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT()); std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());
attr->outMaxValue = false; attr->outMaxValue = false;
@ -45,9 +47,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->keepDims = false; attr->keepDims = false;
attr->axisType = 1; attr->axisType = 1;
auto axis_idx = tfliteOp->inputs[1]; // get axis attr
std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){}); auto axis_idx = tflite_op->inputs[1];
auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer]; std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){});
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
if (buf_data == nullptr) { if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null"; MS_LOG(ERROR) << "the buf data is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -61,6 +64,11 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
op->primitive->value.type = schema::PrimitiveType_ArgMax; op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PREDICT_TFLITE_ARGMAX_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H
#define PREDICT_TFLITE_ARGMAX_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -28,14 +29,15 @@ class TfliteArgmaxParser : public TfliteNodeParser {
public: public:
TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} TfliteArgmaxParser() : TfliteNodeParser("Argmax") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantizedModel) override; std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // PREDICT_TFLITE_ARGMAX_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H

View File

@ -17,14 +17,19 @@
#include "tools/converter/parser/tflite/tflite_argmin_parser.h" #include "tools/converter/parser/tflite/tflite_argmin_parser.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteArgminParser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -35,7 +40,6 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteArgminParser";
std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT()); std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT());
attr->outMaxValue = false; attr->outMaxValue = false;
@ -43,9 +47,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->keepDims = false; attr->keepDims = false;
attr->axisType = 1; attr->axisType = 1;
auto axis_idx = tfliteOp->inputs[1]; // get axis attr
std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){}); auto axis_idx = tflite_op->inputs[1];
auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer]; std::for_each(tflite_tensors[axis_idx]->shape.begin(),
tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){});
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
if (buf_data == nullptr) { if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null"; MS_LOG(ERROR) << "the buf data is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -59,6 +65,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
op->primitive->value.type = schema::PrimitiveType_ArgMin; op->primitive->value.type = schema::PrimitiveType_ArgMin;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PREDICT_TFLITE_ARGMIN_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H
#define PREDICT_TFLITE_ARGMIN_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -28,14 +29,15 @@ class TfliteArgminParser : public TfliteNodeParser {
public: public:
TfliteArgminParser() : TfliteNodeParser("Argmin") {} TfliteArgminParser() : TfliteNodeParser("Argmin") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantizedModel) override; std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // PREDICT_TFLITE_ARGMIN_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H

View File

@ -18,14 +18,17 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -37,55 +40,12 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
} }
std::vector<std::string> node_name_str; std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-"); Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str(); const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Add") == 0
|| std::strcmp(node_name, "Sub") == 0
|| std::strcmp(node_name, "Mul") == 0
|| std::strcmp(node_name, "Div") == 0) {
auto x_index = tfliteOp->inputs[0];
const auto &x_tensor = tfliteTensors[x_index];
if (x_tensor == nullptr) {
MS_LOG(ERROR) << "the first input is null";
return RET_NULL_PTR;
}
auto &x_data = tfliteModelBuffer.at(x_tensor->buffer);
if (x_data == nullptr) {
MS_LOG(ERROR) << "the data of the first input is null";
return RET_NULL_PTR;
}
if (!x_data->data.empty()) {
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse the first tensor failed";
return RET_ERROR;
}
}
auto y_index = tfliteOp->inputs[1];
const auto &y_tensor = tfliteTensors[y_index];
if (y_tensor == nullptr) {
MS_LOG(ERROR) << "the second input is null";
return RET_NULL_PTR;
}
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
if (y_data == nullptr) {
MS_LOG(ERROR) << "the data of the second input is null";
return RET_NULL_PTR;
}
if (!y_data->data.empty()) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}
}
if (std::strcmp(node_name, "Add") == 0) { if (std::strcmp(node_name, "Add") == 0) {
MS_LOG(DEBUG) << "parse TfliteAddParser"; MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT()); std::unique_ptr<schema::AddT> attr(new schema::AddT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions(); const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions();
if (nullptr == tfliteAttr) { if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -93,11 +53,10 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Add; op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sub") == 0) { } else if (std::strcmp(node_name, "Sub") == 0) {
MS_LOG(DEBUG) << "parse TfliteSubParser"; MS_LOG(DEBUG) << "parse TfliteSubParser";
std::unique_ptr<schema::SubT> attr(new schema::SubT()); std::unique_ptr<schema::SubT> attr(new schema::SubT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions(); const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions();
if (nullptr == tfliteAttr) { if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -105,11 +64,10 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Sub; op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Mul") == 0) { } else if (std::strcmp(node_name, "Mul") == 0) {
MS_LOG(DEBUG) << "parse TfliteMulParser"; MS_LOG(DEBUG) << "parse TfliteMulParser";
std::unique_ptr<schema::MulT> attr(new schema::MulT()); std::unique_ptr<schema::MulT> attr(new schema::MulT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions(); const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions();
if (nullptr == tfliteAttr) { if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -117,11 +75,10 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Mul; op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Div") == 0) { } else if (std::strcmp(node_name, "Div") == 0) {
MS_LOG(DEBUG) << "parse TfliteDivParser"; MS_LOG(DEBUG) << "parse TfliteDivParser";
std::unique_ptr<schema::DivT> attr(new schema::DivT()); std::unique_ptr<schema::DivT> attr(new schema::DivT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions(); const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions();
if (nullptr == tfliteAttr) { if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -129,32 +86,26 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
op->primitive->value.type = schema::PrimitiveType_Div; op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
}
} else if (std::strcmp(node_name, "FloorDiv") == 0) { } else if (std::strcmp(node_name, "FloorDiv") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; MS_LOG(DEBUG) << "parse TfliteFloorDivParser";
std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT()); std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT());
op->primitive->value.type = schema::PrimitiveType_FloorDiv; op->primitive->value.type = schema::PrimitiveType_FloorDiv;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "FloorMod") == 0) { } else if (std::strcmp(node_name, "FloorMod") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorModParser"; MS_LOG(DEBUG) << "parse TfliteFloorModParser";
std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT()); std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT());
op->primitive->value.type = schema::PrimitiveType_FloorMod; op->primitive->value.type = schema::PrimitiveType_FloorMod;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "RealDiv") == 0) { } else if (std::strcmp(node_name, "RealDiv") == 0) {
MS_LOG(DEBUG) << "parse TfliteRealDivParser"; MS_LOG(DEBUG) << "parse TfliteRealDivParser";
std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT()); std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT());
op->primitive->value.type = schema::PrimitiveType_RealDiv; op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "SquaredDifference") == 0) { } else if (std::strcmp(node_name, "SquaredDifference") == 0) {
MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser";
std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT()); std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT());
op->primitive->value.type = schema::PrimitiveType_SquaredDifference; op->primitive->value.type = schema::PrimitiveType_SquaredDifference;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Pow") == 0) { } else if (std::strcmp(node_name, "Pow") == 0) {
MS_LOG(DEBUG) << "parse TflitePowParser"; MS_LOG(DEBUG) << "parse TflitePowParser";
std::unique_ptr<schema::PowerT> attr(new schema::PowerT()); std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
@ -163,31 +114,35 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
attr->shift = 0.0f; attr->shift = 0.0f;
op->primitive->value.type = schema::PrimitiveType_Power; op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Maximum") == 0) { } else if (std::strcmp(node_name, "Maximum") == 0) {
MS_LOG(DEBUG) << "parse TfliteMaximumParser"; MS_LOG(DEBUG) << "parse TfliteMaximumParser";
std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT()); std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT());
op->primitive->value.type = schema::PrimitiveType_Maximum; op->primitive->value.type = schema::PrimitiveType_Maximum;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Minimum") == 0) { } else if (std::strcmp(node_name, "Minimum") == 0) {
MS_LOG(DEBUG) << "parse TfliteMinimumParser"; MS_LOG(DEBUG) << "parse TfliteMinimumParser";
std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT()); std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT());
op->primitive->value.type = schema::PrimitiveType_Minimum; op->primitive->value.type = schema::PrimitiveType_Minimum;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
} }
// set input
for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }
STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -199,85 +154,79 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
} }
std::vector<std::string> node_name_str; std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-"); Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str(); const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Abs") == 0) { if (std::strcmp(node_name, "Abs") == 0) {
MS_LOG(DEBUG) << "parse TfliteAbsParser"; MS_LOG(DEBUG) << "parse TfliteAbsParser";
std::unique_ptr<schema::AbsT> attr(new schema::AbsT()); std::unique_ptr<schema::AbsT> attr(new schema::AbsT());
op->primitive->value.type = schema::PrimitiveType_Abs; op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Exp") == 0) { } else if (std::strcmp(node_name, "Exp") == 0) {
MS_LOG(DEBUG) << "parse TfliteExpParser"; MS_LOG(DEBUG) << "parse TfliteExpParser";
std::unique_ptr<schema::ExpT> attr(new schema::ExpT()); std::unique_ptr<schema::ExpT> attr(new schema::ExpT());
op->primitive->value.type = schema::PrimitiveType_Exp; op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sqrt") == 0) { } else if (std::strcmp(node_name, "Sqrt") == 0) {
MS_LOG(DEBUG) << "parse TfliteSqrtParser"; MS_LOG(DEBUG) << "parse TfliteSqrtParser";
std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT()); std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT());
op->primitive->value.type = schema::PrimitiveType_Sqrt; op->primitive->value.type = schema::PrimitiveType_Sqrt;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Rsqrt") == 0) { } else if (std::strcmp(node_name, "Rsqrt") == 0) {
MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; MS_LOG(DEBUG) << "parse TfliteRsqrtParser";
std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT()); std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT());
op->primitive->value.type = schema::PrimitiveType_Rsqrt; op->primitive->value.type = schema::PrimitiveType_Rsqrt;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Square") == 0) { } else if (std::strcmp(node_name, "Square") == 0) {
MS_LOG(DEBUG) << "parse TfliteSquareParser"; MS_LOG(DEBUG) << "parse TfliteSquareParser";
std::unique_ptr<schema::SquareT> attr(new schema::SquareT()); std::unique_ptr<schema::SquareT> attr(new schema::SquareT());
op->primitive->value.type = schema::PrimitiveType_Square; op->primitive->value.type = schema::PrimitiveType_Square;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Sin") == 0) { } else if (std::strcmp(node_name, "Sin") == 0) {
MS_LOG(DEBUG) << "parse TfliteSinParser"; MS_LOG(DEBUG) << "parse TfliteSinParser";
std::unique_ptr<schema::SinT> attr(new schema::SinT()); std::unique_ptr<schema::SinT> attr(new schema::SinT());
op->primitive->value.type = schema::PrimitiveType_Sin; op->primitive->value.type = schema::PrimitiveType_Sin;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Cos") == 0) { } else if (std::strcmp(node_name, "Cos") == 0) {
MS_LOG(DEBUG) << "parse TfliteCosParser"; MS_LOG(DEBUG) << "parse TfliteCosParser";
std::unique_ptr<schema::CosT> attr(new schema::CosT()); std::unique_ptr<schema::CosT> attr(new schema::CosT());
op->primitive->value.type = schema::PrimitiveType_Cos; op->primitive->value.type = schema::PrimitiveType_Cos;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Log") == 0) { } else if (std::strcmp(node_name, "Log") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogParser"; MS_LOG(DEBUG) << "parse TfliteLogParser";
std::unique_ptr<schema::LogT> attr(new schema::LogT()); std::unique_ptr<schema::LogT> attr(new schema::LogT());
op->primitive->value.type = schema::PrimitiveType_Log; op->primitive->value.type = schema::PrimitiveType_Log;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Round") == 0) { } else if (std::strcmp(node_name, "Round") == 0) {
MS_LOG(DEBUG) << "parse TfliteRoundParser"; MS_LOG(DEBUG) << "parse TfliteRoundParser";
std::unique_ptr<schema::RoundT> attr(new schema::RoundT()); std::unique_ptr<schema::RoundT> attr(new schema::RoundT());
op->primitive->value.type = schema::PrimitiveType_Round; op->primitive->value.type = schema::PrimitiveType_Round;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Ceil") == 0) { } else if (std::strcmp(node_name, "Ceil") == 0) {
MS_LOG(DEBUG) << "parse TfliteCeilParser"; MS_LOG(DEBUG) << "parse TfliteCeilParser";
std::unique_ptr<schema::CeilT> attr(new schema::CeilT()); std::unique_ptr<schema::CeilT> attr(new schema::CeilT());
op->primitive->value.type = schema::PrimitiveType_Ceil; op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "flOOR") == 0) { } else if (std::strcmp(node_name, "flOOR") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorParser"; MS_LOG(DEBUG) << "parse TfliteFloorParser";
std::unique_ptr<schema::FloorT> attr(new schema::FloorT()); std::unique_ptr<schema::FloorT> attr(new schema::FloorT());
op->primitive->value.type = schema::PrimitiveType_Floor; op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
}
} }
STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { return RET_OK;
}
STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -289,48 +238,47 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
} }
std::vector<std::string> node_name_str; std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-"); Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str(); const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Equal") == 0) { if (std::strcmp(node_name, "Equal") == 0) {
MS_LOG(DEBUG) << "parse TfliteEqualParser"; MS_LOG(DEBUG) << "parse TfliteEqualParser";
std::unique_ptr<schema::EqualT> attr(new schema::EqualT()); std::unique_ptr<schema::EqualT> attr(new schema::EqualT());
op->primitive->value.type = schema::PrimitiveType_Equal; op->primitive->value.type = schema::PrimitiveType_Equal;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "NotEqual") == 0) { } else if (std::strcmp(node_name, "NotEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; MS_LOG(DEBUG) << "parse TfliteNotEqualParser";
std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT()); std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT());
op->primitive->value.type = schema::PrimitiveType_NotEqual; op->primitive->value.type = schema::PrimitiveType_NotEqual;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Greater") == 0) { } else if (std::strcmp(node_name, "Greater") == 0) {
MS_LOG(DEBUG) << "parse TfliteGreaterParser"; MS_LOG(DEBUG) << "parse TfliteGreaterParser";
std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT()); std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT());
op->primitive->value.type = schema::PrimitiveType_Greater; op->primitive->value.type = schema::PrimitiveType_Greater;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "GreaterEqual") == 0) { } else if (std::strcmp(node_name, "GreaterEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser";
std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT()); std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT());
op->primitive->value.type = schema::PrimitiveType_GreaterEqual; op->primitive->value.type = schema::PrimitiveType_GreaterEqual;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "Less") == 0) { } else if (std::strcmp(node_name, "Less") == 0) {
MS_LOG(DEBUG) << "parse TfliteLessParser"; MS_LOG(DEBUG) << "parse TfliteLessParser";
std::unique_ptr<schema::LessT> attr(new schema::LessT()); std::unique_ptr<schema::LessT> attr(new schema::LessT());
op->primitive->value.type = schema::PrimitiveType_Less; op->primitive->value.type = schema::PrimitiveType_Less;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else if (std::strcmp(node_name, "LessEqual") == 0) { } else if (std::strcmp(node_name, "LessEqual") == 0) {
MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; MS_LOG(DEBUG) << "parse TfliteLessEqualParser";
std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT()); std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT());
op->primitive->value.type = schema::PrimitiveType_LessEqual; op->primitive->value.type = schema::PrimitiveType_LessEqual;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK;
} else {
MS_LOG(ERROR) << "wrong op type";
return RET_ERROR;
} }
for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
} }
TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser());

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PREDICT_TFLITE_MATH_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H
#define PREDICT_TFLITE_MATH_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -29,11 +30,13 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
public: public:
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override; std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
class TfliteAddParser : public TfliteDoubleInputOpParser { class TfliteAddParser : public TfliteDoubleInputOpParser {
@ -96,11 +99,13 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
public: public:
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override; std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
class TfliteAbsParser : public TfliteSingleInputOpParser { class TfliteAbsParser : public TfliteSingleInputOpParser {
@ -163,11 +168,13 @@ class TfliteCompareOpParser : public TfliteNodeParser {
public: public:
TfliteCompareOpParser() : TfliteNodeParser("node_name") {} TfliteCompareOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override; std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
class TfliteEqualParser : public TfliteCompareOpParser { class TfliteEqualParser : public TfliteCompareOpParser {
@ -203,5 +210,5 @@ class TfliteLessEqualParser : public TfliteCompareOpParser {
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // PREDICT_TFLITE_MATH_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H

View File

@ -19,14 +19,17 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -38,30 +41,32 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
} }
std::vector<std::string> node_name_str; std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-"); Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str(); const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "BatchToSpace") == 0) { if (std::strcmp(node_name, "BatchToSpace") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
} else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser";
// in tflite
// blockShape should be a 1D tensor with dimension [spatial_dims_num]
// crops should be a 2D tensor with dimension [spatial_dims_num, 2]
} }
std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT()); std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) { if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) {
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
return RET_ERROR; return RET_ERROR;
} }
if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) { if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) {
MS_LOG(ERROR) << "get batchToSpace -> crops failed"; MS_LOG(ERROR) << "get batchToSpace -> crops failed";
return RET_ERROR; return RET_ERROR;
} }
op->primitive->value.type = schema::PrimitiveType_BatchToSpace; op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef LITE_TFLITE_BATCH_TO_SPACE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H
#define LITE_TFLITE_BATCH_TO_SPACE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -31,8 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) override; std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
@ -43,4 +46,4 @@ class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // LITE_TFLITE_BATCH_TO_SPACE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H

View File

@ -18,14 +18,19 @@
#include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h" #include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -36,16 +41,20 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT()); std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT());
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dst_shape)) { if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) {
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
return RET_ERROR; return RET_ERROR;
} }
op->primitive->value.type = schema::PrimitiveType_BroadcastTo; op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef LITE_TFLITE_BROADCAST_TO_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
#define LITE_TFLITE_BROADCAST_TO_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -31,11 +32,12 @@ class TfliteBroadcastToParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantized_model) override; std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // LITE_TFLITE_BROADCAST_TO_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H

View File

@ -14,18 +14,22 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "tools/converter/parser/tflite/tflite_cast_parser.h" #include "tools/converter/parser/tflite/tflite_cast_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteCastParser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -36,25 +40,28 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteCastParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT()); std::unique_ptr<schema::CastT> attr(new schema::CastT());
const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]]; const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) { if (in_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null"; MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->srcT = dtype_map[in_tensor->type]; attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
if (out_tensor == nullptr) { if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null"; MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->dstT = dtype_map[out_tensor->type]; attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.type = schema::PrimitiveType_Cast; op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef LITE_TFLITE_CAST_PARSER_ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H
#define LITE_TFLITE_CAST_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -31,11 +32,12 @@ class TfliteCastParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantized_model) override; std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // LITE_TFLITE_CAST_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H

View File

@ -17,14 +17,20 @@
#include "tools/converter/parser/tflite/tflite_concat_parser.h" #include "tools/converter/parser/tflite/tflite_concat_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteConcatParser";
// set attr
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -35,20 +41,25 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteConcatParser";
std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT()); std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsConcatenationOptions(); const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions();
if (tfliteAttr == nullptr) { if (tfliteAttr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->axis = tfliteAttr->axis; attr->axis = tfliteAttr->axis;
attr->n = tflite_op->inputs.size();
attr->n = tfliteOp->inputs.size();
op->primitive->value.type = schema::PrimitiveType_Concat; op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
for (int i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PREDICT_TFLITE_CONCAT_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H
#define PREDICT_TFLITE_CONCAT_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -28,15 +29,16 @@ class TfliteConcatParser : public TfliteNodeParser {
public: public:
TfliteConcatParser() : TfliteNodeParser("Concat") {} TfliteConcatParser() : TfliteNodeParser("Concat") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, schema::CNodeT *,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantizedModel) override; std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // PREDICT_TFLITE_CONCAT_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H

View File

@ -17,14 +17,19 @@
#include "tools/converter/parser/tflite/tflite_conv_parser.h" #include "tools/converter/parser/tflite/tflite_conv_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteConvParser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -35,60 +40,61 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteConvParser";
std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT()); std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions(); const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions();
if (tfliteAttr == nullptr) { if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->group = 1; attr->group = 1;
attr->strideW = tfliteAttr->stride_w; attr->strideW = tflite_attr->stride_w;
attr->strideH = tfliteAttr->stride_h; attr->strideH = tflite_attr->stride_h;
attr->dilateH = tfliteAttr->dilation_h_factor; attr->dilateH = tflite_attr->dilation_h_factor;
attr->dilateW = tfliteAttr->dilation_w_factor; attr->dilateW = tflite_attr->dilation_w_factor;
attr->padMode = GetPadMode(tfliteAttr->padding); attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format_NHWC; attr->format = schema::Format_NHWC;
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
attr->hasBias = true;
// get the conv op weight tensor // get the conv op weight tensor
auto weight_index = tfliteOp->inputs[1]; auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index]; const auto &weight_tensor = tflite_tensors[weight_index];
if (weight_tensor == nullptr) { if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null"; MS_LOG(ERROR) << "the weight tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
auto weight_shape = weight_tensor->shape; auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[KHWC_C]; attr->channelIn = weight_shape[3];
attr->channelOut = weight_shape[KHWC_K]; attr->channelOut = weight_shape[0];
attr->kernelW = weight_shape[KHWC_W]; attr->kernelH = weight_shape[1];
attr->kernelH = weight_shape[KHWC_H]; attr->kernelW = weight_shape[2];
// get the conv op bias tensor
if (tfliteOp->inputs.size() == 3) {
attr->hasBias = true;
auto bias_index = tfliteOp->inputs[2];
const auto &bias_tensor = tfliteTensors[bias_index];
if (bias_tensor == nullptr) {
MS_LOG(ERROR) << "bias_tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}
// calculate pad params // calculate pad params
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_tensors[data_index];
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH,
attr->strideW, attr->kernelH, attr->kernelW, &params) != RET_OK) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
} else {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}
op->primitive->value.type = schema::PrimitiveType_Conv2D; op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PREDICT_TFLITE_CONV_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
#define PREDICT_TFLITE_CONV_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -28,15 +29,16 @@ class TfliteConvParser : public TfliteNodeParser {
public: public:
TfliteConvParser() : TfliteNodeParser("Conv2D") {} TfliteConvParser() : TfliteNodeParser("Conv2D") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantizedModel) override; std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // PREDICT_TFLITE_CONV_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H

View File

@ -19,6 +19,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/converter.h" #include "tools/converter/converter.h"
#include "tools/converter/parser/tflite/tflite_model_parser.h" #include "tools/converter/parser/tflite/tflite_model_parser.h"
#include "tools/converter/graphdef_transform.h" #include "tools/converter/graphdef_transform.h"

View File

@ -17,14 +17,19 @@
#include "tools/converter/parser/tflite/tflite_deconv_parser.h" #include "tools/converter/parser/tflite/tflite_deconv_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -35,11 +40,10 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT()); std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT());
const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions(); const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions();
if (tflite_attr == nullptr) { if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
@ -50,26 +54,48 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->dilateW = 1; attr->dilateW = 1;
attr->padMode = GetPadMode(tflite_attr->padding); attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format_NHWC; attr->format = schema::Format_NHWC;
attr->activationType = schema::ActivationType_NO_ACTIVATION;
attr->hasBias = true;
// get the conv op weight tensor // get the conv op weight tensor
auto weight_index = tfliteOp->inputs[1]; auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index]; const auto &weight_tensor = tflite_tensors[weight_index];
if (weight_tensor == nullptr) { if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null"; MS_LOG(ERROR) << "the weight tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
return RET_ERROR;
}
auto weight_shape = weight_tensor->shape; auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[CHWK_K]; attr->channelIn = weight_shape[3];
attr->channelOut = weight_shape[CHWK_C]; attr->channelOut = weight_shape[0];
attr->kernelW = weight_shape[CHWK_W]; attr->kernelH = weight_shape[1];
attr->kernelH = weight_shape[CHWK_H]; attr->kernelW = weight_shape[2];
// calculate pad params
auto data_index = tflite_op->inputs[2];
const auto &data_tensor = tflite_tensors[data_index];
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH,
attr->strideW, attr->kernelH, attr->kernelW, &params) != RET_OK) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
} else {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}
op->primitive->value.type = schema::PrimitiveType_DeConv2D; op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PREDICT_TFLITE_DECONV_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H
#define PREDICT_TFLITE_DECONV_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -31,11 +32,12 @@ class TfliteDeConvParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_op_set, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantizedModel) override; std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // PREDICT_TFLITE_DECONV_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H

View File

@ -18,14 +18,19 @@
#include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h" #include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -36,20 +41,23 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT()); std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT());
const auto &tflite_attr = tfliteOp->builtin_options.AsDepthToSpaceOptions(); const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions();
if (tflite_attr == nullptr) { if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->blockSize = tflite_attr->block_size; attr->blockSize = tflite_attr->block_size;
attr->format = schema::Format_NHWC; attr->format = schema::Format_NHWC;
op->primitive->value.type = schema::PrimitiveType_DepthToSpace; op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H
#define LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -31,11 +32,12 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser {
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantized_model) override; std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H

View File

@ -17,65 +17,22 @@
#include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h" #include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/common/node_util.h" #include "tools/common/node_util.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op,
const std::unique_ptr<schema::DepthwiseConv2DT> &attr,
const std::unique_ptr<tflite::TensorT> &weightTensor,
TensorCache *tensor_cache) {
std::unique_ptr<schema::Conv2DT> convAttr(new schema::Conv2DT);
convAttr->format = attr->format;
convAttr->channelIn = attr->channelIn;
convAttr->channelOut = attr->channelIn * attr->channelMultiplier;
convAttr->kernelH = attr->kernelH;
convAttr->kernelW = attr->kernelW;
convAttr->strideH = attr->strideH;
convAttr->strideW = attr->strideW;
convAttr->padMode = attr->padMode;
convAttr->padUp = attr->padUp;
convAttr->padDown = attr->padDown;
convAttr->padLeft = attr->padLeft;
convAttr->padRight = attr->padRight;
convAttr->dilateH = attr->dilateH;
convAttr->dilateW = attr->dilateW;
convAttr->hasBias = attr->hasBias;
convAttr->activationType = attr->activationType;
auto weightTensorIndex = tensor_cache->FindTensor(weightTensor->name);
if (weightTensorIndex >= 0 && weightTensorIndex < tensor_cache->GetCachedTensor().size()) {
auto liteWeightTensor = tensor_cache->GetCachedTensor()[weightTensorIndex];
if (liteWeightTensor->dataType == TypeId::kNumberTypeUInt8) {
// convert weight format KHWC -> CHWK
auto status = TransFilterFormat<uint8_t>(liteWeightTensor, kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
return RET_ERROR;
}
}
if (liteWeightTensor->dataType == kNumberTypeFloat32 || liteWeightTensor->dataType == kNumberTypeFloat) {
// convert weight format KHWC -> CHWK
auto status = TransFilterFormat<float>(liteWeightTensor, kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
return RET_ERROR;
}
}
}
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = convAttr.release();
return RET_OK;
}
STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -86,7 +43,6 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT()); std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT());
const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions();
if (tflite_attr == nullptr) { if (tflite_attr == nullptr) {
@ -100,15 +56,20 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
attr->padMode = GetPadMode(tflite_attr->padding); attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format_NHWC; attr->format = schema::Format_NHWC;
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
// get the conv op weight tensor attr->hasBias = true;
auto input_index = tflite_op->inputs[0]; attr->channelMultiplier = tflite_attr->depth_multiplier;
const auto &input_tenosr = tflite_tensors[input_index];
if (input_tenosr == nullptr) { // get the data tensor
MS_LOG(ERROR) << "the first input is null"; auto data_index = tflite_op->inputs[1];
const auto &data_tensor = tflite_tensors[data_index];
if (data_tensor == nullptr) {
MS_LOG(ERROR) << "the data tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
auto input_shape = input_tenosr->shape; auto data_shape = data_tensor->shape;
attr->channelIn = data_shape[3];
// get the weight tensor
auto weight_index = tflite_op->inputs[1]; auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_tensors[weight_index]; const auto &weight_tensor = tflite_tensors[weight_index];
if (weight_tensor == nullptr) { if (weight_tensor == nullptr) {
@ -116,38 +77,33 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
return RET_NULL_PTR; return RET_NULL_PTR;
} }
auto weight_shape = weight_tensor->shape; auto weight_shape = weight_tensor->shape;
attr->channelIn = input_shape[KHWC_C]; attr->kernelH = weight_shape[1];
attr->channelMultiplier = tflite_attr->depth_multiplier; attr->kernelW = weight_shape[2];
attr->kernelH = weight_shape[KHWC_H];
attr->kernelW = weight_shape[KHWC_W];
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; // calculate pad params
std::vector<int> params;
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW,
MS_LOG(ERROR) << "parse weight failed"; attr->kernelH, attr->kernelW, &params) != RET_OK) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR; return RET_ERROR;
}
if (tflite_op->inputs.size() == 3) {
attr->hasBias = true;
auto bias_index = tflite_op->inputs[2];
const auto &bias_tensor = tflite_tensors[bias_index];
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}
if (attr->channelMultiplier > 1) {
if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) {
MS_LOG(ERROR) << "Parse Group DepthwiseConv failed";
return RET_ERROR;
}
} else { } else {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
}
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H
#define PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -28,20 +29,16 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser {
public: public:
TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) override; std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
private: std::map<int, int> *tensors_id_map) override;
STATUS ParseGroupDepthwiseConv(schema::CNodeT *op,
const std::unique_ptr<schema::DepthwiseConv2DT> &attr,
const std::unique_ptr<tflite::TensorT> &weightTensor,
TensorCache *tensor_cache);
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // PREDICT_TFLITE_CONV_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H

View File

@ -16,15 +16,20 @@
#include "tools/converter/parser/tflite/tflite_dequantize_parser.h" #include "tools/converter/parser/tflite/tflite_dequantize_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/common/node_util.h" #include "tools/common/node_util.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -35,32 +40,30 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT); std::unique_ptr<schema::CastT> attr(new schema::CastT);
// get the dequantize input tensor // get the dequantize input tensor
const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]]; const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) { if (in_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null"; MS_LOG(ERROR) << "input tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->srcT = dtype_map[in_tensor->type]; attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
if (out_tensor == nullptr) { if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null"; MS_LOG(ERROR) << "output tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->dstT = dtype_map[out_tensor->type]; attr->dstT = GetTfliteDataType(out_tensor->type);
std::vector<tflite::TensorT *> weight_tensors{in_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_Fp16Cast; op->primitive->value.type = schema::PrimitiveType_Fp16Cast;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return 0;
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
} }
TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser());

View File

@ -13,11 +13,12 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H
#define LITE_TFLITE_DEQUANTIZE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -27,13 +28,15 @@ class TfliteDequantizeParser : public TfliteNodeParser {
public: public:
TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override; std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // LITE_TFLITE_DEQUANTIZE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H

View File

@ -17,16 +17,17 @@
#include "tools/converter/parser/tflite/tflite_expand_dims_parser.h" #include "tools/converter/parser/tflite/tflite_expand_dims_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantizedModel) { std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -40,7 +41,7 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; MS_LOG(DEBUG) << "parse TfliteExpandDimsParser";
std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT()); std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT());
const auto &tflite_attr = tfliteOp->builtin_options.AsExpandDimsOptions(); const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions();
if (tflite_attr == nullptr) { if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;

View File

@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PREDICT_TFLITE_EXPAND_DIMS_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H
#define PREDICT_TFLITE_EXPAND_DIMS_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
@ -28,15 +29,16 @@ class TfliteExpandDimsParser : public TfliteNodeParser {
public: public:
TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantizedModel) override; std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // PREDICT_TFLITE_EXPAND_DIMS_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H

View File

@ -1,75 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_fakequant_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser";
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "weight_tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
if (tfliteOp->inputs.size() == 3) {
attr->hasBias = true;
auto bias_index = tfliteOp->inputs[2];
const auto &bias_tensor = tfliteTensors[bias_index];
if (bias_tensor == nullptr) {
MS_LOG(ERROR) << "bias_tensor is null";
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
}
attr->axis = 1;
op->primitive->value.type = schema::PrimitiveType_FullConnection;
op->primitive->value.value = attr.release();
return RET_OK;
}
TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser());
} // namespace lite
} // namespace mindspore

View File

@ -1,39 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_FAKEQUANT_PARSER_H
#define LITE_TFLITE_FAKEQUANT_PARSER_H
#include <vector>
#include <memory>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteFakeQuantParser : public TfliteNodeParser {
public:
TfliteFakeQuantParser() : TfliteNodeParser("FakeQuant") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_TFLITE_FAKEQUANT_PARSER_H

View File

@ -14,19 +14,22 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tools/converter/parser/tflite/tflite_fill_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "tools/converter/parser/tflite/tflite_fill_parser.h" #include <map>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, std::vector<int32_t> *tensors_id,
bool quantizedModel) { std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteFillParser";
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "op is null"; MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -37,18 +40,22 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(DEBUG) << "parse TfliteFillParser";
std::unique_ptr<schema::FillT> attr(new schema::FillT()); std::unique_ptr<schema::FillT> attr(new schema::FillT());
if (tfliteOp->inputs.size() > 1) { if (tflite_op->inputs.size() > 1) {
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dims)) { if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dims)) {
MS_LOG(ERROR) << "get Fill -> dims failed"; MS_LOG(ERROR) << "get fill -> dims failed";
return RET_ERROR; return RET_ERROR;
} }
} }
op->primitive->value.type = schema::PrimitiveType_Fill; op->primitive->value.type = schema::PrimitiveType_Fill;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK; return RET_OK;
} }

Some files were not shown because too many files have changed in this diff Show More