!12212 add pb parser

From: @yeyunpeng2020
Reviewed-by: @HilbertDavid
Signed-off-by: @HilbertDavid
This commit is contained in:
mindspore-ci-bot 2021-02-09 21:47:18 +08:00 committed by Gitee
commit 695052e88c
19 changed files with 616 additions and 40 deletions

View File

@ -276,6 +276,8 @@ union PrimitiveType {
StridedSliceGrad,
IsFinite,
BatchMatMul,
LinSpace,
UniformReal
}
enum QuantType: int {

View File

@ -1281,6 +1281,14 @@ table IsFinite {
}
table BatchMatMul {
adj_x : bool = false;
adj_y : bool = false;
transpose_a :bool;
transpose_b :bool;
}
table LinSpace {
}
table UniformReal {
seed : int;
seed2 : int;
}

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/batch_matmul.h"
#include <memory>
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
@ -22,14 +22,17 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool BatchMatMul::GetAdjX() const { return this->primitive_->value.AsBatchMatMul()->adj_x; }
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value.AsBatchMatMul()->transpose_a; }
void BatchMatMul::SetAdjX(bool adj_x) { this->primitive_->value.AsBatchMatMul()->adj_x = adj_x; }
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value.AsBatchMatMul()->transpose_b; }
bool BatchMatMul::GetAdjY() const { return this->primitive_->value.AsBatchMatMul()->adj_y; }
void BatchMatMul::SetAdjY(bool adj_y) { this->primitive_->value.AsBatchMatMul()->adj_y = adj_y; }
void BatchMatMul::SetTransposeA(bool transpose_a) {
this->primitive_->value.AsBatchMatMul()->transpose_a = transpose_a;
}
void BatchMatMul::SetTransposeB(bool transpose_b) {
this->primitive_->value.AsBatchMatMul()->transpose_b = transpose_b;
}
int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
@ -51,31 +54,32 @@ int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
this->primitive_ = nullptr;
return RET_ERROR;
}
attr->adj_x = GetValue<bool>(prim.GetAttr("adj_x"));
attr->adj_y = GetValue<bool>(prim.GetAttr("adj_y"));
attr->transpose_a = GetValue<bool>(prim.GetAttr("transpose_a"));
attr->transpose_b = GetValue<bool>(prim.GetAttr("transpose_b"));
this->primitive_->value.value = attr;
}
return RET_OK;
}
#else
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value_as_BatchMatMul()->transpose_a(); }
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value_as_BatchMatMul()->transpose_b(); }
int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateBatchMatMul(*fbb);
auto attr = primitive->value_as_BatchMatMul();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Add return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateBatchMatMul(*fbb, attr->transpose_a(), attr->transpose_b());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
bool BatchMatMul::GetAdjX() const { return this->primitive_->value_as_BatchMatMul()->adj_x(); }
bool BatchMatMul::GetAdjY() const { return this->primitive_->value_as_BatchMatMul()->adj_y(); }
PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive);
}
Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator);
#endif
} // namespace lite
} // namespace mindspore

View File

@ -32,15 +32,14 @@ class BatchMatMul : public PrimitiveC {
MS_DECLARE_PARENT(BatchMatMul, PrimitiveC);
explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetAdjX(bool adj_x);
void SetAdjY(bool adj_y);
void SetTransposeA(bool transpose_a);
void SetTransposeB(bool transpose_b);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
bool GetAdjX() const;
bool GetAdjY() const;
bool GetTransposeA() const;
bool GetTransposeB() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/lin_space.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int LinSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateLinSpace(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LinSpace, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *LinSpaceCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<LinSpace>(primitive);
}
Registry LinSpaceRegistry(schema::PrimitiveType_LinSpace, LinSpaceCreator);
#endif
int LinSpace::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
auto input = inputs.front();
MS_ASSERT(input != nullptr);
auto output = outputs.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->set_format(input->format());
auto num = inputs.at(2)->data_c();
if (num == nullptr) {
return RET_INFER_INVALID;
}
output->set_shape({reinterpret_cast<int *>(num)[0]});
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
#ifndef LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_
#define LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_
namespace mindspore {
namespace lite {
class LinSpace : public PrimitiveC {
public:
LinSpace() = default;
~LinSpace() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(LinSpace, PrimitiveC);
explicit LinSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_

View File

@ -170,8 +170,11 @@
#include "src/ops/crop_and_resize.h"
#include "src/ops/nonzero.h"
#include "src/ops/erf.h"
#include "src/ops/is_finite.h"
#include "src/ops/batch_matmul.h"
#include "src/ops/lin_space.h"
#include "src/ops/uniform_real.h"
#include "src/ops/rank.h"
#include "src/ops/is_finite.h"
#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@ -1047,6 +1050,12 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) IsFinite(primitive);
case schema::PrimitiveType_BatchMatMul:
return new (std::nothrow) BatchMatMul(primitive);
case schema::PrimitiveType_LinSpace:
return new (std::nothrow) LinSpace(primitive);
case schema::PrimitiveType_UniformReal:
return new (std::nothrow) UniformReal(primitive);
case schema::PrimitiveType_Rank:
return new (std::nothrow) Rank(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);

View File

@ -0,0 +1,101 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/uniform_real.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int UniformReal::GetSeed() const { return this->primitive_->value.AsUniformReal()->seed; }
int UniformReal::GetSeed2() const { return this->primitive_->value.AsUniformReal()->seed2; }
int UniformReal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_UniformReal;
}
if (this->primitive_->value.type != schema::PrimitiveType_UniformReal) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::UniformRealT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int UniformReal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_UniformReal();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_UniformReal return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateUniformReal(*fbb, attr->seed(), attr->seed2());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_UniformReal, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int UniformReal::GetSeed() const { return this->primitive_->value_as_UniformReal()->seed(); }
int UniformReal::GetSeed2() const { return this->primitive_->value_as_UniformReal()->seed2(); }
PrimitiveC *UniformRealCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<UniformReal>(primitive);
}
Registry UniformRealRegistry(schema::PrimitiveType_UniformReal, UniformRealCreator);
#endif
int UniformReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
auto input_data = static_cast<int32_t *>(inputs_[0]->data_c());
if (input_data == nullptr) {
return RET_INFER_INVALID;
}
auto input_num = inputs_[0]->ElementsNum();
std::vector<int> output_shape(input_num);
for (int i = 0; i < input_num; i++) {
output_shape[i] = input_data[i];
}
outputs_[0]->set_shape(output_shape);
outputs_[0]->set_data_type(kNumberTypeFloat32);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_
#define LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class UniformReal : public PrimitiveC {
public:
UniformReal() = default;
~UniformReal() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(UniformReal, PrimitiveC);
explicit UniformReal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetSeed() const;
int GetSeed2() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_

View File

@ -52,6 +52,8 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
attr->type = schema::ActivationType_TANH;
} else if (tf_op.op() == "LeakyRelu") {
attr->type = schema::ActivationType_LEAKY_RELU;
} else if (tf_op.op() == "Selu") {
attr->type = schema::ActivationType_SELU;
} else {
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
return RET_ERROR;
@ -63,7 +65,7 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
auto attr_leaky_relu = std::make_unique<schema::LeakyReLUT>();
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
MS_LOG(ERROR) << "The attribute alpha shoud be specified.";
MS_LOG(ERROR) << "The attribute alpha should be specified.";
return RET_ERROR;
}
attr_leaky_relu->negativeSlope = attr_value.f();
@ -85,5 +87,6 @@ TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser());
TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser());
} // namespace lite
} // namespace mindspore

