forked from mindspore-Ecosystem/mindspore
refactor tflite parsers.
append ut refactor tflite parsers modify tflite parser, ut and model supplement caffe flatten parser fix the weight tensor format of deconv bug fix bug when idx=-1 fix the weight tensor format of depthConv bug.
This commit is contained in:
parent
a0c12e7aa7
commit
123c2024a5
|
@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./
|
||||||
TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/
|
TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/
|
||||||
cp -fr $TEST_DATA_DIR/testPK ./data
|
cp -fr $TEST_DATA_DIR/testPK ./data
|
||||||
|
|
||||||
#./lite-test --gtest_filter="*MindDataTestTensorDE*"
|
./lite-test --gtest_filter="*MindDataTestTensorDE*"
|
||||||
#./lite-test --gtest_filter="*MindDataTestEager*"
|
./lite-test --gtest_filter="*MindDataTestEager*"
|
||||||
#
|
|
||||||
#./lite-test --gtest_filter="TestTfliteParser*"
|
./lite-test --gtest_filter="TestTfliteParser*"
|
||||||
#
|
|
||||||
#./lite-test --gtest_filter="*TestHebing*"
|
./lite-test --gtest_filter="*TestHebing*"
|
||||||
#
|
|
||||||
#./lite-test --gtest_filter=TestFcFp32*
|
./lite-test --gtest_filter=TestFcFp32*
|
||||||
#./lite-test --gtest_filter=TestConv1x1Fp32*
|
./lite-test --gtest_filter=TestConv1x1Fp32*
|
||||||
#./lite-test --gtest_filter=TestStrassenFp32*
|
./lite-test --gtest_filter=TestStrassenFp32*
|
||||||
#./lite-test --gtest_filter=TestDeConvolutionFp32*
|
./lite-test --gtest_filter=TestDeConvolutionFp32*
|
||||||
#
|
|
||||||
#./lite-test --gtest_filter=TestPadInt8.*
|
./lite-test --gtest_filter=TestPadInt8.*
|
||||||
#./lite-test --gtest_filter=TestDeconvInt8.*
|
./lite-test --gtest_filter=TestDeconvInt8.*
|
||||||
|
|
||||||
./lite-test --gtest_filter="TestTfliteParser*"
|
./lite-test --gtest_filter="TestTfliteParser*"
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -31,6 +31,12 @@ TEST_F(TestTfliteParserRelu, OpType) {
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserRelu, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
|
||||||
|
ASSERT_EQ(val->type, schema::ActivationType_RELU);
|
||||||
|
}
|
||||||
|
|
||||||
class TestTfliteParserRelu6 : public TestTfliteParser {
|
class TestTfliteParserRelu6 : public TestTfliteParser {
|
||||||
public:
|
public:
|
||||||
TestTfliteParserRelu6() = default;
|
TestTfliteParserRelu6() = default;
|
||||||
|
@ -43,6 +49,12 @@ TEST_F(TestTfliteParserRelu6, OpType) {
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserRelu6, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
|
||||||
|
ASSERT_EQ(val->type, schema::ActivationType_RELU6);
|
||||||
|
}
|
||||||
|
|
||||||
class TestTfliteParserTanh : public TestTfliteParser {
|
class TestTfliteParserTanh : public TestTfliteParser {
|
||||||
public:
|
public:
|
||||||
TestTfliteParserTanh() = default;
|
TestTfliteParserTanh() = default;
|
||||||
|
@ -55,7 +67,45 @@ TEST_F(TestTfliteParserTanh, OpType) {
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
// logistic
|
TEST_F(TestTfliteParserTanh, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
|
||||||
|
ASSERT_EQ(val->type, schema::ActivationType_TANH);
|
||||||
|
}
|
||||||
|
|
||||||
|
class TestTfliteParserLogistic : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserLogistic() = default;
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./logistic.tflite", ""); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserLogistic, OpType) {
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
TEST_F(TestTfliteParserLogistic, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
|
||||||
|
ASSERT_EQ(val->type, schema::ActivationType_SIGMOID);
|
||||||
|
}
|
||||||
|
|
||||||
|
class TestTfliteParserHardSwish : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserHardSwish() = default;
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./hardswish.tflite", ""); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserHardSwish, OpType) {
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
TEST_F(TestTfliteParserHardSwish, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
|
||||||
|
ASSERT_EQ(val->type, schema::ActivationType_SIGMOID);
|
||||||
|
}
|
||||||
|
|
||||||
class TestTfliteParserPrelu : public TestTfliteParser {
|
class TestTfliteParserPrelu : public TestTfliteParser {
|
||||||
public:
|
public:
|
||||||
|
@ -73,12 +123,11 @@ TEST_F(TestTfliteParserPrelu, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserPrelu, AttrValue) {
|
TEST_F(TestTfliteParserPrelu, AttrValue) {
|
||||||
std::vector<float> slope(20, 0);
|
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsPrelu()->slope, slope);
|
auto val = meta_graph->nodes.front()->primitive->value;
|
||||||
|
std::vector<float> slope(20, 0);
|
||||||
|
ASSERT_EQ(val.AsPrelu()->slope, slope);
|
||||||
|
ASSERT_EQ(val.type, schema::PrimitiveType_Prelu);
|
||||||
}
|
}
|
||||||
|
|
||||||
class TestTfliteParserLeakyRelu : public TestTfliteParser {
|
class TestTfliteParserLeakyRelu : public TestTfliteParser {
|
||||||
|
@ -94,12 +143,10 @@ TEST_F(TestTfliteParserLeakyRelu, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserLeakyRelu, AttrValue) {
|
TEST_F(TestTfliteParserLeakyRelu, AttrValue) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLeakyReLU(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
auto val = meta_graph->nodes.front()->primitive->value;
|
||||||
|
ASSERT_EQ(val.AsLeakyReLU()->negativeSlope, 0.20000000298023224);
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsLeakyReLU();
|
ASSERT_EQ(val.type, schema::PrimitiveType_LeakyReLU);
|
||||||
ASSERT_NE(val, nullptr);
|
|
||||||
ASSERT_EQ(val->negativeSlope, 0.20000000298023224);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,10 +35,8 @@ TEST_F(TestTfliteParserAddN, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserAddN, AttrValue) {
|
TEST_F(TestTfliteParserAddN, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsAddN()->N, 4);
|
auto val = meta_graph->nodes.front()->primitive->value.AsAddN();
|
||||||
|
ASSERT_EQ(val->N, 4);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestTfliteParserArgmax : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserArgmax() = default;
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./argmax.tflite", ""); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserArgmax, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMax) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserArgmax, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMax(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsArgMax();
|
||||||
|
ASSERT_EQ(val->axis, 1);
|
||||||
|
ASSERT_EQ(val->topK, 1);
|
||||||
|
ASSERT_EQ(val->axisType, 1);
|
||||||
|
ASSERT_EQ(val->keepDims, false);
|
||||||
|
ASSERT_EQ(val->outMaxValue, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -25,15 +25,14 @@ class TestTfliteParserArgmin : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserArgmin, OpType) {
|
TEST_F(TestTfliteParserArgmin, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserArgmin, AttrValue) {
|
TEST_F(TestTfliteParserArgmin, AttrValue) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMin(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsArgMin();
|
auto val = meta_graph->nodes.front()->primitive->value.AsArgMin();
|
||||||
ASSERT_EQ(val->axis, 1);
|
ASSERT_EQ(val->axis, 1);
|
||||||
ASSERT_EQ(val->topK, 1);
|
ASSERT_EQ(val->topK, 1);
|
||||||
|
|
|
@ -19,234 +19,57 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
// doubleInputOp
|
// doubleInputOp
|
||||||
class TestTfliteParserAdd1 : public TestTfliteParser {
|
class TestTfliteParserAdd : public TestTfliteParser {
|
||||||
public:
|
public:
|
||||||
TestTfliteParserAdd1() = default;
|
TestTfliteParserAdd() = default;
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./add1.tflite", ""); }
|
void SetUp() override { meta_graph = LoadAndConvert("./add.tflite", ""); }
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserAdd1, OpType) {
|
TEST_F(TestTfliteParserAdd, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserAdd1, Tensor) {
|
class TestTfliteParserSub : public TestTfliteParser {
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserAdd2 : public TestTfliteParser {
|
|
||||||
public:
|
public:
|
||||||
TestTfliteParserAdd2() = default;
|
TestTfliteParserSub() = default;
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./add2.tflite", ""); }
|
void SetUp() override { meta_graph = LoadAndConvert("./sub.tflite", ""); }
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserAdd2, OpType) {
|
TEST_F(TestTfliteParserSub, OpType) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserAdd2, Tensor) {
|
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserAdd3 : public TestTfliteParser {
|
|
||||||
public:
|
|
||||||
TestTfliteParserAdd3() = default;
|
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./add3.tflite", ""); }
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserAdd3, OpType) {
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserAdd3, Tensor) {
|
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserSub1 : public TestTfliteParser {
|
|
||||||
public:
|
|
||||||
TestTfliteParserSub1() = default;
|
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./sub1.tflite", ""); }
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSub1, OpType) {
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSub1, Tensor) {
|
class TestTfliteParserMul : public TestTfliteParser {
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserSub2 : public TestTfliteParser {
|
|
||||||
public:
|
public:
|
||||||
TestTfliteParserSub2() = default;
|
TestTfliteParserMul() = default;
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./sub2.tflite", ""); }
|
void SetUp() override { meta_graph = LoadAndConvert("./mul.tflite", ""); }
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSub2, OpType) {
|
TEST_F(TestTfliteParserMul, OpType) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSub2, Tensor) {
|
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserSub3 : public TestTfliteParser {
|
|
||||||
public:
|
|
||||||
TestTfliteParserSub3() = default;
|
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./sub3.tflite", ""); }
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSub3, OpType) {
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSub3, Tensor) {
|
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserMul1 : public TestTfliteParser {
|
|
||||||
public:
|
|
||||||
TestTfliteParserMul1() = default;
|
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./mul1.tflite", ""); }
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserMul1, OpType) {
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserMul1, Tensor) {
|
class TestTfliteParserDiv : public TestTfliteParser {
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserMul2 : public TestTfliteParser {
|
|
||||||
public:
|
public:
|
||||||
TestTfliteParserMul2() = default;
|
TestTfliteParserDiv() = default;
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./mul2.tflite", ""); }
|
void SetUp() override { meta_graph = LoadAndConvert("./div.tflite", ""); }
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserMul2, OpType) {
|
TEST_F(TestTfliteParserDiv, OpType) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserMul2, Tensor) {
|
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserMul3 : public TestTfliteParser {
|
|
||||||
public:
|
|
||||||
TestTfliteParserMul3() = default;
|
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./mul3.tflite", ""); }
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserMul3, OpType) {
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserMul3, Tensor) {
|
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserDiv1 : public TestTfliteParser {
|
|
||||||
public:
|
|
||||||
TestTfliteParserDiv1() = default;
|
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./div1.tflite", ""); }
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserDiv1, OpType) {
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserDiv1, Tensor) {
|
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserDiv2 : public TestTfliteParser {
|
|
||||||
public:
|
|
||||||
TestTfliteParserDiv2() = default;
|
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./div2.tflite", ""); }
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserDiv2, OpType) {
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserDiv2, Tensor) {
|
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserDiv3 : public TestTfliteParser {
|
|
||||||
public:
|
|
||||||
TestTfliteParserDiv3() = default;
|
|
||||||
void SetUp() override { meta_graph = LoadAndConvert("./div3.tflite", ""); }
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserDiv3, OpType) {
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestTfliteParserDiv3, Tensor) {
|
|
||||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
|
|
||||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestTfliteParserFloorDiv : public TestTfliteParser {
|
class TestTfliteParserFloorDiv : public TestTfliteParser {
|
||||||
public:
|
public:
|
||||||
TestTfliteParserFloorDiv() = default;
|
TestTfliteParserFloorDiv() = default;
|
||||||
|
@ -254,6 +77,7 @@ class TestTfliteParserFloorDiv : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserFloorDiv, OpType) {
|
TEST_F(TestTfliteParserFloorDiv, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type";
|
||||||
|
@ -266,12 +90,26 @@ class TestTfliteParserFloorMod : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserFloorMod, OpType) {
|
TEST_F(TestTfliteParserFloorMod, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
// realDiv
|
class TestTfliteParserRealDiv : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserRealDiv() = default;
|
||||||
|
void SetUp() override {
|
||||||
|
meta_graph = LoadAndConvert("./realdiv.tflite");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserRealDiv, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
|
||||||
class TestTfliteParserSquaredDifference : public TestTfliteParser {
|
class TestTfliteParserSquaredDifference : public TestTfliteParser {
|
||||||
public:
|
public:
|
||||||
|
@ -296,17 +134,15 @@ class TestTfliteParserPow : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserPow, OpType) {
|
TEST_F(TestTfliteParserPow, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserPow, AttrValue) {
|
TEST_F(TestTfliteParserPow, AttrValue) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPower(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsPower();
|
auto val = meta_graph->nodes.front()->primitive->value.AsPower();
|
||||||
|
|
||||||
ASSERT_EQ(val->scale, 1.0);
|
ASSERT_EQ(val->scale, 1.0);
|
||||||
ASSERT_EQ(val->shift, 0.0);
|
ASSERT_EQ(val->shift, 0.0);
|
||||||
ASSERT_EQ(val->power, 0.0);
|
ASSERT_EQ(val->power, 0.0);
|
||||||
|
@ -477,6 +313,7 @@ class TestTfliteParserFloor : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserFloor, OpType) {
|
TEST_F(TestTfliteParserFloor, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type";
|
||||||
|
|
|
@ -32,14 +32,12 @@ TEST_F(TestTfliteParserBatchToSpaceNd, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) {
|
TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) {
|
||||||
const std::vector<int> blockShape{2, 2};
|
|
||||||
const std::vector<int> crops{0, 0, 2, 0};
|
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->blockShape, blockShape);
|
auto val = meta_graph->nodes.front()->primitive->value.AsBatchToSpace();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->crops, crops);
|
const std::vector<int> blockShape = {2, 2};
|
||||||
|
ASSERT_EQ(val->blockShape, blockShape);
|
||||||
|
const std::vector<int> crops = {0, 0, 2, 0};
|
||||||
|
ASSERT_EQ(val->crops, crops);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,12 +35,9 @@ TEST_F(TestTfliteParserCast, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserCast, AttrValue) {
|
TEST_F(TestTfliteParserCast, AttrValue) {
|
||||||
// float32 --> int32
|
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->srcT, 43);
|
auto val = meta_graph->nodes.front()->primitive->value.AsCast();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->dstT, 34);
|
ASSERT_EQ(val->srcT, 43);
|
||||||
|
ASSERT_EQ(val->dstT, 34);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestTfliteParserConcat : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserConcat() = default;
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./concat.tflite", ""); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserConcat, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Concat) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserConcat, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsConcat();
|
||||||
|
ASSERT_EQ(val->axis, 1);
|
||||||
|
ASSERT_EQ(val->n, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestTfliteParserConv : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserConv() = default;
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./conv.tflite", ""); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserConv, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
|
||||||
|
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserConv, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
|
||||||
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
|
ASSERT_EQ(val->group, 1);
|
||||||
|
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||||
|
ASSERT_EQ(val->hasBias, true);
|
||||||
|
ASSERT_EQ(val->channelIn, 1);
|
||||||
|
ASSERT_EQ(val->channelOut, 4);
|
||||||
|
ASSERT_EQ(val->kernelH, 3);
|
||||||
|
ASSERT_EQ(val->kernelW, 3);
|
||||||
|
ASSERT_EQ(val->strideH, 1);
|
||||||
|
ASSERT_EQ(val->strideW, 1);
|
||||||
|
ASSERT_EQ(val->dilateH, 1);
|
||||||
|
ASSERT_EQ(val->dilateW, 1);
|
||||||
|
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
|
||||||
|
ASSERT_EQ(val->padUp, 1);
|
||||||
|
ASSERT_EQ(val->padDown, 1);
|
||||||
|
ASSERT_EQ(val->padLeft, 1);
|
||||||
|
ASSERT_EQ(val->padRight, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestTfliteParserDeConv : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserDeConv() = default;
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./deconv.tflite", ""); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserDeConv, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
|
||||||
|
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserDeConv, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D();
|
||||||
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
|
ASSERT_EQ(val->group, 1);
|
||||||
|
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||||
|
ASSERT_EQ(val->hasBias, true);
|
||||||
|
|
||||||
|
ASSERT_EQ(val->channelIn, 1);
|
||||||
|
ASSERT_EQ(val->channelOut, 4);
|
||||||
|
ASSERT_EQ(val->kernelH, 3);
|
||||||
|
ASSERT_EQ(val->kernelW, 3);
|
||||||
|
ASSERT_EQ(val->strideH, 1);
|
||||||
|
ASSERT_EQ(val->strideW, 1);
|
||||||
|
ASSERT_EQ(val->dilateH, 1);
|
||||||
|
ASSERT_EQ(val->dilateW, 1);
|
||||||
|
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
|
||||||
|
ASSERT_EQ(val->padUp, 1);
|
||||||
|
ASSERT_EQ(val->padDown, 1);
|
||||||
|
ASSERT_EQ(val->padLeft, 1);
|
||||||
|
ASSERT_EQ(val->padRight, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -35,11 +35,9 @@ TEST_F(TestTfliteParserDepthToSpace, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserDepthToSpace, AttrValue) {
|
TEST_F(TestTfliteParserDepthToSpace, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->blockSize, 4);
|
auto val = meta_graph->nodes.front()->primitive->value.AsDepthToSpace();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->format, schema::Format_NHWC);
|
ASSERT_EQ(val->blockSize, 4);
|
||||||
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestTfliteParserDepthwiseConv1 : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserDepthwiseConv1() = default;
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv1.tflite", ""); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserDepthwiseConv1, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
|
||||||
|
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
|
||||||
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
|
ASSERT_EQ(val->group, 0);
|
||||||
|
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||||
|
ASSERT_EQ(val->hasBias, true);
|
||||||
|
ASSERT_EQ(val->channelIn, 1);
|
||||||
|
ASSERT_EQ(val->channelOut, 4);
|
||||||
|
ASSERT_EQ(val->kernelH, 3);
|
||||||
|
ASSERT_EQ(val->kernelW, 3);
|
||||||
|
ASSERT_EQ(val->strideH, 1);
|
||||||
|
ASSERT_EQ(val->strideW, 1);
|
||||||
|
ASSERT_EQ(val->dilateH, 1);
|
||||||
|
ASSERT_EQ(val->dilateW, 1);
|
||||||
|
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
|
||||||
|
ASSERT_EQ(val->padUp, 1);
|
||||||
|
ASSERT_EQ(val->padDown, 1);
|
||||||
|
ASSERT_EQ(val->padLeft, 1);
|
||||||
|
ASSERT_EQ(val->padRight, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
class TestTfliteParserDepthwiseConv2 : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserDepthwiseConv2() = default;
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv2.tflite", ""); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserDepthwiseConv2, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
|
||||||
|
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D();
|
||||||
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
|
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||||
|
ASSERT_EQ(val->hasBias, true);
|
||||||
|
ASSERT_EQ(val->channelIn, 2);
|
||||||
|
ASSERT_EQ(val->channelMultiplier, 1);
|
||||||
|
ASSERT_EQ(val->kernelH, 3);
|
||||||
|
ASSERT_EQ(val->kernelW, 3);
|
||||||
|
ASSERT_EQ(val->strideH, 1);
|
||||||
|
ASSERT_EQ(val->strideW, 1);
|
||||||
|
ASSERT_EQ(val->dilateH, 1);
|
||||||
|
ASSERT_EQ(val->dilateW, 1);
|
||||||
|
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
|
||||||
|
ASSERT_EQ(val->padUp, 1);
|
||||||
|
ASSERT_EQ(val->padDown, 1);
|
||||||
|
ASSERT_EQ(val->padLeft, 1);
|
||||||
|
ASSERT_EQ(val->padRight, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -25,17 +25,15 @@ class TestTfliteParserFill : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserFill, OpType) {
|
TEST_F(TestTfliteParserFill, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserFill, AttrValue) {
|
TEST_F(TestTfliteParserFill, AttrValue) {;
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsFill(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsFill();
|
auto val = meta_graph->nodes.front()->primitive->value.AsFill();
|
||||||
|
|
||||||
std::vector<int32_t> dims = {9};
|
std::vector<int32_t> dims = {9};
|
||||||
ASSERT_EQ(val->dims, dims);
|
ASSERT_EQ(val->dims, dims);
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,15 +25,14 @@ class TestTfliteParserGatherNd : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserGatherNd, OpType) {
|
TEST_F(TestTfliteParserGatherNd, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserGatherNd, AttrValue) {
|
TEST_F(TestTfliteParserGatherNd, AttrValue) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd();
|
auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd();
|
||||||
ASSERT_EQ(val->batchDims, 0);
|
ASSERT_EQ(val->batchDims, 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,15 +25,14 @@ class TestTfliteParserGather : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserGather, OpType) {
|
TEST_F(TestTfliteParserGather, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserGather, AttrValue) {
|
TEST_F(TestTfliteParserGather, AttrValue) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGather(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsGather();
|
auto val = meta_graph->nodes.front()->primitive->value.AsGather();
|
||||||
ASSERT_EQ(val->axis, 0);
|
ASSERT_EQ(val->axis, 0);
|
||||||
ASSERT_EQ(val->batchDims, 0);
|
ASSERT_EQ(val->batchDims, 0);
|
||||||
|
|
|
@ -25,6 +25,7 @@ class TestTfliteParserLRN : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserLRN, OpType) {
|
TEST_F(TestTfliteParserLRN, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type,
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type,
|
||||||
|
@ -32,9 +33,7 @@ TEST_F(TestTfliteParserLRN, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserLRN, AttrValue) {
|
TEST_F(TestTfliteParserLRN, AttrValue) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization();
|
auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization();
|
||||||
ASSERT_EQ(val->alpha, 1);
|
ASSERT_EQ(val->alpha, 1);
|
||||||
ASSERT_EQ(val->beta, 0.5);
|
ASSERT_EQ(val->beta, 0.5);
|
||||||
|
|
|
@ -32,12 +32,9 @@ TEST_F(TestTfliteParserOneHot, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserOneHot, AttrValue) {
|
TEST_F(TestTfliteParserOneHot, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr);
|
||||||
// in OneHot parser axis = axis > 0 ? axis : axis + tensor_shape.size()
|
auto val = meta_graph->nodes.front()->primitive->value.AsOneHot();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsOneHot()->axis, 2);
|
ASSERT_EQ(val->axis, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -25,17 +25,15 @@ class TestTfliteParserPad : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserPad, OpType) {
|
TEST_F(TestTfliteParserPad, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserPad, AttrValue) {
|
TEST_F(TestTfliteParserPad, AttrValue) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPad(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsPad();
|
auto val = meta_graph->nodes.front()->primitive->value.AsPad();
|
||||||
|
|
||||||
std::vector<int32_t> paddings = {1, 1, 2, 2, 3, 3, 4, 4};
|
std::vector<int32_t> paddings = {1, 1, 2, 2, 3, 3, 4, 4};
|
||||||
ASSERT_EQ(val->paddings, paddings);
|
ASSERT_EQ(val->paddings, paddings);
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,12 +35,8 @@ TEST_F(TestTfliteParserMaxPooling, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserMaxPooling, AttrValue) {
|
TEST_F(TestTfliteParserMaxPooling, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
|
auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
|
||||||
ASSERT_NE(val, nullptr);
|
|
||||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING);
|
ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING);
|
||||||
ASSERT_EQ(val->global, false);
|
ASSERT_EQ(val->global, false);
|
||||||
|
@ -72,12 +68,8 @@ TEST_F(TestTfliteParserAvgPooling, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserAvgPooling, AttrValue) {
|
TEST_F(TestTfliteParserAvgPooling, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
|
auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
|
||||||
ASSERT_NE(val, nullptr);
|
|
||||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING);
|
ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING);
|
||||||
ASSERT_EQ(val->global, false);
|
ASSERT_EQ(val->global, false);
|
||||||
|
|
|
@ -32,13 +32,9 @@ TEST_F(TestTfliteParserReduceMax, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserReduceMax, AttrValue) {
|
TEST_F(TestTfliteParserReduceMax, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
||||||
ASSERT_NE(val, nullptr);
|
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax);
|
||||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax) << "wrong reduce mode";
|
|
||||||
ASSERT_EQ(val->keepDims, false);
|
ASSERT_EQ(val->keepDims, false);
|
||||||
std::vector<int32_t> axes = {2};
|
std::vector<int32_t> axes = {2};
|
||||||
ASSERT_EQ(val->axes, axes);
|
ASSERT_EQ(val->axes, axes);
|
||||||
|
@ -58,13 +54,9 @@ TEST_F(TestTfliteParserReduceMin, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserReduceMin, AttrValue) {
|
TEST_F(TestTfliteParserReduceMin, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
||||||
ASSERT_NE(val, nullptr);
|
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin);
|
||||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin) << "wrong reduce mode";
|
|
||||||
ASSERT_EQ(val->keepDims, false);
|
ASSERT_EQ(val->keepDims, false);
|
||||||
std::vector<int32_t> axes = {2};
|
std::vector<int32_t> axes = {2};
|
||||||
ASSERT_EQ(val->axes, axes);
|
ASSERT_EQ(val->axes, axes);
|
||||||
|
@ -84,13 +76,9 @@ TEST_F(TestTfliteParserReduceProd, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserReduceProd, AttrValue) {
|
TEST_F(TestTfliteParserReduceProd, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
||||||
ASSERT_NE(val, nullptr);
|
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd);
|
||||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd) << "wrong reduce mode";
|
|
||||||
ASSERT_EQ(val->keepDims, false);
|
ASSERT_EQ(val->keepDims, false);
|
||||||
std::vector<int32_t> axes = {2};
|
std::vector<int32_t> axes = {2};
|
||||||
ASSERT_EQ(val->axes, axes);
|
ASSERT_EQ(val->axes, axes);
|
||||||
|
@ -111,13 +99,9 @@ TEST_F(TestTfliteParserSum, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSum, AttrValue) {
|
TEST_F(TestTfliteParserSum, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
||||||
ASSERT_NE(val, nullptr);
|
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum);
|
||||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum) << "wrong reduce mode";
|
|
||||||
ASSERT_EQ(val->keepDims, false);
|
ASSERT_EQ(val->keepDims, false);
|
||||||
std::vector<int32_t> axes = {2};
|
std::vector<int32_t> axes = {2};
|
||||||
ASSERT_EQ(val->axes, axes);
|
ASSERT_EQ(val->axes, axes);
|
||||||
|
@ -138,13 +122,9 @@ TEST_F(TestTfliteParserMean, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserMean, AttrValue) {
|
TEST_F(TestTfliteParserMean, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
||||||
ASSERT_NE(val, nullptr);
|
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean);
|
||||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean) << "wrong reduce mode";
|
|
||||||
ASSERT_EQ(val->keepDims, true);
|
ASSERT_EQ(val->keepDims, true);
|
||||||
std::vector<int32_t> axes = {2, 3};
|
std::vector<int32_t> axes = {2, 3};
|
||||||
ASSERT_EQ(val->axes, axes);
|
ASSERT_EQ(val->axes, axes);
|
||||||
|
|
|
@ -35,12 +35,9 @@ TEST_F(TestTfliteParserReshape, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserReshape, AttrValue) {
|
TEST_F(TestTfliteParserReshape, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsReshape();
|
||||||
std::vector<int64_t> shape = {3, 5, 20};
|
std::vector<int64_t> shape = {3, 5, 20};
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReshape()->shape, shape); // int32
|
ASSERT_EQ(val->shape, shape);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -26,17 +26,15 @@ class TestTfliteParserResizeNN : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserResizeNN, OpType) {
|
TEST_F(TestTfliteParserResizeNN, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserResizeNN, AttrValue) {
|
TEST_F(TestTfliteParserResizeNN, AttrValue) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsResize();
|
auto val = meta_graph->nodes.front()->primitive->value.AsResize();
|
||||||
ASSERT_NE(val, nullptr);
|
|
||||||
ASSERT_EQ(val->alignCorners, false);
|
ASSERT_EQ(val->alignCorners, false);
|
||||||
ASSERT_EQ(val->newHeight, 3);
|
ASSERT_EQ(val->newHeight, 3);
|
||||||
ASSERT_EQ(val->newWidth, 100);
|
ASSERT_EQ(val->newWidth, 100);
|
||||||
|
@ -52,17 +50,15 @@ class TestTfliteParserResizeBilinear : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserResizeBilinear, OpType) {
|
TEST_F(TestTfliteParserResizeBilinear, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserResizeBilinear, AttrValue) {
|
TEST_F(TestTfliteParserResizeBilinear, AttrValue) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsResize();
|
auto val = meta_graph->nodes.front()->primitive->value.AsResize();
|
||||||
ASSERT_NE(val, nullptr);
|
|
||||||
ASSERT_EQ(val->alignCorners, false);
|
ASSERT_EQ(val->alignCorners, false);
|
||||||
ASSERT_EQ(val->newHeight, 75);
|
ASSERT_EQ(val->newHeight, 75);
|
||||||
ASSERT_EQ(val->newWidth, 4);
|
ASSERT_EQ(val->newWidth, 4);
|
||||||
|
|
|
@ -25,17 +25,15 @@ class TestTfliteParserReverse : public TestTfliteParser {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestTfliteParserReverse, OpType) {
|
TEST_F(TestTfliteParserReverse, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type";
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserReverse, AttrValue) {
|
TEST_F(TestTfliteParserReverse, AttrValue) {
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverse(), nullptr);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
|
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsReverse();
|
auto val = meta_graph->nodes.front()->primitive->value.AsReverse();
|
||||||
|
|
||||||
std::vector<int32_t> axis = {3};
|
std::vector<int32_t> axis = {3};
|
||||||
ASSERT_EQ(val->axis, axis);
|
ASSERT_EQ(val->axis, axis);
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,13 +35,11 @@ TEST_F(TestTfliteParserReverseSequence, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserReverseSequence, AttrValue) {
|
TEST_F(TestTfliteParserReverseSequence, AttrValue) {
|
||||||
std::vector<int> seq_length{7, 2, 3, 5};
|
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1);
|
auto val = meta_graph->nodes.front()->primitive->value.AsReverseSequence();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1);
|
ASSERT_EQ(val->seqAxis, 1);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqLengths, seq_length);
|
ASSERT_EQ(val->seqAxis, 1);
|
||||||
|
std::vector<int> seq_length = {7, 2, 3, 5};
|
||||||
|
ASSERT_EQ(val->seqLengths, seq_length);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestTfliteParserSlice : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserSlice() = default;
|
||||||
|
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./slice.tflite"); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserSlice, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Slice) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserSlice, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSlice(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsSlice();
|
||||||
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
|
std::vector<int32_t> begin = {1, 0, 0};
|
||||||
|
ASSERT_EQ(val->begin, begin);
|
||||||
|
std::vector<int32_t> size = {1, 1, 3};
|
||||||
|
ASSERT_EQ(val->size, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSoftmax, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSoftmax, AttrValue) {
|
TEST_F(TestTfliteParserSoftmax, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr);
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
auto val = meta_graph->nodes.front()->primitive->value.AsSoftMax();
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
ASSERT_EQ(val->axis, -1);
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr);
|
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSoftMax()->axis, -1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,13 +35,11 @@ TEST_F(TestTfliteParserSpaceToBatchND, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) {
|
TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) {
|
||||||
std::vector<int> blockshape{2, 2};
|
|
||||||
std::vector<int> padding{0, 0, 2, 0};
|
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->blockShape, blockshape);
|
auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->paddings, padding);
|
std::vector<int> blockshape = {2, 2};
|
||||||
|
ASSERT_EQ(val->blockShape, blockshape);
|
||||||
|
std::vector<int> padding = {0, 0, 2, 0};
|
||||||
|
ASSERT_EQ(val->paddings, padding);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSpaceToDepth, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSpaceToDepth, AttrValue) {
|
TEST_F(TestTfliteParserSpaceToDepth, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->blockSize, 2);
|
auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToDepth();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->format, schema::Format_NHWC);
|
ASSERT_EQ(val->blockSize, 2);
|
||||||
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,16 +35,14 @@ TEST_F(TestTfliteParserSparseToDense, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSparseToDense, AttrValue) {
|
TEST_F(TestTfliteParserSparseToDense, AttrValue) {
|
||||||
std::vector<int> outputShape{5, 5};
|
|
||||||
std::vector<int> sparseValue{1};
|
|
||||||
std::vector<int> defaultValue{0};
|
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->outputShape, outputShape);
|
auto val = meta_graph->nodes.front()->primitive->value.AsSparseToDense();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->sparseValue, sparseValue);
|
std::vector<int> outputShape = {5, 5};
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->defaultValue, defaultValue);
|
ASSERT_EQ(val->outputShape, outputShape);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->validateIndices, false);
|
std::vector<int> sparseValue = {1};
|
||||||
|
ASSERT_EQ(val->sparseValue, sparseValue);
|
||||||
|
std::vector<int> defaultValue = {0};
|
||||||
|
ASSERT_EQ(val->defaultValue, defaultValue);
|
||||||
|
ASSERT_EQ(val->validateIndices, false);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplit, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSplit, AttrValue) {
|
TEST_F(TestTfliteParserSplit, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
|
||||||
const std::vector<int> sizeSplits{2, 2};
|
auto val = meta_graph->nodes.front()->primitive->value.AsSplit();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 2);
|
ASSERT_EQ(val->splitDim, 2);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2);
|
ASSERT_EQ(val->numberSplit, 2);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits);
|
const std::vector<int> sizeSplits = {2, 2};
|
||||||
|
ASSERT_EQ(val->sizeSplits, sizeSplits);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplitV, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserSplitV, AttrValue) {
|
TEST_F(TestTfliteParserSplitV, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
|
||||||
const std::vector<int> sizeSplits{1, 3};
|
auto val = meta_graph->nodes.front()->primitive->value.AsSplit();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 0);
|
ASSERT_EQ(val->splitDim, 0);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2);
|
ASSERT_EQ(val->numberSplit, 2);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits);
|
const std::vector<int> sizeSplits = {1, 3};
|
||||||
|
ASSERT_EQ(val->sizeSplits, sizeSplits);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestTfliteParserStack : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserStack() = default;
|
||||||
|
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./stack.tflite"); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserStack, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Stack) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserStack, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStack(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsStack();
|
||||||
|
ASSERT_EQ(val->axis, 1);
|
||||||
|
ASSERT_EQ(val->n, 2);
|
||||||
|
const std::vector<int> isScale = {3, 2, 3};
|
||||||
|
ASSERT_EQ(val->isScale, isScale);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -35,21 +35,19 @@ TEST_F(TestTfliteParserStridedSlice, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserStridedSlice, AttrValue) {
|
TEST_F(TestTfliteParserStridedSlice, AttrValue) {
|
||||||
std::vector<int> begin{1, -1, 0};
|
|
||||||
std::vector<int> end{2, -3, 3};
|
|
||||||
std::vector<int> stride{1, -1, 1};
|
|
||||||
std::vector<int> isscale{3, 2, 3};
|
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
|
auto val = meta_graph->nodes.front()->primitive->value.AsStridedSlice();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->endMask, 0);
|
ASSERT_EQ(val->beginMask, 0);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
|
ASSERT_EQ(val->endMask, 0);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
|
ASSERT_EQ(val->beginMask, 0);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->begin, begin);
|
ASSERT_EQ(val->beginMask, 0);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->end, end);
|
std::vector<int> begin = {1, -1, 0};
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->stride, stride);
|
ASSERT_EQ(val->begin, begin);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->isScale, isscale);
|
std::vector<int> end = {2, -3, 3};
|
||||||
|
ASSERT_EQ(val->end, end);
|
||||||
|
std::vector<int> stride = {1, -1, 1};
|
||||||
|
ASSERT_EQ(val->stride, stride);
|
||||||
|
std::vector<int> isscale = {3, 2, 3};
|
||||||
|
ASSERT_EQ(val->isScale, isscale);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,11 +35,9 @@ TEST_F(TestTfliteParserTile, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserTile, AttrValue) {
|
TEST_F(TestTfliteParserTile, AttrValue) {
|
||||||
std::vector<int> multiply{2, 3, 4};
|
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTile()->multiples, multiply);
|
auto val = meta_graph->nodes.front()->primitive->value.AsTile();
|
||||||
|
std::vector<int> multiply = {2, 3, 4};
|
||||||
|
ASSERT_EQ(val->multiples, multiply);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,13 +35,10 @@ TEST_F(TestTfliteParserTopKV2, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserTopKV2, AttrValue) {
|
TEST_F(TestTfliteParserTopKV2, AttrValue) {
|
||||||
// attr->sorted default is true
|
|
||||||
std::vector<int> k{3};
|
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->k, k);
|
auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->sorted, true);
|
std::vector<int> k = {3};
|
||||||
|
ASSERT_EQ(val->k, k);
|
||||||
|
ASSERT_EQ(val->sorted, true);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestTfliteParserTranspose : public TestTfliteParser {
|
||||||
|
public:
|
||||||
|
TestTfliteParserTranspose() = default;
|
||||||
|
|
||||||
|
void SetUp() override { meta_graph = LoadAndConvert("./transpose.tflite"); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserTranspose, OpType) {
|
||||||
|
ASSERT_NE(meta_graph, nullptr);
|
||||||
|
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||||
|
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestTfliteParserTranspose, AttrValue) {
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTranspose(), nullptr);
|
||||||
|
auto val = meta_graph->nodes.front()->primitive->value.AsTranspose();
|
||||||
|
ASSERT_EQ(val->conjugate, false);
|
||||||
|
std::vector<int32_t> perm = {1, 0};
|
||||||
|
ASSERT_EQ(val->perm, perm);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -35,10 +35,9 @@ TEST_F(TestTfliteParserUnique, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserUnique, AttrValue) {
|
TEST_F(TestTfliteParserUnique, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnique()->outType, 34); // int32
|
auto val = meta_graph->nodes.front()->primitive->value.AsUnique();
|
||||||
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
|
||||||
|
ASSERT_EQ(val->outType, 34);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,11 +35,9 @@ TEST_F(TestTfliteParserUnstack, OpType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestTfliteParserUnstack, AttrValue) {
|
TEST_F(TestTfliteParserUnstack, AttrValue) {
|
||||||
ASSERT_NE(meta_graph, nullptr);
|
|
||||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr);
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->num, 5);
|
auto val = meta_graph->nodes.front()->primitive->value.AsUnstack();
|
||||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->axis, 1);
|
ASSERT_EQ(val->num, 5);
|
||||||
|
ASSERT_EQ(val->axis, 1);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -353,7 +353,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
||||||
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
|
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
|
||||||
} else if (weightTensor->format == schema::Format_KCHW) {
|
} else if (weightTensor->format == schema::Format_KCHW) {
|
||||||
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
|
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
|
||||||
} else if (weightTensor->format == schema::Format_CHWK) {
|
} else if (weightTensor->format == schema::Format_CHWK) { // from tflite
|
||||||
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
|
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
||||||
|
@ -369,7 +369,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
||||||
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
|
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
|
||||||
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
|
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
|
||||||
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
|
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
|
||||||
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
|
} else if (weightTensor->format == schema::Format_CHWK) { // from tflite
|
||||||
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
|
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
||||||
|
|
|
@ -21,11 +21,16 @@ namespace lite {
|
||||||
STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
|
STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
|
||||||
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
|
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
// MS_LOGE("null pointer dereferencing.");
|
// MS_LOG(ERROR) << "null pointer dereferencing.";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
std::unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT());
|
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
|
||||||
attr->format = schema::Format_NCHW;
|
const caffe::FlattenParameter flattenParam = proto.flatten_param();
|
||||||
|
|
||||||
|
attr->axis = (int32_t)flattenParam.axis();
|
||||||
|
attr->useAxis = true;
|
||||||
|
attr->hasBias = false;
|
||||||
|
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||||
|
|
||||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
op->primitive->value.type = schema::PrimitiveType_Flatten;
|
op->primitive->value.type = schema::PrimitiveType_Flatten;
|
||||||
|
|
|
@ -14,18 +14,21 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "tools/converter/parser/tflite/tflite_activation_parser.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "tools/converter/parser/tflite/tflite_activation_parser.h"
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -35,13 +38,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
||||||
MS_LOG(ERROR) << "op->primitive is null";
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
|
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
|
||||||
|
|
||||||
std::vector<std::string> node_name_str;
|
std::vector<std::string> node_name_str;
|
||||||
Split(op->name, &node_name_str, "-");
|
Split(op->name, &node_name_str, "-");
|
||||||
const char *node_name = node_name_str.data()->c_str();
|
const char *node_name = node_name_str.data()->c_str();
|
||||||
|
|
||||||
if (std::strcmp(node_name, "Relu") == 0) {
|
if (std::strcmp(node_name, "Relu") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteReluParser";
|
MS_LOG(DEBUG) << "parse TfliteReluParser";
|
||||||
attr->type = schema::ActivationType_RELU;
|
attr->type = schema::ActivationType_RELU;
|
||||||
|
@ -54,29 +55,65 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
||||||
} else if (std::strcmp(node_name, "Logistic") == 0) {
|
} else if (std::strcmp(node_name, "Logistic") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
|
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
|
||||||
attr->type = schema::ActivationType_SIGMOID;
|
attr->type = schema::ActivationType_SIGMOID;
|
||||||
} else if (std::strcmp(node_name, "LeakyRelu") == 0) {
|
} else if (std::strcmp(node_name, "HardSwish") == 0) {
|
||||||
const auto &option = tfliteOp->builtin_options.AsLeakyReluOptions();
|
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
|
||||||
if (option == nullptr) {
|
attr->type = schema::ActivationType_SIGMOID;
|
||||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
attr->type = schema::ActivationType_LEAKY_RELU;
|
|
||||||
attr->alpha = option->alpha;
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "wrong activation type";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
attr->alpha = 0.2f;
|
||||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TflitePreluParser";
|
||||||
|
|
||||||
|
if (op == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (op->primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
std::unique_ptr<schema::PreluT> attr(new schema::PreluT());
|
||||||
|
|
||||||
|
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
|
||||||
|
MS_LOG(ERROR) << "get pRelu -> slope failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_Prelu;
|
||||||
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
|
schema::CNodeT *op,
|
||||||
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteLeakyReluParser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -87,22 +124,29 @@ STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "paser TflitePreluParser";
|
std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT());
|
||||||
std::unique_ptr<schema::PreluT> attr(new schema::PreluT());
|
|
||||||
|
|
||||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
|
const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
|
||||||
MS_LOG(ERROR) << "get pRelu -> slope failed";
|
if (tflite_attr == nullptr) {
|
||||||
return RET_ERROR;
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
attr->negativeSlope = tflite_attr->alpha;
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Prelu;
|
op->primitive->value.type = schema::PrimitiveType_LeakyReLU;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
|
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
|
||||||
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
|
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
|
||||||
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
|
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
|
||||||
|
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
|
||||||
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
|
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
|
||||||
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
|
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
|
||||||
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
|
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
|
||||||
|
|
|
@ -14,13 +14,14 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PREDICT_TFLITE_RELU_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
|
||||||
#define PREDICT_TFLITE_RELU_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
|
||||||
|
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -29,11 +30,13 @@ class TfliteActivationParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteActivationParser() : TfliteNodeParser("node_name") {}
|
TfliteActivationParser() : TfliteNodeParser("node_name") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TfliteReluParser : public TfliteActivationParser {
|
class TfliteReluParser : public TfliteActivationParser {
|
||||||
|
@ -56,9 +59,9 @@ class TfliteLogisticParser : public TfliteActivationParser {
|
||||||
TfliteLogisticParser() : TfliteActivationParser() {}
|
TfliteLogisticParser() : TfliteActivationParser() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
class TfliteLeakyReluParser : public TfliteActivationParser {
|
class TfliteHardSwishParser : public TfliteActivationParser {
|
||||||
public:
|
public:
|
||||||
TfliteLeakyReluParser() : TfliteActivationParser() {}
|
TfliteHardSwishParser() : TfliteActivationParser() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
class TflitePreluParser : public TfliteNodeParser {
|
class TflitePreluParser : public TfliteNodeParser {
|
||||||
|
@ -68,12 +71,27 @@ class TflitePreluParser : public TfliteNodeParser {
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache, bool quantized_model) override;
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TfliteLeakyReluParser : public TfliteNodeParser {
|
||||||
|
public:
|
||||||
|
TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {}
|
||||||
|
|
||||||
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
|
schema::CNodeT *op,
|
||||||
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // PREDICT_TFLITE_RELU_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
|
||||||
|
|
||||||
|
|
|
@ -18,14 +18,20 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_addn_parser.h"
|
#include "tools/converter/parser/tflite/tflite_addn_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteAddNParser";
|
||||||
|
|
||||||
|
// set attr
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -36,13 +42,19 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteAddNParser";
|
|
||||||
std::unique_ptr<schema::AddNT> attr(new schema::AddNT());
|
std::unique_ptr<schema::AddNT> attr(new schema::AddNT());
|
||||||
|
|
||||||
attr->N = tfliteTensors.size() - 1;
|
attr->N = tflite_tensors.size() - 1;
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_AddN;
|
op->primitive->value.type = schema::PrimitiveType_AddN;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
// set input
|
||||||
|
for (int i = 0; i < tflite_op->inputs.size(); i++) {
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
}
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef LITE_TFLITE_ADDN_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H
|
||||||
#define LITE_TFLITE_ADDN_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -31,11 +32,12 @@ class TfliteAddNParser : public TfliteNodeParser {
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantized_model) override;
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // LITE_TFLITE_ADDN_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H
|
||||||
|
|
|
@ -17,16 +17,19 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_argmax_parser.h"
|
#include "tools/converter/parser/tflite/tflite_argmax_parser.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
|
||||||
schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantizedModel) {
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -37,7 +40,6 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
|
|
||||||
std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());
|
std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());
|
||||||
|
|
||||||
attr->outMaxValue = false;
|
attr->outMaxValue = false;
|
||||||
|
@ -45,9 +47,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
||||||
attr->keepDims = false;
|
attr->keepDims = false;
|
||||||
attr->axisType = 1;
|
attr->axisType = 1;
|
||||||
|
|
||||||
auto axis_idx = tfliteOp->inputs[1];
|
// get axis attr
|
||||||
std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){});
|
auto axis_idx = tflite_op->inputs[1];
|
||||||
auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer];
|
std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){});
|
||||||
|
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
|
||||||
if (buf_data == nullptr) {
|
if (buf_data == nullptr) {
|
||||||
MS_LOG(ERROR) << "the buf data is null";
|
MS_LOG(ERROR) << "the buf data is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -61,6 +64,11 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_ArgMax;
|
op->primitive->value.type = schema::PrimitiveType_ArgMax;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PREDICT_TFLITE_ARGMAX_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H
|
||||||
#define PREDICT_TFLITE_ARGMAX_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -28,14 +29,15 @@ class TfliteArgmaxParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteArgmaxParser() : TfliteNodeParser("Argmax") {}
|
TfliteArgmaxParser() : TfliteNodeParser("Argmax") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantizedModel) override;
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // PREDICT_TFLITE_ARGMAX_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H
|
||||||
|
|
|
@ -17,14 +17,19 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_argmin_parser.h"
|
#include "tools/converter/parser/tflite/tflite_argmin_parser.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteArgminParser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -35,7 +40,6 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteArgminParser";
|
|
||||||
std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT());
|
std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT());
|
||||||
|
|
||||||
attr->outMaxValue = false;
|
attr->outMaxValue = false;
|
||||||
|
@ -43,9 +47,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
||||||
attr->keepDims = false;
|
attr->keepDims = false;
|
||||||
attr->axisType = 1;
|
attr->axisType = 1;
|
||||||
|
|
||||||
auto axis_idx = tfliteOp->inputs[1];
|
// get axis attr
|
||||||
std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){});
|
auto axis_idx = tflite_op->inputs[1];
|
||||||
auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer];
|
std::for_each(tflite_tensors[axis_idx]->shape.begin(),
|
||||||
|
tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){});
|
||||||
|
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
|
||||||
if (buf_data == nullptr) {
|
if (buf_data == nullptr) {
|
||||||
MS_LOG(ERROR) << "the buf data is null";
|
MS_LOG(ERROR) << "the buf data is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -59,6 +65,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_ArgMin;
|
op->primitive->value.type = schema::PrimitiveType_ArgMin;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PREDICT_TFLITE_ARGMIN_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H
|
||||||
#define PREDICT_TFLITE_ARGMIN_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -28,14 +29,15 @@ class TfliteArgminParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteArgminParser() : TfliteNodeParser("Argmin") {}
|
TfliteArgminParser() : TfliteNodeParser("Argmin") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantizedModel) override;
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // PREDICT_TFLITE_ARGMIN_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H
|
||||||
|
|
|
@ -18,14 +18,17 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -37,124 +40,72 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> node_name_str;
|
std::vector<std::string> node_name_str;
|
||||||
Split(op->name.data(), &node_name_str, "-");
|
Split(op->name, &node_name_str, "-");
|
||||||
const char *node_name = node_name_str.data()->c_str();
|
const char *node_name = node_name_str.data()->c_str();
|
||||||
|
if (std::strcmp(node_name, "Add") == 0) {
|
||||||
if (std::strcmp(node_name, "Add") == 0
|
MS_LOG(DEBUG) << "parse TfliteAddParser";
|
||||||
|| std::strcmp(node_name, "Sub") == 0
|
std::unique_ptr<schema::AddT> attr(new schema::AddT());
|
||||||
|| std::strcmp(node_name, "Mul") == 0
|
const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions();
|
||||||
|| std::strcmp(node_name, "Div") == 0) {
|
if (nullptr == tfliteAttr) {
|
||||||
auto x_index = tfliteOp->inputs[0];
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
const auto &x_tensor = tfliteTensors[x_index];
|
|
||||||
if (x_tensor == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "the first input is null";
|
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
auto &x_data = tfliteModelBuffer.at(x_tensor->buffer);
|
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||||
if (x_data == nullptr) {
|
op->primitive->value.type = schema::PrimitiveType_Add;
|
||||||
MS_LOG(ERROR) << "the data of the first input is null";
|
op->primitive->value.value = attr.release();
|
||||||
|
} else if (std::strcmp(node_name, "Sub") == 0) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteSubParser";
|
||||||
|
std::unique_ptr<schema::SubT> attr(new schema::SubT());
|
||||||
|
const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions();
|
||||||
|
if (nullptr == tfliteAttr) {
|
||||||
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
if (!x_data->data.empty()) {
|
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||||
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
|
op->primitive->value.type = schema::PrimitiveType_Sub;
|
||||||
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
op->primitive->value.value = attr.release();
|
||||||
MS_LOG(ERROR) << "parse the first tensor failed";
|
} else if (std::strcmp(node_name, "Mul") == 0) {
|
||||||
return RET_ERROR;
|
MS_LOG(DEBUG) << "parse TfliteMulParser";
|
||||||
}
|
std::unique_ptr<schema::MulT> attr(new schema::MulT());
|
||||||
}
|
const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions();
|
||||||
|
if (nullptr == tfliteAttr) {
|
||||||
auto y_index = tfliteOp->inputs[1];
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
const auto &y_tensor = tfliteTensors[y_index];
|
|
||||||
if (y_tensor == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "the second input is null";
|
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
|
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||||
if (y_data == nullptr) {
|
op->primitive->value.type = schema::PrimitiveType_Mul;
|
||||||
MS_LOG(ERROR) << "the data of the second input is null";
|
op->primitive->value.value = attr.release();
|
||||||
|
} else if (std::strcmp(node_name, "Div") == 0) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteDivParser";
|
||||||
|
std::unique_ptr<schema::DivT> attr(new schema::DivT());
|
||||||
|
const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions();
|
||||||
|
if (nullptr == tfliteAttr) {
|
||||||
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
if (!y_data->data.empty()) {
|
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||||
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
|
op->primitive->value.type = schema::PrimitiveType_Div;
|
||||||
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
op->primitive->value.value = attr.release();
|
||||||
MS_LOG(ERROR) << "parse the second tensor failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (std::strcmp(node_name, "Add") == 0) {
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteAddParser";
|
|
||||||
std::unique_ptr<schema::AddT> attr(new schema::AddT());
|
|
||||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions();
|
|
||||||
if (nullptr == tfliteAttr) {
|
|
||||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Add;
|
|
||||||
op->primitive->value.value = attr.release();
|
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Sub") == 0) {
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteSubParser";
|
|
||||||
std::unique_ptr<schema::SubT> attr(new schema::SubT());
|
|
||||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions();
|
|
||||||
if (nullptr == tfliteAttr) {
|
|
||||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Sub;
|
|
||||||
op->primitive->value.value = attr.release();
|
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Mul") == 0) {
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteMulParser";
|
|
||||||
std::unique_ptr<schema::MulT> attr(new schema::MulT());
|
|
||||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions();
|
|
||||||
if (nullptr == tfliteAttr) {
|
|
||||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Mul;
|
|
||||||
op->primitive->value.value = attr.release();
|
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Div") == 0) {
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteDivParser";
|
|
||||||
std::unique_ptr<schema::DivT> attr(new schema::DivT());
|
|
||||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions();
|
|
||||||
if (nullptr == tfliteAttr) {
|
|
||||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Div;
|
|
||||||
op->primitive->value.value = attr.release();
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
} else if (std::strcmp(node_name, "FloorDiv") == 0) {
|
} else if (std::strcmp(node_name, "FloorDiv") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteFloorDivParser";
|
MS_LOG(DEBUG) << "parse TfliteFloorDivParser";
|
||||||
std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT());
|
std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_FloorDiv;
|
op->primitive->value.type = schema::PrimitiveType_FloorDiv;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "FloorMod") == 0) {
|
} else if (std::strcmp(node_name, "FloorMod") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteFloorModParser";
|
MS_LOG(DEBUG) << "parse TfliteFloorModParser";
|
||||||
std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT());
|
std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_FloorMod;
|
op->primitive->value.type = schema::PrimitiveType_FloorMod;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "RealDiv") == 0) {
|
} else if (std::strcmp(node_name, "RealDiv") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteRealDivParser";
|
MS_LOG(DEBUG) << "parse TfliteRealDivParser";
|
||||||
std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT());
|
std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_RealDiv;
|
op->primitive->value.type = schema::PrimitiveType_Div;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "SquaredDifference") == 0) {
|
} else if (std::strcmp(node_name, "SquaredDifference") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser";
|
MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser";
|
||||||
std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT());
|
std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_SquaredDifference;
|
op->primitive->value.type = schema::PrimitiveType_SquaredDifference;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Pow") == 0) {
|
} else if (std::strcmp(node_name, "Pow") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TflitePowParser";
|
MS_LOG(DEBUG) << "parse TflitePowParser";
|
||||||
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
|
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
|
||||||
|
@ -163,31 +114,35 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
||||||
attr->shift = 0.0f;
|
attr->shift = 0.0f;
|
||||||
op->primitive->value.type = schema::PrimitiveType_Power;
|
op->primitive->value.type = schema::PrimitiveType_Power;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Maximum") == 0) {
|
} else if (std::strcmp(node_name, "Maximum") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteMaximumParser";
|
MS_LOG(DEBUG) << "parse TfliteMaximumParser";
|
||||||
std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT());
|
std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Maximum;
|
op->primitive->value.type = schema::PrimitiveType_Maximum;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Minimum") == 0) {
|
} else if (std::strcmp(node_name, "Minimum") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteMinimumParser";
|
MS_LOG(DEBUG) << "parse TfliteMinimumParser";
|
||||||
std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT());
|
std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Minimum;
|
op->primitive->value.type = schema::PrimitiveType_Minimum;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "wrong op type";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set input
|
||||||
|
for (int i = 0; i < tflite_op->inputs.size(); i++) {
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
}
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -199,85 +154,79 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> node_name_str;
|
std::vector<std::string> node_name_str;
|
||||||
Split(op->name.data(), &node_name_str, "-");
|
Split(op->name, &node_name_str, "-");
|
||||||
const char *node_name = node_name_str.data()->c_str();
|
const char *node_name = node_name_str.data()->c_str();
|
||||||
if (std::strcmp(node_name, "Abs") == 0) {
|
if (std::strcmp(node_name, "Abs") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteAbsParser";
|
MS_LOG(DEBUG) << "parse TfliteAbsParser";
|
||||||
std::unique_ptr<schema::AbsT> attr(new schema::AbsT());
|
std::unique_ptr<schema::AbsT> attr(new schema::AbsT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Abs;
|
op->primitive->value.type = schema::PrimitiveType_Abs;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Exp") == 0) {
|
} else if (std::strcmp(node_name, "Exp") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteExpParser";
|
MS_LOG(DEBUG) << "parse TfliteExpParser";
|
||||||
std::unique_ptr<schema::ExpT> attr(new schema::ExpT());
|
std::unique_ptr<schema::ExpT> attr(new schema::ExpT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Exp;
|
op->primitive->value.type = schema::PrimitiveType_Exp;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Sqrt") == 0) {
|
} else if (std::strcmp(node_name, "Sqrt") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteSqrtParser";
|
MS_LOG(DEBUG) << "parse TfliteSqrtParser";
|
||||||
std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT());
|
std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Sqrt;
|
op->primitive->value.type = schema::PrimitiveType_Sqrt;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Rsqrt") == 0) {
|
} else if (std::strcmp(node_name, "Rsqrt") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteRsqrtParser";
|
MS_LOG(DEBUG) << "parse TfliteRsqrtParser";
|
||||||
std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT());
|
std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Rsqrt;
|
op->primitive->value.type = schema::PrimitiveType_Rsqrt;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Square") == 0) {
|
} else if (std::strcmp(node_name, "Square") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteSquareParser";
|
MS_LOG(DEBUG) << "parse TfliteSquareParser";
|
||||||
std::unique_ptr<schema::SquareT> attr(new schema::SquareT());
|
std::unique_ptr<schema::SquareT> attr(new schema::SquareT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Square;
|
op->primitive->value.type = schema::PrimitiveType_Square;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Sin") == 0) {
|
} else if (std::strcmp(node_name, "Sin") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteSinParser";
|
MS_LOG(DEBUG) << "parse TfliteSinParser";
|
||||||
std::unique_ptr<schema::SinT> attr(new schema::SinT());
|
std::unique_ptr<schema::SinT> attr(new schema::SinT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Sin;
|
op->primitive->value.type = schema::PrimitiveType_Sin;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Cos") == 0) {
|
} else if (std::strcmp(node_name, "Cos") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteCosParser";
|
MS_LOG(DEBUG) << "parse TfliteCosParser";
|
||||||
std::unique_ptr<schema::CosT> attr(new schema::CosT());
|
std::unique_ptr<schema::CosT> attr(new schema::CosT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Cos;
|
op->primitive->value.type = schema::PrimitiveType_Cos;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Log") == 0) {
|
} else if (std::strcmp(node_name, "Log") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteLogParser";
|
MS_LOG(DEBUG) << "parse TfliteLogParser";
|
||||||
std::unique_ptr<schema::LogT> attr(new schema::LogT());
|
std::unique_ptr<schema::LogT> attr(new schema::LogT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Log;
|
op->primitive->value.type = schema::PrimitiveType_Log;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Round") == 0) {
|
} else if (std::strcmp(node_name, "Round") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteRoundParser";
|
MS_LOG(DEBUG) << "parse TfliteRoundParser";
|
||||||
std::unique_ptr<schema::RoundT> attr(new schema::RoundT());
|
std::unique_ptr<schema::RoundT> attr(new schema::RoundT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Round;
|
op->primitive->value.type = schema::PrimitiveType_Round;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Ceil") == 0) {
|
} else if (std::strcmp(node_name, "Ceil") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteCeilParser";
|
MS_LOG(DEBUG) << "parse TfliteCeilParser";
|
||||||
std::unique_ptr<schema::CeilT> attr(new schema::CeilT());
|
std::unique_ptr<schema::CeilT> attr(new schema::CeilT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Ceil;
|
op->primitive->value.type = schema::PrimitiveType_Ceil;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "flOOR") == 0) {
|
} else if (std::strcmp(node_name, "flOOR") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteFloorParser";
|
MS_LOG(DEBUG) << "parse TfliteFloorParser";
|
||||||
std::unique_ptr<schema::FloorT> attr(new schema::FloorT());
|
std::unique_ptr<schema::FloorT> attr(new schema::FloorT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Floor;
|
op->primitive->value.type = schema::PrimitiveType_Floor;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "wrong op type";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -289,48 +238,47 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> node_name_str;
|
std::vector<std::string> node_name_str;
|
||||||
Split(op->name.data(), &node_name_str, "-");
|
Split(op->name, &node_name_str, "-");
|
||||||
const char *node_name = node_name_str.data()->c_str();
|
const char *node_name = node_name_str.data()->c_str();
|
||||||
if (std::strcmp(node_name, "Equal") == 0) {
|
if (std::strcmp(node_name, "Equal") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteEqualParser";
|
MS_LOG(DEBUG) << "parse TfliteEqualParser";
|
||||||
std::unique_ptr<schema::EqualT> attr(new schema::EqualT());
|
std::unique_ptr<schema::EqualT> attr(new schema::EqualT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Equal;
|
op->primitive->value.type = schema::PrimitiveType_Equal;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "NotEqual") == 0) {
|
} else if (std::strcmp(node_name, "NotEqual") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteNotEqualParser";
|
MS_LOG(DEBUG) << "parse TfliteNotEqualParser";
|
||||||
std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT());
|
std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_NotEqual;
|
op->primitive->value.type = schema::PrimitiveType_NotEqual;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Greater") == 0) {
|
} else if (std::strcmp(node_name, "Greater") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteGreaterParser";
|
MS_LOG(DEBUG) << "parse TfliteGreaterParser";
|
||||||
std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT());
|
std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Greater;
|
op->primitive->value.type = schema::PrimitiveType_Greater;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "GreaterEqual") == 0) {
|
} else if (std::strcmp(node_name, "GreaterEqual") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser";
|
MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser";
|
||||||
std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT());
|
std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_GreaterEqual;
|
op->primitive->value.type = schema::PrimitiveType_GreaterEqual;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "Less") == 0) {
|
} else if (std::strcmp(node_name, "Less") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteLessParser";
|
MS_LOG(DEBUG) << "parse TfliteLessParser";
|
||||||
std::unique_ptr<schema::LessT> attr(new schema::LessT());
|
std::unique_ptr<schema::LessT> attr(new schema::LessT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_Less;
|
op->primitive->value.type = schema::PrimitiveType_Less;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else if (std::strcmp(node_name, "LessEqual") == 0) {
|
} else if (std::strcmp(node_name, "LessEqual") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteLessEqualParser";
|
MS_LOG(DEBUG) << "parse TfliteLessEqualParser";
|
||||||
std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT());
|
std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT());
|
||||||
op->primitive->value.type = schema::PrimitiveType_LessEqual;
|
op->primitive->value.type = schema::PrimitiveType_LessEqual;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "wrong op type";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < tflite_op->inputs.size(); i++) {
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
}
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser());
|
TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser());
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PREDICT_TFLITE_MATH_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H
|
||||||
#define PREDICT_TFLITE_MATH_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -29,11 +30,13 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}
|
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TfliteAddParser : public TfliteDoubleInputOpParser {
|
class TfliteAddParser : public TfliteDoubleInputOpParser {
|
||||||
|
@ -96,11 +99,13 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}
|
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TfliteAbsParser : public TfliteSingleInputOpParser {
|
class TfliteAbsParser : public TfliteSingleInputOpParser {
|
||||||
|
@ -163,11 +168,13 @@ class TfliteCompareOpParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}
|
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TfliteEqualParser : public TfliteCompareOpParser {
|
class TfliteEqualParser : public TfliteCompareOpParser {
|
||||||
|
@ -203,5 +210,5 @@ class TfliteLessEqualParser : public TfliteCompareOpParser {
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // PREDICT_TFLITE_MATH_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H
|
||||||
|
|
||||||
|
|
|
@ -19,14 +19,17 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -38,30 +41,32 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> node_name_str;
|
std::vector<std::string> node_name_str;
|
||||||
Split(op->name.data(), &node_name_str, "-");
|
Split(op->name, &node_name_str, "-");
|
||||||
const char *node_name = node_name_str.data()->c_str();
|
const char *node_name = node_name_str.data()->c_str();
|
||||||
if (std::strcmp(node_name, "BatchToSpace") == 0) {
|
if (std::strcmp(node_name, "BatchToSpace") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
|
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
|
||||||
} else if (std::strcmp(node_name, "BatchToSpaceND") == 0) {
|
} else if (std::strcmp(node_name, "BatchToSpaceND") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser";
|
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser";
|
||||||
// in tflite
|
|
||||||
// blockShape should be a 1D tensor with dimension [spatial_dims_num]
|
|
||||||
// crops should be a 2D tensor with dimension [spatial_dims_num, 2]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());
|
std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());
|
||||||
|
|
||||||
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) {
|
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) {
|
||||||
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
|
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) {
|
if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) {
|
||||||
MS_LOG(ERROR) << "get batchToSpace -> crops failed";
|
MS_LOG(ERROR) << "get batchToSpace -> crops failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
|
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H
|
||||||
#define LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -31,8 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache, bool quantized_model) override;
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
|
class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
|
||||||
|
@ -43,4 +46,4 @@ class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H
|
||||||
|
|
|
@ -18,14 +18,19 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h"
|
#include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -36,16 +41,20 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
|
|
||||||
std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT());
|
std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT());
|
||||||
|
|
||||||
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dst_shape)) {
|
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) {
|
||||||
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
|
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
|
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef LITE_TFLITE_BROADCAST_TO_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
|
||||||
#define LITE_TFLITE_BROADCAST_TO_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -31,11 +32,12 @@ class TfliteBroadcastToParser : public TfliteNodeParser {
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantized_model) override;
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // LITE_TFLITE_BROADCAST_TO_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
|
||||||
|
|
|
@ -14,18 +14,22 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/converter/parser/tflite/tflite_cast_parser.h"
|
#include "tools/converter/parser/tflite/tflite_cast_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteCastParser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -36,25 +40,28 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteCastParser";
|
|
||||||
std::unique_ptr<schema::CastT> attr(new schema::CastT());
|
std::unique_ptr<schema::CastT> attr(new schema::CastT());
|
||||||
|
|
||||||
const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]];
|
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
|
||||||
if (in_tensor == nullptr) {
|
if (in_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "tensor is null";
|
MS_LOG(ERROR) << "tensor is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
attr->srcT = dtype_map[in_tensor->type];
|
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||||
|
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
|
||||||
const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
|
|
||||||
if (out_tensor == nullptr) {
|
if (out_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "tensor is null";
|
MS_LOG(ERROR) << "tensor is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
attr->dstT = dtype_map[out_tensor->type];
|
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Cast;
|
op->primitive->value.type = schema::PrimitiveType_Cast;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef LITE_TFLITE_CAST_PARSER_
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H
|
||||||
#define LITE_TFLITE_CAST_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -31,11 +32,12 @@ class TfliteCastParser : public TfliteNodeParser {
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantized_model) override;
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // LITE_TFLITE_CAST_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H
|
||||||
|
|
|
@ -17,14 +17,20 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_concat_parser.h"
|
#include "tools/converter/parser/tflite/tflite_concat_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteConcatParser";
|
||||||
|
|
||||||
|
// set attr
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -35,20 +41,25 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteConcatParser";
|
|
||||||
std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT());
|
std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT());
|
||||||
|
|
||||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsConcatenationOptions();
|
const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions();
|
||||||
if (tfliteAttr == nullptr) {
|
if (tfliteAttr == nullptr) {
|
||||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
attr->axis = tfliteAttr->axis;
|
attr->axis = tfliteAttr->axis;
|
||||||
|
attr->n = tflite_op->inputs.size();
|
||||||
attr->n = tfliteOp->inputs.size();
|
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Concat;
|
op->primitive->value.type = schema::PrimitiveType_Concat;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
for (int i = 0; i < tflite_op->inputs.size(); i++) {
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
}
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PREDICT_TFLITE_CONCAT_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H
|
||||||
#define PREDICT_TFLITE_CONCAT_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -28,15 +29,16 @@ class TfliteConcatParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteConcatParser() : TfliteNodeParser("Concat") {}
|
TfliteConcatParser() : TfliteNodeParser("Concat") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
schema::CNodeT *,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantizedModel) override;
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // PREDICT_TFLITE_CONCAT_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H
|
||||||
|
|
||||||
|
|
|
@ -17,14 +17,19 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_conv_parser.h"
|
#include "tools/converter/parser/tflite/tflite_conv_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteConvParser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -35,60 +40,61 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteConvParser";
|
|
||||||
std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT());
|
std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT());
|
||||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions();
|
const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions();
|
||||||
if (tfliteAttr == nullptr) {
|
if (tflite_attr == nullptr) {
|
||||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
attr->group = 1;
|
attr->group = 1;
|
||||||
attr->strideW = tfliteAttr->stride_w;
|
attr->strideW = tflite_attr->stride_w;
|
||||||
attr->strideH = tfliteAttr->stride_h;
|
attr->strideH = tflite_attr->stride_h;
|
||||||
attr->dilateH = tfliteAttr->dilation_h_factor;
|
attr->dilateH = tflite_attr->dilation_h_factor;
|
||||||
attr->dilateW = tfliteAttr->dilation_w_factor;
|
attr->dilateW = tflite_attr->dilation_w_factor;
|
||||||
attr->padMode = GetPadMode(tfliteAttr->padding);
|
attr->padMode = GetPadMode(tflite_attr->padding);
|
||||||
attr->format = schema::Format_NHWC;
|
attr->format = schema::Format_NHWC;
|
||||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
||||||
|
attr->hasBias = true;
|
||||||
|
|
||||||
// get the conv op weight tensor
|
// get the conv op weight tensor
|
||||||
auto weight_index = tfliteOp->inputs[1];
|
auto weight_index = tflite_op->inputs[1];
|
||||||
const auto &weight_tensor = tfliteTensors[weight_index];
|
const auto &weight_tensor = tflite_tensors[weight_index];
|
||||||
if (weight_tensor == nullptr) {
|
if (weight_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "weight_tensor is null";
|
MS_LOG(ERROR) << "the weight tensor is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
|
||||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
|
||||||
MS_LOG(ERROR) << "parse weight failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
auto weight_shape = weight_tensor->shape;
|
auto weight_shape = weight_tensor->shape;
|
||||||
attr->channelIn = weight_shape[KHWC_C];
|
attr->channelIn = weight_shape[3];
|
||||||
attr->channelOut = weight_shape[KHWC_K];
|
attr->channelOut = weight_shape[0];
|
||||||
attr->kernelW = weight_shape[KHWC_W];
|
attr->kernelH = weight_shape[1];
|
||||||
attr->kernelH = weight_shape[KHWC_H];
|
attr->kernelW = weight_shape[2];
|
||||||
|
|
||||||
// get the conv op bias tensor
|
|
||||||
if (tfliteOp->inputs.size() == 3) {
|
|
||||||
attr->hasBias = true;
|
|
||||||
auto bias_index = tfliteOp->inputs[2];
|
|
||||||
const auto &bias_tensor = tfliteTensors[bias_index];
|
|
||||||
if (bias_tensor == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "bias_tensor is null";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
|
|
||||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
|
||||||
MS_LOG(ERROR) << "parse bias failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// calculate pad params
|
// calculate pad params
|
||||||
|
auto data_index = tflite_op->inputs[0];
|
||||||
|
const auto &data_tensor = tflite_tensors[data_index];
|
||||||
|
std::vector<int> params;
|
||||||
|
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH,
|
||||||
|
attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "get padding params failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
} else {
|
||||||
|
attr->padUp = params.at(0);
|
||||||
|
attr->padDown = params.at(1);
|
||||||
|
attr->padLeft = params.at(2);
|
||||||
|
attr->padRight = params.at(3);
|
||||||
|
}
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PREDICT_TFLITE_CONV_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
|
||||||
#define PREDICT_TFLITE_CONV_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -28,15 +29,16 @@ class TfliteConvParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteConvParser() : TfliteNodeParser("Conv2D") {}
|
TfliteConvParser() : TfliteNodeParser("Conv2D") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantizedModel) override;
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // PREDICT_TFLITE_CONV_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/converter.h"
|
#include "tools/converter/converter.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_model_parser.h"
|
#include "tools/converter/parser/tflite/tflite_model_parser.h"
|
||||||
#include "tools/converter/graphdef_transform.h"
|
#include "tools/converter/graphdef_transform.h"
|
||||||
|
|
|
@ -17,14 +17,19 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_deconv_parser.h"
|
#include "tools/converter/parser/tflite/tflite_deconv_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -35,11 +40,10 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
|
|
||||||
std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT());
|
std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT());
|
||||||
const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions();
|
const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions();
|
||||||
if (tflite_attr == nullptr) {
|
if (tflite_attr == nullptr) {
|
||||||
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,26 +54,48 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
||||||
attr->dilateW = 1;
|
attr->dilateW = 1;
|
||||||
attr->padMode = GetPadMode(tflite_attr->padding);
|
attr->padMode = GetPadMode(tflite_attr->padding);
|
||||||
attr->format = schema::Format_NHWC;
|
attr->format = schema::Format_NHWC;
|
||||||
|
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||||
|
attr->hasBias = true;
|
||||||
|
|
||||||
// get the conv op weight tensor
|
// get the conv op weight tensor
|
||||||
auto weight_index = tfliteOp->inputs[1];
|
auto weight_index = tflite_op->inputs[1];
|
||||||
const auto &weight_tensor = tfliteTensors[weight_index];
|
const auto &weight_tensor = tflite_tensors[weight_index];
|
||||||
if (weight_tensor == nullptr) {
|
if (weight_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "weight_tensor is null";
|
MS_LOG(ERROR) << "the weight tensor is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
|
||||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
auto weight_shape = weight_tensor->shape;
|
auto weight_shape = weight_tensor->shape;
|
||||||
attr->channelIn = weight_shape[CHWK_K];
|
attr->channelIn = weight_shape[3];
|
||||||
attr->channelOut = weight_shape[CHWK_C];
|
attr->channelOut = weight_shape[0];
|
||||||
attr->kernelW = weight_shape[CHWK_W];
|
attr->kernelH = weight_shape[1];
|
||||||
attr->kernelH = weight_shape[CHWK_H];
|
attr->kernelW = weight_shape[2];
|
||||||
|
|
||||||
|
// calculate pad params
|
||||||
|
auto data_index = tflite_op->inputs[2];
|
||||||
|
const auto &data_tensor = tflite_tensors[data_index];
|
||||||
|
std::vector<int> params;
|
||||||
|
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH,
|
||||||
|
attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "get padding params failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
} else {
|
||||||
|
attr->padUp = params.at(0);
|
||||||
|
attr->padDown = params.at(1);
|
||||||
|
attr->padLeft = params.at(2);
|
||||||
|
attr->padRight = params.at(3);
|
||||||
|
}
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
|
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PREDICT_TFLITE_DECONV_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H
|
||||||
#define PREDICT_TFLITE_DECONV_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -31,11 +32,12 @@ class TfliteDeConvParser : public TfliteNodeParser {
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_op_set, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantizedModel) override;
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // PREDICT_TFLITE_DECONV_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H
|
||||||
|
|
|
@ -18,14 +18,19 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h"
|
#include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -36,20 +41,23 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
|
|
||||||
std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT());
|
std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT());
|
||||||
|
|
||||||
const auto &tflite_attr = tfliteOp->builtin_options.AsDepthToSpaceOptions();
|
const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions();
|
||||||
if (tflite_attr == nullptr) {
|
if (tflite_attr == nullptr) {
|
||||||
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
|
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
attr->blockSize = tflite_attr->block_size;
|
attr->blockSize = tflite_attr->block_size;
|
||||||
|
|
||||||
attr->format = schema::Format_NHWC;
|
attr->format = schema::Format_NHWC;
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
|
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
||||||
#define LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -31,11 +32,12 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser {
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantized_model) override;
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
||||||
|
|
|
@ -17,65 +17,22 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h"
|
#include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
#include "tools/common/node_util.h"
|
#include "tools/common/node_util.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op,
|
|
||||||
const std::unique_ptr<schema::DepthwiseConv2DT> &attr,
|
|
||||||
const std::unique_ptr<tflite::TensorT> &weightTensor,
|
|
||||||
TensorCache *tensor_cache) {
|
|
||||||
std::unique_ptr<schema::Conv2DT> convAttr(new schema::Conv2DT);
|
|
||||||
|
|
||||||
convAttr->format = attr->format;
|
|
||||||
convAttr->channelIn = attr->channelIn;
|
|
||||||
convAttr->channelOut = attr->channelIn * attr->channelMultiplier;
|
|
||||||
convAttr->kernelH = attr->kernelH;
|
|
||||||
convAttr->kernelW = attr->kernelW;
|
|
||||||
convAttr->strideH = attr->strideH;
|
|
||||||
convAttr->strideW = attr->strideW;
|
|
||||||
convAttr->padMode = attr->padMode;
|
|
||||||
convAttr->padUp = attr->padUp;
|
|
||||||
convAttr->padDown = attr->padDown;
|
|
||||||
convAttr->padLeft = attr->padLeft;
|
|
||||||
convAttr->padRight = attr->padRight;
|
|
||||||
convAttr->dilateH = attr->dilateH;
|
|
||||||
convAttr->dilateW = attr->dilateW;
|
|
||||||
convAttr->hasBias = attr->hasBias;
|
|
||||||
convAttr->activationType = attr->activationType;
|
|
||||||
|
|
||||||
auto weightTensorIndex = tensor_cache->FindTensor(weightTensor->name);
|
|
||||||
if (weightTensorIndex >= 0 && weightTensorIndex < tensor_cache->GetCachedTensor().size()) {
|
|
||||||
auto liteWeightTensor = tensor_cache->GetCachedTensor()[weightTensorIndex];
|
|
||||||
if (liteWeightTensor->dataType == TypeId::kNumberTypeUInt8) {
|
|
||||||
// convert weight format KHWC -> CHWK
|
|
||||||
auto status = TransFilterFormat<uint8_t>(liteWeightTensor, kKHWC2CHWK);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (liteWeightTensor->dataType == kNumberTypeFloat32 || liteWeightTensor->dataType == kNumberTypeFloat) {
|
|
||||||
// convert weight format KHWC -> CHWK
|
|
||||||
auto status = TransFilterFormat<float>(liteWeightTensor, kKHWC2CHWK);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
|
||||||
op->primitive->value.value = convAttr.release();
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -86,7 +43,6 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
|
|
||||||
std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT());
|
std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT());
|
||||||
const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions();
|
const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions();
|
||||||
if (tflite_attr == nullptr) {
|
if (tflite_attr == nullptr) {
|
||||||
|
@ -100,15 +56,20 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
|
||||||
attr->padMode = GetPadMode(tflite_attr->padding);
|
attr->padMode = GetPadMode(tflite_attr->padding);
|
||||||
attr->format = schema::Format_NHWC;
|
attr->format = schema::Format_NHWC;
|
||||||
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
||||||
// get the conv op weight tensor
|
attr->hasBias = true;
|
||||||
auto input_index = tflite_op->inputs[0];
|
attr->channelMultiplier = tflite_attr->depth_multiplier;
|
||||||
const auto &input_tenosr = tflite_tensors[input_index];
|
|
||||||
if (input_tenosr == nullptr) {
|
// get the data tensor
|
||||||
MS_LOG(ERROR) << "the first input is null";
|
auto data_index = tflite_op->inputs[1];
|
||||||
|
const auto &data_tensor = tflite_tensors[data_index];
|
||||||
|
if (data_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the data tensor is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
auto input_shape = input_tenosr->shape;
|
auto data_shape = data_tensor->shape;
|
||||||
|
attr->channelIn = data_shape[3];
|
||||||
|
|
||||||
|
// get the weight tensor
|
||||||
auto weight_index = tflite_op->inputs[1];
|
auto weight_index = tflite_op->inputs[1];
|
||||||
const auto &weight_tensor = tflite_tensors[weight_index];
|
const auto &weight_tensor = tflite_tensors[weight_index];
|
||||||
if (weight_tensor == nullptr) {
|
if (weight_tensor == nullptr) {
|
||||||
|
@ -116,38 +77,33 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
auto weight_shape = weight_tensor->shape;
|
auto weight_shape = weight_tensor->shape;
|
||||||
attr->channelIn = input_shape[KHWC_C];
|
attr->kernelH = weight_shape[1];
|
||||||
attr->channelMultiplier = tflite_attr->depth_multiplier;
|
attr->kernelW = weight_shape[2];
|
||||||
attr->kernelH = weight_shape[KHWC_H];
|
|
||||||
attr->kernelW = weight_shape[KHWC_W];
|
|
||||||
|
|
||||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
// calculate pad params
|
||||||
|
std::vector<int> params;
|
||||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW,
|
||||||
MS_LOG(ERROR) << "parse weight failed";
|
attr->kernelH, attr->kernelW, ¶ms) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "get padding params failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
|
||||||
|
|
||||||
if (tflite_op->inputs.size() == 3) {
|
|
||||||
attr->hasBias = true;
|
|
||||||
auto bias_index = tflite_op->inputs[2];
|
|
||||||
const auto &bias_tensor = tflite_tensors[bias_index];
|
|
||||||
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
|
|
||||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
|
||||||
MS_LOG(ERROR) << "parse bias failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (attr->channelMultiplier > 1) {
|
|
||||||
if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) {
|
|
||||||
MS_LOG(ERROR) << "Parse Group DepthwiseConv failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
attr->padUp = params.at(0);
|
||||||
op->primitive->value.value = attr.release();
|
attr->padDown = params.at(1);
|
||||||
|
attr->padLeft = params.at(2);
|
||||||
|
attr->padRight = params.at(3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||||
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H
|
||||||
#define PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -28,20 +29,16 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {}
|
TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) override;
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
private:
|
std::map<int, int> *tensors_id_map) override;
|
||||||
STATUS ParseGroupDepthwiseConv(schema::CNodeT *op,
|
|
||||||
const std::unique_ptr<schema::DepthwiseConv2DT> &attr,
|
|
||||||
const std::unique_ptr<tflite::TensorT> &weightTensor,
|
|
||||||
TensorCache *tensor_cache);
|
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // PREDICT_TFLITE_CONV_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
|
||||||
|
|
||||||
|
|
|
@ -16,15 +16,20 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_dequantize_parser.h"
|
#include "tools/converter/parser/tflite/tflite_dequantize_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
#include "tools/common/node_util.h"
|
#include "tools/common/node_util.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -35,32 +40,30 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";
|
|
||||||
std::unique_ptr<schema::CastT> attr(new schema::CastT);
|
std::unique_ptr<schema::CastT> attr(new schema::CastT);
|
||||||
|
|
||||||
// get the dequantize input tensor
|
// get the dequantize input tensor
|
||||||
const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]];
|
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
|
||||||
if (in_tensor == nullptr) {
|
if (in_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "weight_tensor is null";
|
MS_LOG(ERROR) << "input tensor is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
attr->srcT = dtype_map[in_tensor->type];
|
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||||
|
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
|
||||||
const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
|
|
||||||
if (out_tensor == nullptr) {
|
if (out_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "tensor is null";
|
MS_LOG(ERROR) << "output tensor is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
attr->dstT = dtype_map[out_tensor->type];
|
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||||
std::vector<tflite::TensorT *> weight_tensors{in_tensor.get()};
|
|
||||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
|
||||||
MS_LOG(ERROR) << "parse weight failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Fp16Cast;
|
op->primitive->value.type = schema::PrimitiveType_Fp16Cast;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return 0;
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser());
|
TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser());
|
||||||
|
|
|
@ -13,11 +13,12 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H
|
||||||
#define LITE_TFLITE_DEQUANTIZE_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -27,13 +28,15 @@ class TfliteDequantizeParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {}
|
TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
std::vector<int32_t> *tensors_id,
|
||||||
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // LITE_TFLITE_DEQUANTIZE_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H
|
||||||
|
|
|
@ -17,16 +17,17 @@
|
||||||
#include "tools/converter/parser/tflite/tflite_expand_dims_parser.h"
|
#include "tools/converter/parser/tflite/tflite_expand_dims_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op,
|
std::vector<int32_t> *tensors_id,
|
||||||
TensorCache *tensor_cache,
|
std::vector<schema::Format> *tensors_format,
|
||||||
bool quantizedModel) {
|
std::map<int, int> *tensors_id_map) {
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -40,7 +41,7 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
||||||
MS_LOG(DEBUG) << "parse TfliteExpandDimsParser";
|
MS_LOG(DEBUG) << "parse TfliteExpandDimsParser";
|
||||||
std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT());
|
std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT());
|
||||||
|
|
||||||
const auto &tflite_attr = tfliteOp->builtin_options.AsExpandDimsOptions();
|
const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions();
|
||||||
if (tflite_attr == nullptr) {
|
if (tflite_attr == nullptr) {
|
||||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -14,11 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H
|
||||||
#define PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
|
|
||||||
|
@ -28,15 +29,16 @@ class TfliteExpandDimsParser : public TfliteNodeParser {
|
||||||
public:
|
public:
|
||||||
TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {}
|
TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {}
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
schema::CNodeT *op,
|
||||||
TensorCache *tensor_cache,
|
std::vector<int32_t> *tensors_id,
|
||||||
bool quantizedModel) override;
|
std::vector<schema::Format> *tensors_format,
|
||||||
|
std::map<int, int> *tensors_id_map) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H
|
||||||
|
|
||||||
|
|
|
@ -1,75 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/parser/tflite/tflite_fakequant_parser.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
|
||||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
|
||||||
if (op == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "op is null";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
|
||||||
if (op->primitive == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "op->primitive is null";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser";
|
|
||||||
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
|
|
||||||
|
|
||||||
auto weight_index = tfliteOp->inputs[1];
|
|
||||||
const auto &weight_tensor = tfliteTensors[weight_index];
|
|
||||||
if (weight_tensor == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "weight_tensor is null";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
|
||||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
|
||||||
MS_LOG(ERROR) << "parse weight failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tfliteOp->inputs.size() == 3) {
|
|
||||||
attr->hasBias = true;
|
|
||||||
auto bias_index = tfliteOp->inputs[2];
|
|
||||||
const auto &bias_tensor = tfliteTensors[bias_index];
|
|
||||||
if (bias_tensor == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "bias_tensor is null";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
|
|
||||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
|
||||||
MS_LOG(ERROR) << "parse bias failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr->axis = 1;
|
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_FullConnection;
|
|
||||||
op->primitive->value.value = attr.release();
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser());
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,39 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef LITE_TFLITE_FAKEQUANT_PARSER_H
|
|
||||||
#define LITE_TFLITE_FAKEQUANT_PARSER_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
class TfliteFakeQuantParser : public TfliteNodeParser {
|
|
||||||
public:
|
|
||||||
TfliteFakeQuantParser() : TfliteNodeParser("FakeQuant") {}
|
|
||||||
|
|
||||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
|
||||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
|
||||||
};
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // LITE_TFLITE_FAKEQUANT_PARSER_H
|
|
|
@ -14,19 +14,22 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "tools/converter/parser/tflite/tflite_fill_parser.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "tools/converter/parser/tflite/tflite_fill_parser.h"
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
schema::CNodeT *op,
|
||||||
schema::CNodeT *op,
|
std::vector<int32_t> *tensors_id,
|
||||||
TensorCache *tensor_cache,
|
std::vector<schema::Format> *tensors_format,
|
||||||
bool quantizedModel) {
|
std::map<int, int> *tensors_id_map) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteFillParser";
|
||||||
|
|
||||||
if (op == nullptr) {
|
if (op == nullptr) {
|
||||||
MS_LOG(ERROR) << "op is null";
|
MS_LOG(ERROR) << "op is null";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -37,18 +40,22 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "parse TfliteFillParser";
|
|
||||||
std::unique_ptr<schema::FillT> attr(new schema::FillT());
|
std::unique_ptr<schema::FillT> attr(new schema::FillT());
|
||||||
|
|
||||||
if (tfliteOp->inputs.size() > 1) {
|
if (tflite_op->inputs.size() > 1) {
|
||||||
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dims)) {
|
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dims)) {
|
||||||
MS_LOG(ERROR) << "get Fill -> dims failed";
|
MS_LOG(ERROR) << "get fill -> dims failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Fill;
|
op->primitive->value.type = schema::PrimitiveType_Fill;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
||||||
|
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
|
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||||
|
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue