forked from mindspore-Ecosystem/mindspore
refactor tflite parsers.
append ut refactor tflite parsers modify tflite parser, ut and model supplement caffe flatten parser fix the weight tensor format of deconv bug fix bug when idx=-1 fix the weight tensor format of depthConv bug.
This commit is contained in:
parent
a0c12e7aa7
commit
123c2024a5
|
@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./
|
|||
TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/
|
||||
cp -fr $TEST_DATA_DIR/testPK ./data
|
||||
|
||||
#./lite-test --gtest_filter="*MindDataTestTensorDE*"
|
||||
#./lite-test --gtest_filter="*MindDataTestEager*"
|
||||
#
|
||||
#./lite-test --gtest_filter="TestTfliteParser*"
|
||||
#
|
||||
#./lite-test --gtest_filter="*TestHebing*"
|
||||
#
|
||||
#./lite-test --gtest_filter=TestFcFp32*
|
||||
#./lite-test --gtest_filter=TestConv1x1Fp32*
|
||||
#./lite-test --gtest_filter=TestStrassenFp32*
|
||||
#./lite-test --gtest_filter=TestDeConvolutionFp32*
|
||||
#
|
||||
#./lite-test --gtest_filter=TestPadInt8.*
|
||||
#./lite-test --gtest_filter=TestDeconvInt8.*
|
||||
./lite-test --gtest_filter="*MindDataTestTensorDE*"
|
||||
./lite-test --gtest_filter="*MindDataTestEager*"
|
||||
|
||||
./lite-test --gtest_filter="TestTfliteParser*"
|
||||
|
||||
./lite-test --gtest_filter="*TestHebing*"
|
||||
|
||||
./lite-test --gtest_filter=TestFcFp32*
|
||||
./lite-test --gtest_filter=TestConv1x1Fp32*
|
||||
./lite-test --gtest_filter=TestStrassenFp32*
|
||||
./lite-test --gtest_filter=TestDeConvolutionFp32*
|
||||
|
||||
./lite-test --gtest_filter=TestPadInt8.*
|
||||
./lite-test --gtest_filter=TestDeconvInt8.*
|
||||
|
||||
./lite-test --gtest_filter="TestTfliteParser*"
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -31,6 +31,12 @@ TEST_F(TestTfliteParserRelu, OpType) {
|
|||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserRelu, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
|
||||
ASSERT_EQ(val->type, schema::ActivationType_RELU);
|
||||
}
|
||||
|
||||
class TestTfliteParserRelu6 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserRelu6() = default;
|
||||
|
@ -43,6 +49,12 @@ TEST_F(TestTfliteParserRelu6, OpType) {
|
|||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserRelu6, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
|
||||
ASSERT_EQ(val->type, schema::ActivationType_RELU6);
|
||||
}
|
||||
|
||||
class TestTfliteParserTanh : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserTanh() = default;
|
||||
|
@ -55,7 +67,45 @@ TEST_F(TestTfliteParserTanh, OpType) {
|
|||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
||||
}
|
||||
|
||||
// logistic
|
||||
TEST_F(TestTfliteParserTanh, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
|
||||
ASSERT_EQ(val->type, schema::ActivationType_TANH);
|
||||
}
|
||||
|
||||
class TestTfliteParserLogistic : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserLogistic() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./logistic.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserLogistic, OpType) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
||||
}
|
||||
TEST_F(TestTfliteParserLogistic, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
|
||||
ASSERT_EQ(val->type, schema::ActivationType_SIGMOID);
|
||||
}
|
||||
|
||||
class TestTfliteParserHardSwish : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserHardSwish() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./hardswish.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserHardSwish, OpType) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type";
|
||||
}
|
||||
TEST_F(TestTfliteParserHardSwish, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsActivation();
|
||||
ASSERT_EQ(val->type, schema::ActivationType_SIGMOID);
|
||||
}
|
||||
|
||||
class TestTfliteParserPrelu : public TestTfliteParser {
|
||||
public:
|
||||
|
@ -73,12 +123,11 @@ TEST_F(TestTfliteParserPrelu, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserPrelu, AttrValue) {
|
||||
std::vector<float> slope(20, 0);
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsPrelu()->slope, slope);
|
||||
auto val = meta_graph->nodes.front()->primitive->value;
|
||||
std::vector<float> slope(20, 0);
|
||||
ASSERT_EQ(val.AsPrelu()->slope, slope);
|
||||
ASSERT_EQ(val.type, schema::PrimitiveType_Prelu);
|
||||
}
|
||||
|
||||
class TestTfliteParserLeakyRelu : public TestTfliteParser {
|
||||
|
@ -94,12 +143,10 @@ TEST_F(TestTfliteParserLeakyRelu, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserLeakyRelu, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsLeakyReLU();
|
||||
ASSERT_NE(val, nullptr);
|
||||
ASSERT_EQ(val->negativeSlope, 0.20000000298023224);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLeakyReLU(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value;
|
||||
ASSERT_EQ(val.AsLeakyReLU()->negativeSlope, 0.20000000298023224);
|
||||
ASSERT_EQ(val.type, schema::PrimitiveType_LeakyReLU);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,10 +35,8 @@ TEST_F(TestTfliteParserAddN, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserAddN, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsAddN()->N, 4);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsAddN();
|
||||
ASSERT_EQ(val->N, 4);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||
#include <iostream>
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestTfliteParserArgmax : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserArgmax() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./argmax.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserArgmax, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMax) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserArgmax, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMax(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsArgMax();
|
||||
ASSERT_EQ(val->axis, 1);
|
||||
ASSERT_EQ(val->topK, 1);
|
||||
ASSERT_EQ(val->axisType, 1);
|
||||
ASSERT_EQ(val->keepDims, false);
|
||||
ASSERT_EQ(val->outMaxValue, false);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -25,15 +25,14 @@ class TestTfliteParserArgmin : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserArgmin, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserArgmin, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMin(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsArgMin();
|
||||
ASSERT_EQ(val->axis, 1);
|
||||
ASSERT_EQ(val->topK, 1);
|
||||
|
|
|
@ -19,234 +19,57 @@
|
|||
|
||||
namespace mindspore {
|
||||
// doubleInputOp
|
||||
class TestTfliteParserAdd1 : public TestTfliteParser {
|
||||
class TestTfliteParserAdd : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserAdd1() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./add1.tflite", ""); }
|
||||
TestTfliteParserAdd() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./add.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserAdd1, OpType) {
|
||||
TEST_F(TestTfliteParserAdd, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserAdd1, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserAdd2 : public TestTfliteParser {
|
||||
class TestTfliteParserSub : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserAdd2() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./add2.tflite", ""); }
|
||||
TestTfliteParserSub() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./sub.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserAdd2, OpType) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserAdd2, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserAdd3 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserAdd3() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./add3.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserAdd3, OpType) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserAdd3, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserSub1 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserSub1() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./sub1.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserSub1, OpType) {
|
||||
TEST_F(TestTfliteParserSub, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserSub1, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserSub2 : public TestTfliteParser {
|
||||
class TestTfliteParserMul : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserSub2() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./sub2.tflite", ""); }
|
||||
TestTfliteParserMul() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./mul.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserSub2, OpType) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserSub2, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserSub3 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserSub3() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./sub3.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserSub3, OpType) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserSub3, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserMul1 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserMul1() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./mul1.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserMul1, OpType) {
|
||||
TEST_F(TestTfliteParserMul, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserMul1, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserMul2 : public TestTfliteParser {
|
||||
class TestTfliteParserDiv : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserMul2() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./mul2.tflite", ""); }
|
||||
TestTfliteParserDiv() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./div.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserMul2, OpType) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserMul2, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserMul3 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserMul3() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./mul3.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserMul3, OpType) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserMul3, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserDiv1 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserDiv1() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./div1.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserDiv1, OpType) {
|
||||
TEST_F(TestTfliteParserDiv, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserDiv1, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserDiv2 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserDiv2() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./div2.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserDiv2, OpType) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserDiv2, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserDiv3 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserDiv3() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./div3.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserDiv3, OpType) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserDiv3, Tensor) {
|
||||
ASSERT_GT(meta_graph->allTensors.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0);
|
||||
ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0);
|
||||
}
|
||||
|
||||
class TestTfliteParserFloorDiv : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserFloorDiv() = default;
|
||||
|
@ -254,6 +77,7 @@ class TestTfliteParserFloorDiv : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserFloorDiv, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type";
|
||||
|
@ -266,12 +90,26 @@ class TestTfliteParserFloorMod : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserFloorMod, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type";
|
||||
}
|
||||
|
||||
// realDiv
|
||||
class TestTfliteParserRealDiv : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserRealDiv() = default;
|
||||
void SetUp() override {
|
||||
meta_graph = LoadAndConvert("./realdiv.tflite");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserRealDiv, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type";
|
||||
}
|
||||
|
||||
class TestTfliteParserSquaredDifference : public TestTfliteParser {
|
||||
public:
|
||||
|
@ -296,17 +134,15 @@ class TestTfliteParserPow : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserPow, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserPow, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPower(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsPower();
|
||||
|
||||
ASSERT_EQ(val->scale, 1.0);
|
||||
ASSERT_EQ(val->shift, 0.0);
|
||||
ASSERT_EQ(val->power, 0.0);
|
||||
|
@ -477,6 +313,7 @@ class TestTfliteParserFloor : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserFloor, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type";
|
||||
|
|
|
@ -32,14 +32,12 @@ TEST_F(TestTfliteParserBatchToSpaceNd, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) {
|
||||
const std::vector<int> blockShape{2, 2};
|
||||
const std::vector<int> crops{0, 0, 2, 0};
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->blockShape, blockShape);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->crops, crops);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsBatchToSpace();
|
||||
const std::vector<int> blockShape = {2, 2};
|
||||
ASSERT_EQ(val->blockShape, blockShape);
|
||||
const std::vector<int> crops = {0, 0, 2, 0};
|
||||
ASSERT_EQ(val->crops, crops);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,12 +35,9 @@ TEST_F(TestTfliteParserCast, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserCast, AttrValue) {
|
||||
// float32 --> int32
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->srcT, 43);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->dstT, 34);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsCast();
|
||||
ASSERT_EQ(val->srcT, 43);
|
||||
ASSERT_EQ(val->dstT, 34);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||
#include <iostream>
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestTfliteParserConcat : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserConcat() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./concat.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserConcat, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Concat) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserConcat, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsConcat();
|
||||
ASSERT_EQ(val->axis, 1);
|
||||
ASSERT_EQ(val->n, 2);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||
#include <iostream>
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestTfliteParserConv : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserConv() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./conv.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserConv, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserConv, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->group, 1);
|
||||
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||
ASSERT_EQ(val->hasBias, true);
|
||||
ASSERT_EQ(val->channelIn, 1);
|
||||
ASSERT_EQ(val->channelOut, 4);
|
||||
ASSERT_EQ(val->kernelH, 3);
|
||||
ASSERT_EQ(val->kernelW, 3);
|
||||
ASSERT_EQ(val->strideH, 1);
|
||||
ASSERT_EQ(val->strideW, 1);
|
||||
ASSERT_EQ(val->dilateH, 1);
|
||||
ASSERT_EQ(val->dilateW, 1);
|
||||
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
|
||||
ASSERT_EQ(val->padUp, 1);
|
||||
ASSERT_EQ(val->padDown, 1);
|
||||
ASSERT_EQ(val->padLeft, 1);
|
||||
ASSERT_EQ(val->padRight, 1);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||
#include <iostream>
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestTfliteParserDeConv : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserDeConv() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./deconv.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserDeConv, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserDeConv, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D();
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->group, 1);
|
||||
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||
ASSERT_EQ(val->hasBias, true);
|
||||
|
||||
ASSERT_EQ(val->channelIn, 1);
|
||||
ASSERT_EQ(val->channelOut, 4);
|
||||
ASSERT_EQ(val->kernelH, 3);
|
||||
ASSERT_EQ(val->kernelW, 3);
|
||||
ASSERT_EQ(val->strideH, 1);
|
||||
ASSERT_EQ(val->strideW, 1);
|
||||
ASSERT_EQ(val->dilateH, 1);
|
||||
ASSERT_EQ(val->dilateW, 1);
|
||||
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
|
||||
ASSERT_EQ(val->padUp, 1);
|
||||
ASSERT_EQ(val->padDown, 1);
|
||||
ASSERT_EQ(val->padLeft, 1);
|
||||
ASSERT_EQ(val->padRight, 1);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -35,11 +35,9 @@ TEST_F(TestTfliteParserDepthToSpace, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserDepthToSpace, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->blockSize, 4);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->format, schema::Format_NHWC);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsDepthToSpace();
|
||||
ASSERT_EQ(val->blockSize, 4);
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||
#include <iostream>
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestTfliteParserDepthwiseConv1 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserDepthwiseConv1() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv1.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserDepthwiseConv1, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->group, 0);
|
||||
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||
ASSERT_EQ(val->hasBias, true);
|
||||
ASSERT_EQ(val->channelIn, 1);
|
||||
ASSERT_EQ(val->channelOut, 4);
|
||||
ASSERT_EQ(val->kernelH, 3);
|
||||
ASSERT_EQ(val->kernelW, 3);
|
||||
ASSERT_EQ(val->strideH, 1);
|
||||
ASSERT_EQ(val->strideW, 1);
|
||||
ASSERT_EQ(val->dilateH, 1);
|
||||
ASSERT_EQ(val->dilateW, 1);
|
||||
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
|
||||
ASSERT_EQ(val->padUp, 1);
|
||||
ASSERT_EQ(val->padDown, 1);
|
||||
ASSERT_EQ(val->padLeft, 1);
|
||||
ASSERT_EQ(val->padRight, 1);
|
||||
}
|
||||
|
||||
class TestTfliteParserDepthwiseConv2 : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserDepthwiseConv2() = default;
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv2.tflite", ""); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserDepthwiseConv2, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D();
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||
ASSERT_EQ(val->hasBias, true);
|
||||
ASSERT_EQ(val->channelIn, 2);
|
||||
ASSERT_EQ(val->channelMultiplier, 1);
|
||||
ASSERT_EQ(val->kernelH, 3);
|
||||
ASSERT_EQ(val->kernelW, 3);
|
||||
ASSERT_EQ(val->strideH, 1);
|
||||
ASSERT_EQ(val->strideW, 1);
|
||||
ASSERT_EQ(val->dilateH, 1);
|
||||
ASSERT_EQ(val->dilateW, 1);
|
||||
ASSERT_EQ(val->padMode, schema::PadMode_SAME);
|
||||
ASSERT_EQ(val->padUp, 1);
|
||||
ASSERT_EQ(val->padDown, 1);
|
||||
ASSERT_EQ(val->padLeft, 1);
|
||||
ASSERT_EQ(val->padRight, 1);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -25,17 +25,15 @@ class TestTfliteParserFill : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserFill, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserFill, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
TEST_F(TestTfliteParserFill, AttrValue) {;
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsFill(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsFill();
|
||||
|
||||
std::vector<int32_t> dims = {9};
|
||||
ASSERT_EQ(val->dims, dims);
|
||||
}
|
||||
|
|
|
@ -25,15 +25,14 @@ class TestTfliteParserGatherNd : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserGatherNd, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserGatherNd, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd();
|
||||
ASSERT_EQ(val->batchDims, 0);
|
||||
}
|
||||
|
|
|
@ -25,15 +25,14 @@ class TestTfliteParserGather : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserGather, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserGather, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGather(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsGather();
|
||||
ASSERT_EQ(val->axis, 0);
|
||||
ASSERT_EQ(val->batchDims, 0);
|
||||
|
|
|
@ -25,6 +25,7 @@ class TestTfliteParserLRN : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserLRN, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type,
|
||||
|
@ -32,9 +33,7 @@ TEST_F(TestTfliteParserLRN, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserLRN, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization();
|
||||
ASSERT_EQ(val->alpha, 1);
|
||||
ASSERT_EQ(val->beta, 0.5);
|
||||
|
|
|
@ -32,12 +32,9 @@ TEST_F(TestTfliteParserOneHot, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserOneHot, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr);
|
||||
// in OneHot parser axis = axis > 0 ? axis : axis + tensor_shape.size()
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsOneHot()->axis, 2);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsOneHot();
|
||||
ASSERT_EQ(val->axis, 2);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,17 +25,15 @@ class TestTfliteParserPad : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserPad, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserPad, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPad(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsPad();
|
||||
|
||||
std::vector<int32_t> paddings = {1, 1, 2, 2, 3, 3, 4, 4};
|
||||
ASSERT_EQ(val->paddings, paddings);
|
||||
}
|
||||
|
|
|
@ -35,12 +35,8 @@ TEST_F(TestTfliteParserMaxPooling, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserMaxPooling, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
|
||||
ASSERT_NE(val, nullptr);
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING);
|
||||
ASSERT_EQ(val->global, false);
|
||||
|
@ -72,12 +68,8 @@ TEST_F(TestTfliteParserAvgPooling, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserAvgPooling, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsPooling();
|
||||
ASSERT_NE(val, nullptr);
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING);
|
||||
ASSERT_EQ(val->global, false);
|
||||
|
|
|
@ -32,13 +32,9 @@ TEST_F(TestTfliteParserReduceMax, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserReduceMax, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
||||
ASSERT_NE(val, nullptr);
|
||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax) << "wrong reduce mode";
|
||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax);
|
||||
ASSERT_EQ(val->keepDims, false);
|
||||
std::vector<int32_t> axes = {2};
|
||||
ASSERT_EQ(val->axes, axes);
|
||||
|
@ -58,13 +54,9 @@ TEST_F(TestTfliteParserReduceMin, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserReduceMin, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
||||
ASSERT_NE(val, nullptr);
|
||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin) << "wrong reduce mode";
|
||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin);
|
||||
ASSERT_EQ(val->keepDims, false);
|
||||
std::vector<int32_t> axes = {2};
|
||||
ASSERT_EQ(val->axes, axes);
|
||||
|
@ -84,13 +76,9 @@ TEST_F(TestTfliteParserReduceProd, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserReduceProd, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
||||
ASSERT_NE(val, nullptr);
|
||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd) << "wrong reduce mode";
|
||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd);
|
||||
ASSERT_EQ(val->keepDims, false);
|
||||
std::vector<int32_t> axes = {2};
|
||||
ASSERT_EQ(val->axes, axes);
|
||||
|
@ -111,13 +99,9 @@ TEST_F(TestTfliteParserSum, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserSum, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
||||
ASSERT_NE(val, nullptr);
|
||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum) << "wrong reduce mode";
|
||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum);
|
||||
ASSERT_EQ(val->keepDims, false);
|
||||
std::vector<int32_t> axes = {2};
|
||||
ASSERT_EQ(val->axes, axes);
|
||||
|
@ -138,13 +122,9 @@ TEST_F(TestTfliteParserMean, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserMean, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsReduce();
|
||||
ASSERT_NE(val, nullptr);
|
||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean) << "wrong reduce mode";
|
||||
ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean);
|
||||
ASSERT_EQ(val->keepDims, true);
|
||||
std::vector<int32_t> axes = {2, 3};
|
||||
ASSERT_EQ(val->axes, axes);
|
||||
|
|
|
@ -35,12 +35,9 @@ TEST_F(TestTfliteParserReshape, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserReshape, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr);
|
||||
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsReshape();
|
||||
std::vector<int64_t> shape = {3, 5, 20};
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReshape()->shape, shape); // int32
|
||||
ASSERT_EQ(val->shape, shape);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,17 +26,15 @@ class TestTfliteParserResizeNN : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserResizeNN, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserResizeNN, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsResize();
|
||||
ASSERT_NE(val, nullptr);
|
||||
ASSERT_EQ(val->alignCorners, false);
|
||||
ASSERT_EQ(val->newHeight, 3);
|
||||
ASSERT_EQ(val->newWidth, 100);
|
||||
|
@ -52,17 +50,15 @@ class TestTfliteParserResizeBilinear : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserResizeBilinear, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserResizeBilinear, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsResize();
|
||||
ASSERT_NE(val, nullptr);
|
||||
ASSERT_EQ(val->alignCorners, false);
|
||||
ASSERT_EQ(val->newHeight, 75);
|
||||
ASSERT_EQ(val->newWidth, 4);
|
||||
|
|
|
@ -25,17 +25,15 @@ class TestTfliteParserReverse : public TestTfliteParser {
|
|||
};
|
||||
|
||||
TEST_F(TestTfliteParserReverse, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserReverse, AttrValue) {
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverse(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsReverse();
|
||||
|
||||
std::vector<int32_t> axis = {3};
|
||||
ASSERT_EQ(val->axis, axis);
|
||||
}
|
||||
|
|
|
@ -35,13 +35,11 @@ TEST_F(TestTfliteParserReverseSequence, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserReverseSequence, AttrValue) {
|
||||
std::vector<int> seq_length{7, 2, 3, 5};
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqLengths, seq_length);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsReverseSequence();
|
||||
ASSERT_EQ(val->seqAxis, 1);
|
||||
ASSERT_EQ(val->seqAxis, 1);
|
||||
std::vector<int> seq_length = {7, 2, 3, 5};
|
||||
ASSERT_EQ(val->seqLengths, seq_length);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||
#include <iostream>
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestTfliteParserSlice : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserSlice() = default;
|
||||
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./slice.tflite"); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserSlice, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Slice) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserSlice, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSlice(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsSlice();
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
std::vector<int32_t> begin = {1, 0, 0};
|
||||
ASSERT_EQ(val->begin, begin);
|
||||
std::vector<int32_t> size = {1, 1, 3};
|
||||
ASSERT_EQ(val->size, size);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSoftmax, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserSoftmax, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSoftMax()->axis, -1);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsSoftMax();
|
||||
ASSERT_EQ(val->axis, -1);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,13 +35,11 @@ TEST_F(TestTfliteParserSpaceToBatchND, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) {
|
||||
std::vector<int> blockshape{2, 2};
|
||||
std::vector<int> padding{0, 0, 2, 0};
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->blockShape, blockshape);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->paddings, padding);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND();
|
||||
std::vector<int> blockshape = {2, 2};
|
||||
ASSERT_EQ(val->blockShape, blockshape);
|
||||
std::vector<int> padding = {0, 0, 2, 0};
|
||||
ASSERT_EQ(val->paddings, padding);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSpaceToDepth, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserSpaceToDepth, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->blockSize, 2);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->format, schema::Format_NHWC);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToDepth();
|
||||
ASSERT_EQ(val->blockSize, 2);
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,16 +35,14 @@ TEST_F(TestTfliteParserSparseToDense, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserSparseToDense, AttrValue) {
|
||||
std::vector<int> outputShape{5, 5};
|
||||
std::vector<int> sparseValue{1};
|
||||
std::vector<int> defaultValue{0};
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->outputShape, outputShape);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->sparseValue, sparseValue);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->defaultValue, defaultValue);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->validateIndices, false);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsSparseToDense();
|
||||
std::vector<int> outputShape = {5, 5};
|
||||
ASSERT_EQ(val->outputShape, outputShape);
|
||||
std::vector<int> sparseValue = {1};
|
||||
ASSERT_EQ(val->sparseValue, sparseValue);
|
||||
std::vector<int> defaultValue = {0};
|
||||
ASSERT_EQ(val->defaultValue, defaultValue);
|
||||
ASSERT_EQ(val->validateIndices, false);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplit, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserSplit, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
|
||||
const std::vector<int> sizeSplits{2, 2};
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 2);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsSplit();
|
||||
ASSERT_EQ(val->splitDim, 2);
|
||||
ASSERT_EQ(val->numberSplit, 2);
|
||||
const std::vector<int> sizeSplits = {2, 2};
|
||||
ASSERT_EQ(val->sizeSplits, sizeSplits);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplitV, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserSplitV, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
|
||||
const std::vector<int> sizeSplits{1, 3};
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 0);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsSplit();
|
||||
ASSERT_EQ(val->splitDim, 0);
|
||||
ASSERT_EQ(val->numberSplit, 2);
|
||||
const std::vector<int> sizeSplits = {1, 3};
|
||||
ASSERT_EQ(val->sizeSplits, sizeSplits);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||
#include <iostream>
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestTfliteParserStack : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserStack() = default;
|
||||
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./stack.tflite"); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserStack, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Stack) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserStack, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStack(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsStack();
|
||||
ASSERT_EQ(val->axis, 1);
|
||||
ASSERT_EQ(val->n, 2);
|
||||
const std::vector<int> isScale = {3, 2, 3};
|
||||
ASSERT_EQ(val->isScale, isScale);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -35,21 +35,19 @@ TEST_F(TestTfliteParserStridedSlice, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserStridedSlice, AttrValue) {
|
||||
std::vector<int> begin{1, -1, 0};
|
||||
std::vector<int> end{2, -3, 3};
|
||||
std::vector<int> stride{1, -1, 1};
|
||||
std::vector<int> isscale{3, 2, 3};
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->endMask, 0);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->begin, begin);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->end, end);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->stride, stride);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->isScale, isscale);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsStridedSlice();
|
||||
ASSERT_EQ(val->beginMask, 0);
|
||||
ASSERT_EQ(val->endMask, 0);
|
||||
ASSERT_EQ(val->beginMask, 0);
|
||||
ASSERT_EQ(val->beginMask, 0);
|
||||
std::vector<int> begin = {1, -1, 0};
|
||||
ASSERT_EQ(val->begin, begin);
|
||||
std::vector<int> end = {2, -3, 3};
|
||||
ASSERT_EQ(val->end, end);
|
||||
std::vector<int> stride = {1, -1, 1};
|
||||
ASSERT_EQ(val->stride, stride);
|
||||
std::vector<int> isscale = {3, 2, 3};
|
||||
ASSERT_EQ(val->isScale, isscale);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,11 +35,9 @@ TEST_F(TestTfliteParserTile, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserTile, AttrValue) {
|
||||
std::vector<int> multiply{2, 3, 4};
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTile()->multiples, multiply);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsTile();
|
||||
std::vector<int> multiply = {2, 3, 4};
|
||||
ASSERT_EQ(val->multiples, multiply);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,13 +35,10 @@ TEST_F(TestTfliteParserTopKV2, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserTopKV2, AttrValue) {
|
||||
// attr->sorted default is true
|
||||
std::vector<int> k{3};
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->k, k);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->sorted, true);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2();
|
||||
std::vector<int> k = {3};
|
||||
ASSERT_EQ(val->k, k);
|
||||
ASSERT_EQ(val->sorted, true);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
|
||||
#include <iostream>
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestTfliteParserTranspose : public TestTfliteParser {
|
||||
public:
|
||||
TestTfliteParserTranspose() = default;
|
||||
|
||||
void SetUp() override { meta_graph = LoadAndConvert("./transpose.tflite"); }
|
||||
};
|
||||
|
||||
TEST_F(TestTfliteParserTranspose, OpType) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserTranspose, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTranspose(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsTranspose();
|
||||
ASSERT_EQ(val->conjugate, false);
|
||||
std::vector<int32_t> perm = {1, 0};
|
||||
ASSERT_EQ(val->perm, perm);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -35,10 +35,9 @@ TEST_F(TestTfliteParserUnique, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserUnique, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnique()->outType, 34); // int32
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsUnique();
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
|
||||
ASSERT_EQ(val->outType, 34);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,11 +35,9 @@ TEST_F(TestTfliteParserUnstack, OpType) {
|
|||
}
|
||||
|
||||
TEST_F(TestTfliteParserUnstack, AttrValue) {
|
||||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->num, 5);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->axis, 1);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsUnstack();
|
||||
ASSERT_EQ(val->num, 5);
|
||||
ASSERT_EQ(val->axis, 1);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -353,7 +353,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
|||
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
|
||||
} else if (weightTensor->format == schema::Format_KCHW) {
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
|
||||
} else if (weightTensor->format == schema::Format_CHWK) {
|
||||
} else if (weightTensor->format == schema::Format_CHWK) { // from tflite
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
||||
|
@ -369,7 +369,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
|||
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
|
||||
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
|
||||
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
|
||||
} else if (weightTensor->format == schema::Format_CHWK) { // from tflite
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
||||
|
|
|
@ -21,11 +21,16 @@ namespace lite {
|
|||
STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
|
||||
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
|
||||
if (op == nullptr) {
|
||||
// MS_LOGE("null pointer dereferencing.");
|
||||
// MS_LOG(ERROR) << "null pointer dereferencing.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT());
|
||||
attr->format = schema::Format_NCHW;
|
||||
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
|
||||
const caffe::FlattenParameter flattenParam = proto.flatten_param();
|
||||
|
||||
attr->axis = (int32_t)flattenParam.axis();
|
||||
attr->useAxis = true;
|
||||
attr->hasBias = false;
|
||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Flatten;
|
||||
|
|
|
@ -14,18 +14,21 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_activation_parser.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "tools/converter/parser/tflite/tflite_activation_parser.h"
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -35,13 +38,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
|||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
|
||||
|
||||
std::vector<std::string> node_name_str;
|
||||
Split(op->name, &node_name_str, "-");
|
||||
const char *node_name = node_name_str.data()->c_str();
|
||||
|
||||
if (std::strcmp(node_name, "Relu") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteReluParser";
|
||||
attr->type = schema::ActivationType_RELU;
|
||||
|
@ -54,29 +55,65 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
|||
} else if (std::strcmp(node_name, "Logistic") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
|
||||
attr->type = schema::ActivationType_SIGMOID;
|
||||
} else if (std::strcmp(node_name, "LeakyRelu") == 0) {
|
||||
const auto &option = tfliteOp->builtin_options.AsLeakyReluOptions();
|
||||
if (option == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->type = schema::ActivationType_LEAKY_RELU;
|
||||
attr->alpha = option->alpha;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "wrong activation type";
|
||||
return RET_ERROR;
|
||||
} else if (std::strcmp(node_name, "HardSwish") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
|
||||
attr->type = schema::ActivationType_SIGMOID;
|
||||
}
|
||||
|
||||
attr->alpha = 0.2f;
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) {
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TflitePreluParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::unique_ptr<schema::PreluT> attr(new schema::PreluT());
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
|
||||
MS_LOG(ERROR) << "get pRelu -> slope failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_Prelu;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLeakyReluParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -87,22 +124,29 @@ STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "paser TflitePreluParser";
|
||||
std::unique_ptr<schema::PreluT> attr(new schema::PreluT());
|
||||
std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT());
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
|
||||
MS_LOG(ERROR) << "get pRelu -> slope failed";
|
||||
return RET_ERROR;
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->negativeSlope = tflite_attr->alpha;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Prelu;
|
||||
op->primitive->value.type = schema::PrimitiveType_LeakyReLU;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
|
||||
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
|
||||
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
|
||||
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
|
||||
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
|
||||
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
|
||||
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
|
||||
|
|
|
@ -14,13 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PREDICT_TFLITE_RELU_PARSER_H
|
||||
#define PREDICT_TFLITE_RELU_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -29,11 +30,13 @@ class TfliteActivationParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteActivationParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
|
||||
class TfliteReluParser : public TfliteActivationParser {
|
||||
|
@ -56,9 +59,9 @@ class TfliteLogisticParser : public TfliteActivationParser {
|
|||
TfliteLogisticParser() : TfliteActivationParser() {}
|
||||
};
|
||||
|
||||
class TfliteLeakyReluParser : public TfliteActivationParser {
|
||||
class TfliteHardSwishParser : public TfliteActivationParser {
|
||||
public:
|
||||
TfliteLeakyReluParser() : TfliteActivationParser() {}
|
||||
TfliteHardSwishParser() : TfliteActivationParser() {}
|
||||
};
|
||||
|
||||
class TflitePreluParser : public TfliteNodeParser {
|
||||
|
@ -68,12 +71,27 @@ class TflitePreluParser : public TfliteNodeParser {
|
|||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache, bool quantized_model) override;
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
|
||||
class TfliteLeakyReluParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PREDICT_TFLITE_RELU_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
|
||||
|
||||
|
|
|
@ -18,14 +18,20 @@
|
|||
#include "tools/converter/parser/tflite/tflite_addn_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteAddNParser";
|
||||
|
||||
// set attr
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -36,13 +42,19 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteAddNParser";
|
||||
std::unique_ptr<schema::AddNT> attr(new schema::AddNT());
|
||||
|
||||
attr->N = tfliteTensors.size() - 1;
|
||||
|
||||
attr->N = tflite_tensors.size() - 1;
|
||||
op->primitive->value.type = schema::PrimitiveType_AddN;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
// set input
|
||||
for (int i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_TFLITE_ADDN_PARSER_H
|
||||
#define LITE_TFLITE_ADDN_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -31,11 +32,12 @@ class TfliteAddNParser : public TfliteNodeParser {
|
|||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantized_model) override;
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TFLITE_ADDN_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H
|
||||
|
|
|
@ -17,16 +17,19 @@
|
|||
#include "tools/converter/parser/tflite/tflite_argmax_parser.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantizedModel) {
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -37,7 +40,6 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
|
||||
std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());
|
||||
|
||||
attr->outMaxValue = false;
|
||||
|
@ -45,9 +47,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
attr->keepDims = false;
|
||||
attr->axisType = 1;
|
||||
|
||||
auto axis_idx = tfliteOp->inputs[1];
|
||||
std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){});
|
||||
auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer];
|
||||
// get axis attr
|
||||
auto axis_idx = tflite_op->inputs[1];
|
||||
std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){});
|
||||
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
|
||||
if (buf_data == nullptr) {
|
||||
MS_LOG(ERROR) << "the buf data is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -61,6 +64,11 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
|
||||
op->primitive->value.type = schema::PrimitiveType_ArgMax;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PREDICT_TFLITE_ARGMAX_PARSER_H
|
||||
#define PREDICT_TFLITE_ARGMAX_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -28,14 +29,15 @@ class TfliteArgmaxParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteArgmaxParser() : TfliteNodeParser("Argmax") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantizedModel) override;
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PREDICT_TFLITE_ARGMAX_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H
|
||||
|
|
|
@ -17,14 +17,19 @@
|
|||
#include "tools/converter/parser/tflite/tflite_argmin_parser.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteArgminParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -35,7 +40,6 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteArgminParser";
|
||||
std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT());
|
||||
|
||||
attr->outMaxValue = false;
|
||||
|
@ -43,9 +47,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
attr->keepDims = false;
|
||||
attr->axisType = 1;
|
||||
|
||||
auto axis_idx = tfliteOp->inputs[1];
|
||||
std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){});
|
||||
auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer];
|
||||
// get axis attr
|
||||
auto axis_idx = tflite_op->inputs[1];
|
||||
std::for_each(tflite_tensors[axis_idx]->shape.begin(),
|
||||
tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){});
|
||||
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
|
||||
if (buf_data == nullptr) {
|
||||
MS_LOG(ERROR) << "the buf data is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -59,6 +65,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
|
||||
op->primitive->value.type = schema::PrimitiveType_ArgMin;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PREDICT_TFLITE_ARGMIN_PARSER_H
|
||||
#define PREDICT_TFLITE_ARGMIN_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -28,14 +29,15 @@ class TfliteArgminParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteArgminParser() : TfliteNodeParser("Argmin") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantizedModel) override;
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PREDICT_TFLITE_ARGMIN_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H
|
||||
|
|
|
@ -18,14 +18,17 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -37,124 +40,72 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
|||
}
|
||||
|
||||
std::vector<std::string> node_name_str;
|
||||
Split(op->name.data(), &node_name_str, "-");
|
||||
Split(op->name, &node_name_str, "-");
|
||||
const char *node_name = node_name_str.data()->c_str();
|
||||
|
||||
if (std::strcmp(node_name, "Add") == 0
|
||||
|| std::strcmp(node_name, "Sub") == 0
|
||||
|| std::strcmp(node_name, "Mul") == 0
|
||||
|| std::strcmp(node_name, "Div") == 0) {
|
||||
auto x_index = tfliteOp->inputs[0];
|
||||
const auto &x_tensor = tfliteTensors[x_index];
|
||||
if (x_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the first input is null";
|
||||
if (std::strcmp(node_name, "Add") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteAddParser";
|
||||
std::unique_ptr<schema::AddT> attr(new schema::AddT());
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto &x_data = tfliteModelBuffer.at(x_tensor->buffer);
|
||||
if (x_data == nullptr) {
|
||||
MS_LOG(ERROR) << "the data of the first input is null";
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
op->primitive->value.type = schema::PrimitiveType_Add;
|
||||
op->primitive->value.value = attr.release();
|
||||
} else if (std::strcmp(node_name, "Sub") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSubParser";
|
||||
std::unique_ptr<schema::SubT> attr(new schema::SubT());
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (!x_data->data.empty()) {
|
||||
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
|
||||
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse the first tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
auto y_index = tfliteOp->inputs[1];
|
||||
const auto &y_tensor = tfliteTensors[y_index];
|
||||
if (y_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the second input is null";
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
op->primitive->value.type = schema::PrimitiveType_Sub;
|
||||
op->primitive->value.value = attr.release();
|
||||
} else if (std::strcmp(node_name, "Mul") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteMulParser";
|
||||
std::unique_ptr<schema::MulT> attr(new schema::MulT());
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
|
||||
if (y_data == nullptr) {
|
||||
MS_LOG(ERROR) << "the data of the second input is null";
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
op->primitive->value.type = schema::PrimitiveType_Mul;
|
||||
op->primitive->value.value = attr.release();
|
||||
} else if (std::strcmp(node_name, "Div") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteDivParser";
|
||||
std::unique_ptr<schema::DivT> attr(new schema::DivT());
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (!y_data->data.empty()) {
|
||||
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
|
||||
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse the second tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::strcmp(node_name, "Add") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteAddParser";
|
||||
std::unique_ptr<schema::AddT> attr(new schema::AddT());
|
||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
op->primitive->value.type = schema::PrimitiveType_Add;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Sub") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSubParser";
|
||||
std::unique_ptr<schema::SubT> attr(new schema::SubT());
|
||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
op->primitive->value.type = schema::PrimitiveType_Sub;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Mul") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteMulParser";
|
||||
std::unique_ptr<schema::MulT> attr(new schema::MulT());
|
||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
op->primitive->value.type = schema::PrimitiveType_Mul;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Div") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteDivParser";
|
||||
std::unique_ptr<schema::DivT> attr(new schema::DivT());
|
||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
op->primitive->value.type = schema::PrimitiveType_Div;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
op->primitive->value.type = schema::PrimitiveType_Div;
|
||||
op->primitive->value.value = attr.release();
|
||||
} else if (std::strcmp(node_name, "FloorDiv") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteFloorDivParser";
|
||||
std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT());
|
||||
op->primitive->value.type = schema::PrimitiveType_FloorDiv;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "FloorMod") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteFloorModParser";
|
||||
std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT());
|
||||
op->primitive->value.type = schema::PrimitiveType_FloorMod;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "RealDiv") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteRealDivParser";
|
||||
std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT());
|
||||
op->primitive->value.type = schema::PrimitiveType_RealDiv;
|
||||
op->primitive->value.type = schema::PrimitiveType_Div;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "SquaredDifference") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser";
|
||||
std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT());
|
||||
op->primitive->value.type = schema::PrimitiveType_SquaredDifference;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Pow") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TflitePowParser";
|
||||
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
|
||||
|
@ -163,31 +114,35 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
|||
attr->shift = 0.0f;
|
||||
op->primitive->value.type = schema::PrimitiveType_Power;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Maximum") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteMaximumParser";
|
||||
std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Maximum;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Minimum") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteMinimumParser";
|
||||
std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Minimum;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "wrong op type";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// set input
|
||||
for (int i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -199,85 +154,79 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
|||
}
|
||||
|
||||
std::vector<std::string> node_name_str;
|
||||
Split(op->name.data(), &node_name_str, "-");
|
||||
Split(op->name, &node_name_str, "-");
|
||||
const char *node_name = node_name_str.data()->c_str();
|
||||
if (std::strcmp(node_name, "Abs") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteAbsParser";
|
||||
std::unique_ptr<schema::AbsT> attr(new schema::AbsT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Abs;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Exp") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteExpParser";
|
||||
std::unique_ptr<schema::ExpT> attr(new schema::ExpT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Exp;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Sqrt") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSqrtParser";
|
||||
std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Sqrt;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Rsqrt") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteRsqrtParser";
|
||||
std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Rsqrt;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Square") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSquareParser";
|
||||
std::unique_ptr<schema::SquareT> attr(new schema::SquareT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Square;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Sin") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSinParser";
|
||||
std::unique_ptr<schema::SinT> attr(new schema::SinT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Sin;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Cos") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteCosParser";
|
||||
std::unique_ptr<schema::CosT> attr(new schema::CosT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Cos;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Log") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLogParser";
|
||||
std::unique_ptr<schema::LogT> attr(new schema::LogT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Log;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Round") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteRoundParser";
|
||||
std::unique_ptr<schema::RoundT> attr(new schema::RoundT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Round;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Ceil") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteCeilParser";
|
||||
std::unique_ptr<schema::CeilT> attr(new schema::CeilT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Ceil;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "flOOR") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteFloorParser";
|
||||
std::unique_ptr<schema::FloorT> attr(new schema::FloorT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Floor;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "wrong op type";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -289,48 +238,47 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
|
|||
}
|
||||
|
||||
std::vector<std::string> node_name_str;
|
||||
Split(op->name.data(), &node_name_str, "-");
|
||||
Split(op->name, &node_name_str, "-");
|
||||
const char *node_name = node_name_str.data()->c_str();
|
||||
if (std::strcmp(node_name, "Equal") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteEqualParser";
|
||||
std::unique_ptr<schema::EqualT> attr(new schema::EqualT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Equal;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "NotEqual") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteNotEqualParser";
|
||||
std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT());
|
||||
op->primitive->value.type = schema::PrimitiveType_NotEqual;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Greater") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteGreaterParser";
|
||||
std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Greater;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "GreaterEqual") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser";
|
||||
std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT());
|
||||
op->primitive->value.type = schema::PrimitiveType_GreaterEqual;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "Less") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLessParser";
|
||||
std::unique_ptr<schema::LessT> attr(new schema::LessT());
|
||||
op->primitive->value.type = schema::PrimitiveType_Less;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else if (std::strcmp(node_name, "LessEqual") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLessEqualParser";
|
||||
std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT());
|
||||
op->primitive->value.type = schema::PrimitiveType_LessEqual;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "wrong op type";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
for (int i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser());
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PREDICT_TFLITE_MATH_PARSER_H
|
||||
#define PREDICT_TFLITE_MATH_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -29,11 +30,13 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
|
||||
class TfliteAddParser : public TfliteDoubleInputOpParser {
|
||||
|
@ -96,11 +99,13 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
|
||||
class TfliteAbsParser : public TfliteSingleInputOpParser {
|
||||
|
@ -163,11 +168,13 @@ class TfliteCompareOpParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
|
||||
class TfliteEqualParser : public TfliteCompareOpParser {
|
||||
|
@ -203,5 +210,5 @@ class TfliteLessEqualParser : public TfliteCompareOpParser {
|
|||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PREDICT_TFLITE_MATH_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H
|
||||
|
||||
|
|
|
@ -19,14 +19,17 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -38,30 +41,32 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
|||
}
|
||||
|
||||
std::vector<std::string> node_name_str;
|
||||
Split(op->name.data(), &node_name_str, "-");
|
||||
Split(op->name, &node_name_str, "-");
|
||||
const char *node_name = node_name_str.data()->c_str();
|
||||
if (std::strcmp(node_name, "BatchToSpace") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
|
||||
} else if (std::strcmp(node_name, "BatchToSpaceND") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser";
|
||||
// in tflite
|
||||
// blockShape should be a 1D tensor with dimension [spatial_dims_num]
|
||||
// crops should be a 2D tensor with dimension [spatial_dims_num, 2]
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());
|
||||
|
||||
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) {
|
||||
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) {
|
||||
if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) {
|
||||
MS_LOG(ERROR) << "get batchToSpace -> crops failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
|
||||
#define LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -31,8 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
|
|||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache, bool quantized_model) override;
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
|
||||
class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
|
||||
|
@ -43,4 +46,4 @@ class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
|
|||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TFLITE_BATCH_TO_SPACE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H
|
||||
|
|
|
@ -18,14 +18,19 @@
|
|||
#include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -36,16 +41,20 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
|
||||
std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT());
|
||||
|
||||
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dst_shape)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) {
|
||||
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_TFLITE_BROADCAST_TO_PARSER_H
|
||||
#define LITE_TFLITE_BROADCAST_TO_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -31,11 +32,12 @@ class TfliteBroadcastToParser : public TfliteNodeParser {
|
|||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantized_model) override;
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TFLITE_BROADCAST_TO_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
|
||||
|
|
|
@ -14,18 +14,22 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_cast_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteCastParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -36,25 +40,28 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteCastParser";
|
||||
std::unique_ptr<schema::CastT> attr(new schema::CastT());
|
||||
|
||||
const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]];
|
||||
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->srcT = dtype_map[in_tensor->type];
|
||||
|
||||
const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->dstT = dtype_map[out_tensor->type];
|
||||
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Cast;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_TFLITE_CAST_PARSER_
|
||||
#define LITE_TFLITE_CAST_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -31,11 +32,12 @@ class TfliteCastParser : public TfliteNodeParser {
|
|||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantized_model) override;
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TFLITE_CAST_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H
|
||||
|
|
|
@ -17,14 +17,20 @@
|
|||
#include "tools/converter/parser/tflite/tflite_concat_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteConcatParser";
|
||||
|
||||
// set attr
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -35,20 +41,25 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteConcatParser";
|
||||
std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT());
|
||||
|
||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsConcatenationOptions();
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions();
|
||||
if (tfliteAttr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->axis = tfliteAttr->axis;
|
||||
|
||||
attr->n = tfliteOp->inputs.size();
|
||||
attr->n = tflite_op->inputs.size();
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Concat;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
for (int i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PREDICT_TFLITE_CONCAT_PARSER_H
|
||||
#define PREDICT_TFLITE_CONCAT_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -28,15 +29,16 @@ class TfliteConcatParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteConcatParser() : TfliteNodeParser("Concat") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantizedModel) override;
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PREDICT_TFLITE_CONCAT_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H
|
||||
|
||||
|
|
|
@ -17,14 +17,19 @@
|
|||
#include "tools/converter/parser/tflite/tflite_conv_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteConvParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -35,60 +40,61 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteConvParser";
|
||||
std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT());
|
||||
const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions();
|
||||
if (tfliteAttr == nullptr) {
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->group = 1;
|
||||
attr->strideW = tfliteAttr->stride_w;
|
||||
attr->strideH = tfliteAttr->stride_h;
|
||||
attr->dilateH = tfliteAttr->dilation_h_factor;
|
||||
attr->dilateW = tfliteAttr->dilation_w_factor;
|
||||
attr->padMode = GetPadMode(tfliteAttr->padding);
|
||||
attr->strideW = tflite_attr->stride_w;
|
||||
attr->strideH = tflite_attr->stride_h;
|
||||
attr->dilateH = tflite_attr->dilation_h_factor;
|
||||
attr->dilateW = tflite_attr->dilation_w_factor;
|
||||
attr->padMode = GetPadMode(tflite_attr->padding);
|
||||
attr->format = schema::Format_NHWC;
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
||||
attr->hasBias = true;
|
||||
|
||||
// get the conv op weight tensor
|
||||
auto weight_index = tfliteOp->inputs[1];
|
||||
const auto &weight_tensor = tfliteTensors[weight_index];
|
||||
auto weight_index = tflite_op->inputs[1];
|
||||
const auto &weight_tensor = tflite_tensors[weight_index];
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "weight_tensor is null";
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
||||
MS_LOG(ERROR) << "parse weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_shape = weight_tensor->shape;
|
||||
attr->channelIn = weight_shape[KHWC_C];
|
||||
attr->channelOut = weight_shape[KHWC_K];
|
||||
attr->kernelW = weight_shape[KHWC_W];
|
||||
attr->kernelH = weight_shape[KHWC_H];
|
||||
|
||||
// get the conv op bias tensor
|
||||
if (tfliteOp->inputs.size() == 3) {
|
||||
attr->hasBias = true;
|
||||
auto bias_index = tfliteOp->inputs[2];
|
||||
const auto &bias_tensor = tfliteTensors[bias_index];
|
||||
if (bias_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "bias_tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse bias failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
attr->channelIn = weight_shape[3];
|
||||
attr->channelOut = weight_shape[0];
|
||||
attr->kernelH = weight_shape[1];
|
||||
attr->kernelW = weight_shape[2];
|
||||
|
||||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[0];
|
||||
const auto &data_tensor = tflite_tensors[data_index];
|
||||
std::vector<int> params;
|
||||
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH,
|
||||
attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
attr->padRight = params.at(3);
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PREDICT_TFLITE_CONV_PARSER_H
|
||||
#define PREDICT_TFLITE_CONV_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -28,15 +29,16 @@ class TfliteConvParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteConvParser() : TfliteNodeParser("Conv2D") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantizedModel) override;
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PREDICT_TFLITE_CONV_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/converter/converter.h"
|
||||
#include "tools/converter/parser/tflite/tflite_model_parser.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
|
|
|
@ -17,14 +17,19 @@
|
|||
#include "tools/converter/parser/tflite/tflite_deconv_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -35,11 +40,10 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
|
||||
std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT());
|
||||
const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions();
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
|
@ -50,26 +54,48 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
attr->dilateW = 1;
|
||||
attr->padMode = GetPadMode(tflite_attr->padding);
|
||||
attr->format = schema::Format_NHWC;
|
||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
attr->hasBias = true;
|
||||
|
||||
// get the conv op weight tensor
|
||||
auto weight_index = tfliteOp->inputs[1];
|
||||
const auto &weight_tensor = tfliteTensors[weight_index];
|
||||
auto weight_index = tflite_op->inputs[1];
|
||||
const auto &weight_tensor = tflite_tensors[weight_index];
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "weight_tensor is null";
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_shape = weight_tensor->shape;
|
||||
attr->channelIn = weight_shape[CHWK_K];
|
||||
attr->channelOut = weight_shape[CHWK_C];
|
||||
attr->kernelW = weight_shape[CHWK_W];
|
||||
attr->kernelH = weight_shape[CHWK_H];
|
||||
attr->channelIn = weight_shape[3];
|
||||
attr->channelOut = weight_shape[0];
|
||||
attr->kernelH = weight_shape[1];
|
||||
attr->kernelW = weight_shape[2];
|
||||
|
||||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[2];
|
||||
const auto &data_tensor = tflite_tensors[data_index];
|
||||
std::vector<int> params;
|
||||
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH,
|
||||
attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
attr->padRight = params.at(3);
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PREDICT_TFLITE_DECONV_PARSER_H
|
||||
#define PREDICT_TFLITE_DECONV_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -31,11 +32,12 @@ class TfliteDeConvParser : public TfliteNodeParser {
|
|||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_op_set, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantizedModel) override;
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PREDICT_TFLITE_DECONV_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H
|
||||
|
|
|
@ -18,14 +18,19 @@
|
|||
#include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -36,20 +41,23 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
|
||||
std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT());
|
||||
|
||||
const auto &tflite_attr = tfliteOp->builtin_options.AsDepthToSpaceOptions();
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str();
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->blockSize = tflite_attr->block_size;
|
||||
|
||||
attr->format = schema::Format_NHWC;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
||||
#define LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -31,11 +32,12 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser {
|
|||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantized_model) override;
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H
|
||||
|
|
|
@ -17,65 +17,22 @@
|
|||
#include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/common/node_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op,
|
||||
const std::unique_ptr<schema::DepthwiseConv2DT> &attr,
|
||||
const std::unique_ptr<tflite::TensorT> &weightTensor,
|
||||
TensorCache *tensor_cache) {
|
||||
std::unique_ptr<schema::Conv2DT> convAttr(new schema::Conv2DT);
|
||||
|
||||
convAttr->format = attr->format;
|
||||
convAttr->channelIn = attr->channelIn;
|
||||
convAttr->channelOut = attr->channelIn * attr->channelMultiplier;
|
||||
convAttr->kernelH = attr->kernelH;
|
||||
convAttr->kernelW = attr->kernelW;
|
||||
convAttr->strideH = attr->strideH;
|
||||
convAttr->strideW = attr->strideW;
|
||||
convAttr->padMode = attr->padMode;
|
||||
convAttr->padUp = attr->padUp;
|
||||
convAttr->padDown = attr->padDown;
|
||||
convAttr->padLeft = attr->padLeft;
|
||||
convAttr->padRight = attr->padRight;
|
||||
convAttr->dilateH = attr->dilateH;
|
||||
convAttr->dilateW = attr->dilateW;
|
||||
convAttr->hasBias = attr->hasBias;
|
||||
convAttr->activationType = attr->activationType;
|
||||
|
||||
auto weightTensorIndex = tensor_cache->FindTensor(weightTensor->name);
|
||||
if (weightTensorIndex >= 0 && weightTensorIndex < tensor_cache->GetCachedTensor().size()) {
|
||||
auto liteWeightTensor = tensor_cache->GetCachedTensor()[weightTensorIndex];
|
||||
if (liteWeightTensor->dataType == TypeId::kNumberTypeUInt8) {
|
||||
// convert weight format KHWC -> CHWK
|
||||
auto status = TransFilterFormat<uint8_t>(liteWeightTensor, kKHWC2CHWK);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (liteWeightTensor->dataType == kNumberTypeFloat32 || liteWeightTensor->dataType == kNumberTypeFloat) {
|
||||
// convert weight format KHWC -> CHWK
|
||||
auto status = TransFilterFormat<float>(liteWeightTensor, kKHWC2CHWK);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
op->primitive->value.value = convAttr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -86,7 +43,6 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
|
||||
std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT());
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
|
@ -100,15 +56,20 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
|
|||
attr->padMode = GetPadMode(tflite_attr->padding);
|
||||
attr->format = schema::Format_NHWC;
|
||||
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
||||
// get the conv op weight tensor
|
||||
auto input_index = tflite_op->inputs[0];
|
||||
const auto &input_tenosr = tflite_tensors[input_index];
|
||||
if (input_tenosr == nullptr) {
|
||||
MS_LOG(ERROR) << "the first input is null";
|
||||
attr->hasBias = true;
|
||||
attr->channelMultiplier = tflite_attr->depth_multiplier;
|
||||
|
||||
// get the data tensor
|
||||
auto data_index = tflite_op->inputs[1];
|
||||
const auto &data_tensor = tflite_tensors[data_index];
|
||||
if (data_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the data tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto input_shape = input_tenosr->shape;
|
||||
auto data_shape = data_tensor->shape;
|
||||
attr->channelIn = data_shape[3];
|
||||
|
||||
// get the weight tensor
|
||||
auto weight_index = tflite_op->inputs[1];
|
||||
const auto &weight_tensor = tflite_tensors[weight_index];
|
||||
if (weight_tensor == nullptr) {
|
||||
|
@ -116,38 +77,33 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
auto weight_shape = weight_tensor->shape;
|
||||
attr->channelIn = input_shape[KHWC_C];
|
||||
attr->channelMultiplier = tflite_attr->depth_multiplier;
|
||||
attr->kernelH = weight_shape[KHWC_H];
|
||||
attr->kernelW = weight_shape[KHWC_W];
|
||||
attr->kernelH = weight_shape[1];
|
||||
attr->kernelW = weight_shape[2];
|
||||
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
||||
MS_LOG(ERROR) << "parse weight failed";
|
||||
// calculate pad params
|
||||
std::vector<int> params;
|
||||
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW,
|
||||
attr->kernelH, attr->kernelW, ¶ms) != RET_OK) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (tflite_op->inputs.size() == 3) {
|
||||
attr->hasBias = true;
|
||||
auto bias_index = tflite_op->inputs[2];
|
||||
const auto &bias_tensor = tflite_tensors[bias_index];
|
||||
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse bias failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (attr->channelMultiplier > 1) {
|
||||
if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) {
|
||||
MS_LOG(ERROR) << "Parse Group DepthwiseConv failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
attr->padRight = params.at(3);
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC);
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H
|
||||
#define PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -28,20 +29,16 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) override;
|
||||
|
||||
private:
|
||||
STATUS ParseGroupDepthwiseConv(schema::CNodeT *op,
|
||||
const std::unique_ptr<schema::DepthwiseConv2DT> &attr,
|
||||
const std::unique_ptr<tflite::TensorT> &weightTensor,
|
||||
TensorCache *tensor_cache);
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PREDICT_TFLITE_CONV_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
|
||||
|
||||
|
|
|
@ -16,15 +16,20 @@
|
|||
#include "tools/converter/parser/tflite/tflite_dequantize_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/common/node_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -35,32 +40,30 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";
|
||||
std::unique_ptr<schema::CastT> attr(new schema::CastT);
|
||||
|
||||
// get the dequantize input tensor
|
||||
const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]];
|
||||
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "weight_tensor is null";
|
||||
MS_LOG(ERROR) << "input tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->srcT = dtype_map[in_tensor->type];
|
||||
|
||||
const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]];
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
MS_LOG(ERROR) << "output tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->dstT = dtype_map[out_tensor->type];
|
||||
std::vector<tflite::TensorT *> weight_tensors{in_tensor.get()};
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
||||
MS_LOG(ERROR) << "parse weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Fp16Cast;
|
||||
op->primitive->value.value = attr.release();
|
||||
return 0;
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser());
|
||||
|
|
|
@ -13,11 +13,12 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H
|
||||
#define LITE_TFLITE_DEQUANTIZE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -27,13 +28,15 @@ class TfliteDequantizeParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TFLITE_DEQUANTIZE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H
|
||||
|
|
|
@ -17,16 +17,17 @@
|
|||
#include "tools/converter/parser/tflite/tflite_expand_dims_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantizedModel) {
|
||||
STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -40,7 +41,7 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
|||
MS_LOG(DEBUG) << "parse TfliteExpandDimsParser";
|
||||
std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT());
|
||||
|
||||
const auto &tflite_attr = tfliteOp->builtin_options.AsExpandDimsOptions();
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
|
||||
#define PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
|
@ -28,15 +29,16 @@ class TfliteExpandDimsParser : public TfliteNodeParser {
|
|||
public:
|
||||
TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantizedModel) override;
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PREDICT_TFLITE_EXPAND_DIMS_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H
|
||||
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tflite/tflite_fakequant_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser";
|
||||
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
|
||||
|
||||
auto weight_index = tfliteOp->inputs[1];
|
||||
const auto &weight_tensor = tfliteTensors[weight_index];
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "weight_tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
||||
MS_LOG(ERROR) << "parse weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (tfliteOp->inputs.size() == 3) {
|
||||
attr->hasBias = true;
|
||||
auto bias_index = tfliteOp->inputs[2];
|
||||
const auto &bias_tensor = tfliteTensors[bias_index];
|
||||
if (bias_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "bias_tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse bias failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
attr->axis = 1;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_FullConnection;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,39 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef LITE_TFLITE_FAKEQUANT_PARSER_H
|
||||
#define LITE_TFLITE_FAKEQUANT_PARSER_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TfliteFakeQuantParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteFakeQuantParser() : TfliteNodeParser("FakeQuant") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
|
||||
TensorCache *tensor_cache, bool quantizedModel) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TFLITE_FAKEQUANT_PARSER_H
|
|
@ -14,19 +14,22 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_fill_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/tflite/tflite_fill_parser.h"
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
|
||||
schema::CNodeT *op,
|
||||
TensorCache *tensor_cache,
|
||||
bool quantizedModel) {
|
||||
STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteFillParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -37,18 +40,22 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "parse TfliteFillParser";
|
||||
std::unique_ptr<schema::FillT> attr(new schema::FillT());
|
||||
|
||||
if (tfliteOp->inputs.size() > 1) {
|
||||
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dims)) {
|
||||
MS_LOG(ERROR) << "get Fill -> dims failed";
|
||||
if (tflite_op->inputs.size() > 1) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dims)) {
|
||||
MS_LOG(ERROR) << "get fill -> dims failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Fill;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue