forked from mindspore-Ecosystem/mindspore
!12212 add pb parser
From: @yeyunpeng2020 Reviewed-by: @HilbertDavid Signed-off-by: @HilbertDavid
This commit is contained in:
commit
695052e88c
|
@ -276,6 +276,8 @@ union PrimitiveType {
|
||||||
StridedSliceGrad,
|
StridedSliceGrad,
|
||||||
IsFinite,
|
IsFinite,
|
||||||
BatchMatMul,
|
BatchMatMul,
|
||||||
|
LinSpace,
|
||||||
|
UniformReal
|
||||||
}
|
}
|
||||||
|
|
||||||
enum QuantType: int {
|
enum QuantType: int {
|
||||||
|
|
|
@ -1281,6 +1281,14 @@ table IsFinite {
|
||||||
}
|
}
|
||||||
|
|
||||||
table BatchMatMul {
|
table BatchMatMul {
|
||||||
adj_x : bool = false;
|
transpose_a :bool;
|
||||||
adj_y : bool = false;
|
transpose_b :bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
table LinSpace {
|
||||||
|
}
|
||||||
|
|
||||||
|
table UniformReal {
|
||||||
|
seed : int;
|
||||||
|
seed2 : int;
|
||||||
}
|
}
|
|
@ -13,8 +13,8 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/ops/batch_matmul.h"
|
#include "src/ops/batch_matmul.h"
|
||||||
#include <memory>
|
|
||||||
#ifndef PRIMITIVE_WRITEABLE
|
#ifndef PRIMITIVE_WRITEABLE
|
||||||
#include "src/ops/ops_register.h"
|
#include "src/ops/ops_register.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -22,14 +22,17 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
#ifdef PRIMITIVE_WRITEABLE
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
bool BatchMatMul::GetAdjX() const { return this->primitive_->value.AsBatchMatMul()->adj_x; }
|
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value.AsBatchMatMul()->transpose_a; }
|
||||||
|
|
||||||
void BatchMatMul::SetAdjX(bool adj_x) { this->primitive_->value.AsBatchMatMul()->adj_x = adj_x; }
|
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value.AsBatchMatMul()->transpose_b; }
|
||||||
|
|
||||||
bool BatchMatMul::GetAdjY() const { return this->primitive_->value.AsBatchMatMul()->adj_y; }
|
void BatchMatMul::SetTransposeA(bool transpose_a) {
|
||||||
|
this->primitive_->value.AsBatchMatMul()->transpose_a = transpose_a;
|
||||||
void BatchMatMul::SetAdjY(bool adj_y) { this->primitive_->value.AsBatchMatMul()->adj_y = adj_y; }
|
}
|
||||||
|
|
||||||
|
void BatchMatMul::SetTransposeB(bool transpose_b) {
|
||||||
|
this->primitive_->value.AsBatchMatMul()->transpose_b = transpose_b;
|
||||||
|
}
|
||||||
int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||||
if (this->primitive_ == nullptr) {
|
if (this->primitive_ == nullptr) {
|
||||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||||
|
@ -51,31 +54,32 @@ int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
||||||
this->primitive_ = nullptr;
|
this->primitive_ = nullptr;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
attr->adj_x = GetValue<bool>(prim.GetAttr("adj_x"));
|
attr->transpose_a = GetValue<bool>(prim.GetAttr("transpose_a"));
|
||||||
attr->adj_y = GetValue<bool>(prim.GetAttr("adj_y"));
|
attr->transpose_b = GetValue<bool>(prim.GetAttr("transpose_b"));
|
||||||
this->primitive_->value.value = attr;
|
this->primitive_->value.value = attr;
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value_as_BatchMatMul()->transpose_a(); }
|
||||||
|
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value_as_BatchMatMul()->transpose_b(); }
|
||||||
int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||||
MS_ASSERT(nullptr != primitive);
|
MS_ASSERT(nullptr != primitive);
|
||||||
MS_ASSERT(nullptr != fbb);
|
MS_ASSERT(nullptr != fbb);
|
||||||
auto val_offset = schema::CreateBatchMatMul(*fbb);
|
auto attr = primitive->value_as_BatchMatMul();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value_as_Add return nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto val_offset = schema::CreateBatchMatMul(*fbb, attr->transpose_a(), attr->transpose_b());
|
||||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o);
|
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o);
|
||||||
fbb->Finish(prim_offset);
|
fbb->Finish(prim_offset);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
bool BatchMatMul::GetAdjX() const { return this->primitive_->value_as_BatchMatMul()->adj_x(); }
|
|
||||||
|
|
||||||
bool BatchMatMul::GetAdjY() const { return this->primitive_->value_as_BatchMatMul()->adj_y(); }
|
|
||||||
|
|
||||||
PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) {
|
PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) {
|
||||||
return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive);
|
return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive);
|
||||||
}
|
}
|
||||||
Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator);
|
Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -32,15 +32,14 @@ class BatchMatMul : public PrimitiveC {
|
||||||
MS_DECLARE_PARENT(BatchMatMul, PrimitiveC);
|
MS_DECLARE_PARENT(BatchMatMul, PrimitiveC);
|
||||||
explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||||
void SetAdjX(bool adj_x);
|
void SetTransposeA(bool transpose_a);
|
||||||
void SetAdjY(bool adj_y);
|
void SetTransposeB(bool transpose_b);
|
||||||
#else
|
#else
|
||||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||||
#endif
|
#endif
|
||||||
bool GetAdjX() const;
|
bool GetTransposeA() const;
|
||||||
bool GetAdjY() const;
|
bool GetTransposeB() const;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_
|
#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/ops/lin_space.h"
|
||||||
|
#ifndef PRIMITIVE_WRITEABLE
|
||||||
|
#include "src/ops/ops_register.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
#ifndef PRIMITIVE_WRITEABLE
|
||||||
|
int LinSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||||
|
MS_ASSERT(nullptr != primitive);
|
||||||
|
MS_ASSERT(nullptr != fbb);
|
||||||
|
auto val_offset = schema::CreateLinSpace(*fbb);
|
||||||
|
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LinSpace, val_offset.o);
|
||||||
|
fbb->Finish(prim_offset);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
PrimitiveC *LinSpaceCreator(const schema::Primitive *primitive) {
|
||||||
|
return PrimitiveC::NewPrimitiveC<LinSpace>(primitive);
|
||||||
|
}
|
||||||
|
Registry LinSpaceRegistry(schema::PrimitiveType_LinSpace, LinSpaceCreator);
|
||||||
|
#endif
|
||||||
|
int LinSpace::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||||
|
auto input = inputs.front();
|
||||||
|
MS_ASSERT(input != nullptr);
|
||||||
|
auto output = outputs.front();
|
||||||
|
MS_ASSERT(output != nullptr);
|
||||||
|
output->set_data_type(input->data_type());
|
||||||
|
output->set_format(input->format());
|
||||||
|
auto num = inputs.at(2)->data_c();
|
||||||
|
if (num == nullptr) {
|
||||||
|
return RET_INFER_INVALID;
|
||||||
|
}
|
||||||
|
output->set_shape({reinterpret_cast<int *>(num)[0]});
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "src/ops/primitive_c.h"
|
||||||
|
|
||||||
|
#ifndef LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_
|
||||||
|
#define LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class LinSpace : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
LinSpace() = default;
|
||||||
|
~LinSpace() = default;
|
||||||
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
|
MS_DECLARE_PARENT(LinSpace, PrimitiveC);
|
||||||
|
explicit LinSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||||
|
#else
|
||||||
|
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||||
|
#endif
|
||||||
|
int InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_
|
|
@ -170,8 +170,11 @@
|
||||||
#include "src/ops/crop_and_resize.h"
|
#include "src/ops/crop_and_resize.h"
|
||||||
#include "src/ops/nonzero.h"
|
#include "src/ops/nonzero.h"
|
||||||
#include "src/ops/erf.h"
|
#include "src/ops/erf.h"
|
||||||
#include "src/ops/is_finite.h"
|
|
||||||
#include "src/ops/batch_matmul.h"
|
#include "src/ops/batch_matmul.h"
|
||||||
|
#include "src/ops/lin_space.h"
|
||||||
|
#include "src/ops/uniform_real.h"
|
||||||
|
#include "src/ops/rank.h"
|
||||||
|
#include "src/ops/is_finite.h"
|
||||||
|
|
||||||
#ifdef SUPPORT_TRAIN
|
#ifdef SUPPORT_TRAIN
|
||||||
#include "src/ops/neg_grad.h"
|
#include "src/ops/neg_grad.h"
|
||||||
|
@ -1047,6 +1050,12 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
||||||
return new (std::nothrow) IsFinite(primitive);
|
return new (std::nothrow) IsFinite(primitive);
|
||||||
case schema::PrimitiveType_BatchMatMul:
|
case schema::PrimitiveType_BatchMatMul:
|
||||||
return new (std::nothrow) BatchMatMul(primitive);
|
return new (std::nothrow) BatchMatMul(primitive);
|
||||||
|
case schema::PrimitiveType_LinSpace:
|
||||||
|
return new (std::nothrow) LinSpace(primitive);
|
||||||
|
case schema::PrimitiveType_UniformReal:
|
||||||
|
return new (std::nothrow) UniformReal(primitive);
|
||||||
|
case schema::PrimitiveType_Rank:
|
||||||
|
return new (std::nothrow) Rank(primitive);
|
||||||
#ifdef SUPPORT_TRAIN
|
#ifdef SUPPORT_TRAIN
|
||||||
case schema::PrimitiveType_ActivationGrad:
|
case schema::PrimitiveType_ActivationGrad:
|
||||||
return new (std::nothrow) ActivationGrad(primitive);
|
return new (std::nothrow) ActivationGrad(primitive);
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/ops/uniform_real.h"
|
||||||
|
|
||||||
|
#ifndef PRIMITIVE_WRITEABLE
|
||||||
|
#include "src/ops/ops_register.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
|
int UniformReal::GetSeed() const { return this->primitive_->value.AsUniformReal()->seed; }
|
||||||
|
|
||||||
|
int UniformReal::GetSeed2() const { return this->primitive_->value.AsUniformReal()->seed2; }
|
||||||
|
|
||||||
|
int UniformReal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||||
|
if (this->primitive_ == nullptr) {
|
||||||
|
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||||
|
if (this->primitive_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new primitiveT failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
this->primitive_->value.type = schema::PrimitiveType_UniformReal;
|
||||||
|
}
|
||||||
|
if (this->primitive_->value.type != schema::PrimitiveType_UniformReal) {
|
||||||
|
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
auto attr = new (std::nothrow) schema::UniformRealT();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
this->primitive_->value.value = attr;
|
||||||
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
int UniformReal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||||
|
MS_ASSERT(nullptr != primitive);
|
||||||
|
MS_ASSERT(nullptr != fbb);
|
||||||
|
auto attr = primitive->value_as_UniformReal();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value_as_UniformReal return nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto val_offset = schema::CreateUniformReal(*fbb, attr->seed(), attr->seed2());
|
||||||
|
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_UniformReal, val_offset.o);
|
||||||
|
fbb->Finish(prim_offset);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int UniformReal::GetSeed() const { return this->primitive_->value_as_UniformReal()->seed(); }
|
||||||
|
|
||||||
|
int UniformReal::GetSeed2() const { return this->primitive_->value_as_UniformReal()->seed2(); }
|
||||||
|
|
||||||
|
PrimitiveC *UniformRealCreator(const schema::Primitive *primitive) {
|
||||||
|
return PrimitiveC::NewPrimitiveC<UniformReal>(primitive);
|
||||||
|
}
|
||||||
|
Registry UniformRealRegistry(schema::PrimitiveType_UniformReal, UniformRealCreator);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int UniformReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||||
|
if (!infer_flag()) {
|
||||||
|
return RET_INFER_INVALID;
|
||||||
|
}
|
||||||
|
auto input_data = static_cast<int32_t *>(inputs_[0]->data_c());
|
||||||
|
if (input_data == nullptr) {
|
||||||
|
return RET_INFER_INVALID;
|
||||||
|
}
|
||||||
|
auto input_num = inputs_[0]->ElementsNum();
|
||||||
|
std::vector<int> output_shape(input_num);
|
||||||
|
for (int i = 0; i < input_num; i++) {
|
||||||
|
output_shape[i] = input_data[i];
|
||||||
|
}
|
||||||
|
outputs_[0]->set_shape(output_shape);
|
||||||
|
outputs_[0]->set_data_type(kNumberTypeFloat32);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_
|
||||||
|
#define LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
#include <cmath>
|
||||||
|
#include <memory>
|
||||||
|
#include "src/ops/primitive_c.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class UniformReal : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
UniformReal() = default;
|
||||||
|
~UniformReal() = default;
|
||||||
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
|
MS_DECLARE_PARENT(UniformReal, PrimitiveC);
|
||||||
|
explicit UniformReal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||||
|
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||||
|
#else
|
||||||
|
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||||
|
#endif
|
||||||
|
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||||
|
int GetSeed() const;
|
||||||
|
int GetSeed2() const;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_
|
|
@ -52,6 +52,8 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
attr->type = schema::ActivationType_TANH;
|
attr->type = schema::ActivationType_TANH;
|
||||||
} else if (tf_op.op() == "LeakyRelu") {
|
} else if (tf_op.op() == "LeakyRelu") {
|
||||||
attr->type = schema::ActivationType_LEAKY_RELU;
|
attr->type = schema::ActivationType_LEAKY_RELU;
|
||||||
|
} else if (tf_op.op() == "Selu") {
|
||||||
|
attr->type = schema::ActivationType_SELU;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
|
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -63,7 +65,7 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
auto attr_leaky_relu = std::make_unique<schema::LeakyReLUT>();
|
auto attr_leaky_relu = std::make_unique<schema::LeakyReLUT>();
|
||||||
tensorflow::AttrValue attr_value;
|
tensorflow::AttrValue attr_value;
|
||||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
|
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
|
||||||
MS_LOG(ERROR) << "The attribute alpha shoud be specified.";
|
MS_LOG(ERROR) << "The attribute alpha should be specified.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
attr_leaky_relu->negativeSlope = attr_value.f();
|
attr_leaky_relu->negativeSlope = attr_value.f();
|
||||||
|
@ -85,5 +87,6 @@ TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
|
||||||
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
|
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
|
||||||
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
|
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
|
||||||
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser());
|
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser());
|
||||||
|
TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser());
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -58,6 +58,11 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
status = CreateOperator<schema::SquareT>(primitive, schema::PrimitiveType_Square);
|
status = CreateOperator<schema::SquareT>(primitive, schema::PrimitiveType_Square);
|
||||||
} else if (tf_op.op() == "Pow") {
|
} else if (tf_op.op() == "Pow") {
|
||||||
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power);
|
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power);
|
||||||
|
} else if (tf_op.op() == "Abs") {
|
||||||
|
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Abs);
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "unsupported arithmetic self type:" << tf_op.op();
|
||||||
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
return status;
|
return status;
|
||||||
|
@ -85,5 +90,6 @@ TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser());
|
||||||
TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser());
|
TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser());
|
||||||
TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser());
|
TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser());
|
||||||
TFNodeRegistrar g_tfPowParser("Pow", new TFArithmeticSelfParser());
|
TFNodeRegistrar g_tfPowParser("Pow", new TFArithmeticSelfParser());
|
||||||
|
TFNodeRegistrar g_tfAbsParser("Abs", new TFArithmeticSelfParser());
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,29 +22,35 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op,
|
STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||||
|
MS_LOG(DEBUG) << "TF BatchMatMulParser";
|
||||||
if (primitiveC == nullptr || output_size == nullptr) {
|
if (primitiveC == nullptr || output_size == nullptr) {
|
||||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
if (primitive == nullptr) {
|
if (primitive == nullptr) {
|
||||||
MS_LOG(ERROR) << "primitive is nullptr";
|
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
auto attr = std::make_unique<schema::BatchMatMulT>();
|
auto attr = std::make_unique<schema::BatchMatMulT>();
|
||||||
if (attr == nullptr) {
|
if (attr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new op failed";
|
MS_LOG(ERROR) << "new attr failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
tensorflow::AttrValue attr_value;
|
tensorflow::AttrValue attr_value;
|
||||||
TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value);
|
if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value)) {
|
||||||
attr->adj_x = attr_value.b();
|
MS_LOG(ERROR) << "The begin_mask attr should be specified";
|
||||||
attr->adj_y = attr_value.b();
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
attr->transpose_a = attr_value.b();
|
||||||
|
if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) {
|
||||||
|
MS_LOG(ERROR) << "The begin_mask attr should be specified";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
attr->transpose_b = attr_value.b();
|
||||||
primitive->value.type = schema::PrimitiveType_BatchMatMul;
|
primitive->value.type = schema::PrimitiveType_BatchMatMul;
|
||||||
primitive->value.value = attr.release();
|
primitive->value.value = attr.release();
|
||||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||||
|
@ -52,13 +58,15 @@ STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
*output_size = 1;
|
*output_size = 1;
|
||||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
for (int i = 0; i < tf_op.input_size(); ++i) {
|
||||||
inputs->emplace_back(tf_op.input(i));
|
auto status = AddOpInput(tf_op, i, inputs);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatmulParser());
|
TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatMulParser());
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,15 +23,14 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
class TFBatchMatmulParser : public TFNodeParser {
|
class TFBatchMatMulParser : public TFNodeParser {
|
||||||
public:
|
public:
|
||||||
TFBatchMatmulParser() = default;
|
TFBatchMatMulParser() = default;
|
||||||
~TFBatchMatmulParser() override = default;
|
~TFBatchMatMulParser() override = default;
|
||||||
|
|
||||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "tools/converter/parser/tf/tf_linspace_parser.h"
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS TFLinSpaceParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
|
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||||
|
MS_LOG(DEBUG) << "TF LinSpaceParser";
|
||||||
|
if (primitiveC == nullptr || output_size == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
auto attr = std::make_unique<schema::LinSpaceT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new attr failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
primitive->value.type = schema::PrimitiveType_LinSpace;
|
||||||
|
primitive->value.value = attr.release();
|
||||||
|
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||||
|
if (*primitiveC == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
*output_size = 1;
|
||||||
|
for (int i = 0; i < tf_op.input_size(); ++i) {
|
||||||
|
auto status = AddOpInput(tf_op, i, inputs);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
TFNodeRegistrar g_tfLinSpaceParser("LinSpace", new TFLinSpaceParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class TFLinSpaceParser : public TFNodeParser {
|
||||||
|
public:
|
||||||
|
TFLinSpaceParser() = default;
|
||||||
|
~TFLinSpaceParser() override = default;
|
||||||
|
|
||||||
|
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_
|
|
@ -0,0 +1,59 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "tools/converter/parser/tf/tf_rank_parser.h"
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS TFRankParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
|
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||||
|
std::vector<std::string> *inputs, int *output_size) {
|
||||||
|
MS_LOG(DEBUG) << "TF RankParser";
|
||||||
|
if (primitiveC == nullptr || output_size == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
auto attr = std::make_unique<schema::RankT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new attr failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
primitive->value.type = schema::PrimitiveType_Rank;
|
||||||
|
primitive->value.value = attr.release();
|
||||||
|
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||||
|
if (*primitiveC == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
*output_size = 1;
|
||||||
|
auto status = AddOpInput(tf_op, 0, inputs);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
TFNodeRegistrar g_tfRankParser("Rank", new TFRankParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class TFRankParser : public TFNodeParser {
|
||||||
|
public:
|
||||||
|
TFRankParser() = default;
|
||||||
|
~TFRankParser() override = default;
|
||||||
|
|
||||||
|
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_
|
|
@ -0,0 +1,68 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "tools/converter/parser/tf/tf_uniform_real_parser.h"
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS TFUniformRealParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
|
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||||
|
MS_LOG(DEBUG) << "TF UniformRealParser";
|
||||||
|
if (primitiveC == nullptr || output_size == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
auto attr = std::make_unique<schema::UniformRealT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new attr failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
tensorflow::AttrValue attr_value;
|
||||||
|
if (!TensorFlowUtils::FindAttrValue(tf_op, "seed", &attr_value)) {
|
||||||
|
MS_LOG(ERROR) << "The seed attr should be specified";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
attr->seed = attr_value.i();
|
||||||
|
if (!TensorFlowUtils::FindAttrValue(tf_op, "seed2", &attr_value)) {
|
||||||
|
MS_LOG(ERROR) << "The seed2 attr should be specified";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
attr->seed2 = attr_value.i();
|
||||||
|
primitive->value.type = schema::PrimitiveType_UniformReal;
|
||||||
|
primitive->value.value = attr.release();
|
||||||
|
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||||
|
if (*primitiveC == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
*output_size = 1;
|
||||||
|
auto status = AddOpInput(tf_op, 0, inputs);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
TFNodeRegistrar g_tfRandomUniformParser("RandomUniform", new TFUniformRealParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class TFUniformRealParser : public TFNodeParser {
|
||||||
|
public:
|
||||||
|
TFUniformRealParser() = default;
|
||||||
|
~TFUniformRealParser() override = default;
|
||||||
|
|
||||||
|
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_
|
Loading…
Reference in New Issue