forked from mindspore-Ecosystem/mindspore
!12212 add pb parser
From: @yeyunpeng2020 Reviewed-by: @HilbertDavid Signed-off-by: @HilbertDavid
This commit is contained in:
commit
695052e88c
|
@ -276,6 +276,8 @@ union PrimitiveType {
|
|||
StridedSliceGrad,
|
||||
IsFinite,
|
||||
BatchMatMul,
|
||||
LinSpace,
|
||||
UniformReal
|
||||
}
|
||||
|
||||
enum QuantType: int {
|
||||
|
|
|
@ -1281,6 +1281,14 @@ table IsFinite {
|
|||
}
|
||||
|
||||
table BatchMatMul {
|
||||
adj_x : bool = false;
|
||||
adj_y : bool = false;
|
||||
transpose_a :bool;
|
||||
transpose_b :bool;
|
||||
}
|
||||
|
||||
table LinSpace {
|
||||
}
|
||||
|
||||
table UniformReal {
|
||||
seed : int;
|
||||
seed2 : int;
|
||||
}
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/batch_matmul.h"
|
||||
#include <memory>
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
@ -22,14 +22,17 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
bool BatchMatMul::GetAdjX() const { return this->primitive_->value.AsBatchMatMul()->adj_x; }
|
||||
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value.AsBatchMatMul()->transpose_a; }
|
||||
|
||||
void BatchMatMul::SetAdjX(bool adj_x) { this->primitive_->value.AsBatchMatMul()->adj_x = adj_x; }
|
||||
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value.AsBatchMatMul()->transpose_b; }
|
||||
|
||||
bool BatchMatMul::GetAdjY() const { return this->primitive_->value.AsBatchMatMul()->adj_y; }
|
||||
|
||||
void BatchMatMul::SetAdjY(bool adj_y) { this->primitive_->value.AsBatchMatMul()->adj_y = adj_y; }
|
||||
void BatchMatMul::SetTransposeA(bool transpose_a) {
|
||||
this->primitive_->value.AsBatchMatMul()->transpose_a = transpose_a;
|
||||
}
|
||||
|
||||
void BatchMatMul::SetTransposeB(bool transpose_b) {
|
||||
this->primitive_->value.AsBatchMatMul()->transpose_b = transpose_b;
|
||||
}
|
||||
int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
|
@ -51,31 +54,32 @@ int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
|||
this->primitive_ = nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->adj_x = GetValue<bool>(prim.GetAttr("adj_x"));
|
||||
attr->adj_y = GetValue<bool>(prim.GetAttr("adj_y"));
|
||||
attr->transpose_a = GetValue<bool>(prim.GetAttr("transpose_a"));
|
||||
attr->transpose_b = GetValue<bool>(prim.GetAttr("transpose_b"));
|
||||
this->primitive_->value.value = attr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value_as_BatchMatMul()->transpose_a(); }
|
||||
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value_as_BatchMatMul()->transpose_b(); }
|
||||
int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateBatchMatMul(*fbb);
|
||||
auto attr = primitive->value_as_BatchMatMul();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Add return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateBatchMatMul(*fbb, attr->transpose_a(), attr->transpose_b());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
bool BatchMatMul::GetAdjX() const { return this->primitive_->value_as_BatchMatMul()->adj_x(); }
|
||||
|
||||
bool BatchMatMul::GetAdjY() const { return this->primitive_->value_as_BatchMatMul()->adj_y(); }
|
||||
|
||||
PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive);
|
||||
}
|
||||
Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator);
|
||||
#endif
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,15 +32,14 @@ class BatchMatMul : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(BatchMatMul, PrimitiveC);
|
||||
explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetAdjX(bool adj_x);
|
||||
void SetAdjY(bool adj_y);
|
||||
void SetTransposeA(bool transpose_a);
|
||||
void SetTransposeB(bool transpose_b);
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
bool GetAdjX() const;
|
||||
bool GetAdjY() const;
|
||||
bool GetTransposeA() const;
|
||||
bool GetTransposeB() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/lin_space.h"
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
int LinSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateLinSpace(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LinSpace, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *LinSpaceCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<LinSpace>(primitive);
|
||||
}
|
||||
Registry LinSpaceRegistry(schema::PrimitiveType_LinSpace, LinSpaceCreator);
|
||||
#endif
|
||||
int LinSpace::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||
auto input = inputs.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input->data_type());
|
||||
output->set_format(input->format());
|
||||
auto num = inputs.at(2)->data_c();
|
||||
if (num == nullptr) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
output->set_shape({reinterpret_cast<int *>(num)[0]});
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class LinSpace : public PrimitiveC {
|
||||
public:
|
||||
LinSpace() = default;
|
||||
~LinSpace() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(LinSpace, PrimitiveC);
|
||||
explicit LinSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_
|
|
@ -170,8 +170,11 @@
|
|||
#include "src/ops/crop_and_resize.h"
|
||||
#include "src/ops/nonzero.h"
|
||||
#include "src/ops/erf.h"
|
||||
#include "src/ops/is_finite.h"
|
||||
#include "src/ops/batch_matmul.h"
|
||||
#include "src/ops/lin_space.h"
|
||||
#include "src/ops/uniform_real.h"
|
||||
#include "src/ops/rank.h"
|
||||
#include "src/ops/is_finite.h"
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
#include "src/ops/neg_grad.h"
|
||||
|
@ -1047,6 +1050,12 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new (std::nothrow) IsFinite(primitive);
|
||||
case schema::PrimitiveType_BatchMatMul:
|
||||
return new (std::nothrow) BatchMatMul(primitive);
|
||||
case schema::PrimitiveType_LinSpace:
|
||||
return new (std::nothrow) LinSpace(primitive);
|
||||
case schema::PrimitiveType_UniformReal:
|
||||
return new (std::nothrow) UniformReal(primitive);
|
||||
case schema::PrimitiveType_Rank:
|
||||
return new (std::nothrow) Rank(primitive);
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
return new (std::nothrow) ActivationGrad(primitive);
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/uniform_real.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int UniformReal::GetSeed() const { return this->primitive_->value.AsUniformReal()->seed; }
|
||||
|
||||
int UniformReal::GetSeed2() const { return this->primitive_->value.AsUniformReal()->seed2; }
|
||||
|
||||
int UniformReal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_UniformReal;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_UniformReal) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::UniformRealT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int UniformReal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_UniformReal();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_UniformReal return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateUniformReal(*fbb, attr->seed(), attr->seed2());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_UniformReal, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int UniformReal::GetSeed() const { return this->primitive_->value_as_UniformReal()->seed(); }
|
||||
|
||||
int UniformReal::GetSeed2() const { return this->primitive_->value_as_UniformReal()->seed2(); }
|
||||
|
||||
PrimitiveC *UniformRealCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<UniformReal>(primitive);
|
||||
}
|
||||
Registry UniformRealRegistry(schema::PrimitiveType_UniformReal, UniformRealCreator);
|
||||
#endif
|
||||
|
||||
int UniformReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
auto input_data = static_cast<int32_t *>(inputs_[0]->data_c());
|
||||
if (input_data == nullptr) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
auto input_num = inputs_[0]->ElementsNum();
|
||||
std::vector<int> output_shape(input_num);
|
||||
for (int i = 0; i < input_num; i++) {
|
||||
output_shape[i] = input_data[i];
|
||||
}
|
||||
outputs_[0]->set_shape(output_shape);
|
||||
outputs_[0]->set_data_type(kNumberTypeFloat32);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class UniformReal : public PrimitiveC {
|
||||
public:
|
||||
UniformReal() = default;
|
||||
~UniformReal() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(UniformReal, PrimitiveC);
|
||||
explicit UniformReal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
int GetSeed() const;
|
||||
int GetSeed2() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_
|
|
@ -52,6 +52,8 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
attr->type = schema::ActivationType_TANH;
|
||||
} else if (tf_op.op() == "LeakyRelu") {
|
||||
attr->type = schema::ActivationType_LEAKY_RELU;
|
||||
} else if (tf_op.op() == "Selu") {
|
||||
attr->type = schema::ActivationType_SELU;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
|
||||
return RET_ERROR;
|
||||
|
@ -63,7 +65,7 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
auto attr_leaky_relu = std::make_unique<schema::LeakyReLUT>();
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The attribute alpha shoud be specified.";
|
||||
MS_LOG(ERROR) << "The attribute alpha should be specified.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr_leaky_relu->negativeSlope = attr_value.f();
|
||||
|
@ -85,5 +87,6 @@ TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
|
|||
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
|
||||
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
|
||||
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser());
|
||||
TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,6 +58,11 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
status = CreateOperator<schema::SquareT>(primitive, schema::PrimitiveType_Square);
|
||||
} else if (tf_op.op() == "Pow") {
|
||||
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power);
|
||||
} else if (tf_op.op() == "Abs") {
|
||||
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Abs);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported arithmetic self type:" << tf_op.op();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
|
@ -85,5 +90,6 @@ TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser());
|
|||
TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser());
|
||||
TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser());
|
||||
TFNodeRegistrar g_tfPowParser("Pow", new TFArithmeticSelfParser());
|
||||
TFNodeRegistrar g_tfAbsParser("Abs", new TFArithmeticSelfParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,29 +22,35 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(DEBUG) << "TF BatchMatMulParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is nullptr";
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::BatchMatMulT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
tensorflow::AttrValue attr_value;
|
||||
TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value);
|
||||
attr->adj_x = attr_value.b();
|
||||
attr->adj_y = attr_value.b();
|
||||
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The begin_mask attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->transpose_a = attr_value.b();
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The begin_mask attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->transpose_b = attr_value.b();
|
||||
primitive->value.type = schema::PrimitiveType_BatchMatMul;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
|
@ -52,13 +58,15 @@ STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
inputs->emplace_back(tf_op.input(i));
|
||||
for (int i = 0; i < tf_op.input_size(); ++i) {
|
||||
auto status = AddOpInput(tf_op, i, inputs);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatmulParser());
|
||||
TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatMulParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,15 +23,14 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFBatchMatmulParser : public TFNodeParser {
|
||||
class TFBatchMatMulParser : public TFNodeParser {
|
||||
public:
|
||||
TFBatchMatmulParser() = default;
|
||||
~TFBatchMatmulParser() override = default;
|
||||
TFBatchMatMulParser() = default;
|
||||
~TFBatchMatMulParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_linspace_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFLinSpaceParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(DEBUG) << "TF LinSpaceParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::LinSpaceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_LinSpace;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); ++i) {
|
||||
auto status = AddOpInput(tf_op, i, inputs);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
TFNodeRegistrar g_tfLinSpaceParser("LinSpace", new TFLinSpaceParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFLinSpaceParser : public TFNodeParser {
|
||||
public:
|
||||
TFLinSpaceParser() = default;
|
||||
~TFLinSpaceParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_rank_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFRankParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(DEBUG) << "TF RankParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::RankT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Rank;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
TFNodeRegistrar g_tfRankParser("Rank", new TFRankParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFRankParser : public TFNodeParser {
|
||||
public:
|
||||
TFRankParser() = default;
|
||||
~TFRankParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_uniform_real_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFUniformRealParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(DEBUG) << "TF UniformRealParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::UniformRealT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "seed", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The seed attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->seed = attr_value.i();
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "seed2", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The seed2 attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->seed2 = attr_value.i();
|
||||
primitive->value.type = schema::PrimitiveType_UniformReal;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfRandomUniformParser("RandomUniform", new TFUniformRealParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFUniformRealParser : public TFNodeParser {
|
||||
public:
|
||||
TFUniformRealParser() = default;
|
||||
~TFUniformRealParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_
|
Loading…
Reference in New Issue