View File

@ -58,6 +58,11 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op,
status = CreateOperator<schema::SquareT>(primitive, schema::PrimitiveType_Square);
} else if (tf_op.op() == "Pow") {
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power);
} else if (tf_op.op() == "Abs") {
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Abs);
} else {
MS_LOG(ERROR) << "unsupported arithmetic self type:" << tf_op.op();
return RET_ERROR;
}
if (status != RET_OK) {
return status;
@ -85,5 +90,6 @@ TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfPowParser("Pow", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfAbsParser("Abs", new TFArithmeticSelfParser());
} // namespace lite
} // namespace mindspore

View File

@ -22,29 +22,35 @@
namespace mindspore {
namespace lite {
STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op,
STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF BatchMatMulParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::BatchMatMulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value);
attr->adj_x = attr_value.b();
attr->adj_y = attr_value.b();
if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value)) {
MS_LOG(ERROR) << "The begin_mask attr should be specified";
return RET_ERROR;
}
attr->transpose_a = attr_value.b();
if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) {
MS_LOG(ERROR) << "The begin_mask attr should be specified";
return RET_ERROR;
}
attr->transpose_b = attr_value.b();
primitive->value.type = schema::PrimitiveType_BatchMatMul;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
@ -52,13 +58,15 @@ STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
for (int i = 0; i < tf_op.input_size(); ++i) {
auto status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
return status;
}
}
return RET_OK;
}
TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatmulParser());
TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatMulParser());
} // namespace lite
} // namespace mindspore

View File

@ -23,15 +23,14 @@
namespace mindspore {
namespace lite {
class TFBatchMatmulParser : public TFNodeParser {
class TFBatchMatMulParser : public TFNodeParser {
public:
TFBatchMatmulParser() = default;
~TFBatchMatmulParser() override = default;
TFBatchMatMulParser() = default;
~TFBatchMatMulParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_linspace_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFLinSpaceParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF LinSpaceParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::LinSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_LinSpace;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); ++i) {
auto status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
return status;
}
}
return RET_OK;
}
TFNodeRegistrar g_tfLinSpaceParser("LinSpace", new TFLinSpaceParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFLinSpaceParser : public TFNodeParser {
public:
TFLinSpaceParser() = default;
~TFLinSpaceParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_rank_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFRankParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF RankParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::RankT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Rank;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
return RET_OK;
}
TFNodeRegistrar g_tfRankParser("Rank", new TFRankParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFRankParser : public TFNodeParser {
public:
TFRankParser() = default;
~TFRankParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_

View File

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_uniform_real_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFUniformRealParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF UniformRealParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::UniformRealT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "seed", &attr_value)) {
MS_LOG(ERROR) << "The seed attr should be specified";
return RET_ERROR;
}
attr->seed = attr_value.i();
if (!TensorFlowUtils::FindAttrValue(tf_op, "seed2", &attr_value)) {
MS_LOG(ERROR) << "The seed2 attr should be specified";
return RET_ERROR;
}
attr->seed2 = attr_value.i();
primitive->value.type = schema::PrimitiveType_UniformReal;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfRandomUniformParser("RandomUniform", new TFUniformRealParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFUniformRealParser : public TFNodeParser {
public:
TFUniformRealParser() = default;
~TFUniformRealParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_