Add MS_DECLARE_PARENT and UnPackAttr judge null and change setter position

This commit is contained in:
yeyunpeng 2020-08-23 18:17:39 +08:00
parent 80d570f003
commit 6c1eb3c22d
213 changed files with 897 additions and 788 deletions

View File

@ -19,11 +19,6 @@
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic_self.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ABS_H_
#define LITE_MINDSPORE_LITE_C_OPS_ABS_H_
@ -33,6 +28,7 @@ namespace lite {
class Abs : public ArithmeticSelf {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Abs, ArithmeticSelf);
Abs() = default;
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else

View File

@ -27,7 +27,18 @@ void Activation::SetType(int type) { this->primitive_->value.AsActivation()->typ
void Activation::SetAlpha(float alpha) { this->primitive_->value.AsActivation()->alpha = alpha; }
int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Activation;
}
if (this->primitive_->value.type != schema::PrimitiveType_Activation) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
auto attr = std::make_unique<schema::ActivationT>();
if (prim.name() == "ReLU") {
attr->type = schema::ActivationType_RELU;
@ -36,18 +47,17 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
} else if (prim.name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6;
}
this->primitive_->value.type = schema::PrimitiveType_Activation;
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
return RET_OK;
}
#else
int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); }
float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); }
void Activation::SetType(int type) {}
void Activation::SetAlpha(float alpha) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -27,16 +27,17 @@ namespace lite {
class Activation : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Activation, PrimitiveC);
Activation() = default;
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetType(int type);
void SetAlpha(float alpha);
#else
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int GetType() const;
float GetAlpha() const;
void SetType(int type);
void SetAlpha(float alpha);
};
} // namespace lite
} // namespace mindspore

View File

@ -29,7 +29,6 @@ void ActivationGrad::SetType(int type) {
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
void ActivationGrad::SetType(int type) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,13 +28,14 @@ namespace lite {
class ActivationGrad : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ActivationGrad, PrimitiveC);
ActivationGrad() = default;
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetType(int type);
#else
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int GetType() const;
void SetType(int type);
};
} // namespace lite
} // namespace mindspore

View File

@ -36,7 +36,7 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
this->primitive_->value.type = schema::PrimitiveType_Add;
}
if (this->primitive_->value.type != schema::PrimitiveType_Add) {
MS_LOG(ERROR) << "Primitive type should be add";
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
@ -53,7 +53,6 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }
void Add::SetActivationType(int activation_type) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -22,25 +22,21 @@
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
namespace mindspore {
namespace lite {
class Add : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Add, Arithmetic);
Add() = default;
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetActivationType(int activation_type);
#else
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}
#endif
int GetActivationType() const;
void SetActivationType(int activation_type);
};
} // namespace lite
} // namespace mindspore

View File

@ -27,7 +27,6 @@ void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; }
int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); }
void AddN::SetN(int n) {}
#endif
namespace {

View File

@ -28,14 +28,15 @@ namespace lite {
class AddN : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(AddN, PrimitiveC);
AddN() = default;
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetN(int n);
#else
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetN() const;
void SetN(int n);
};
} // namespace lite
} // namespace mindspore

View File

@ -39,11 +39,6 @@ int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK()
bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); }
int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); }
void ArgMax::SetAxis(int axis) {}
void ArgMax::SetOutMaxValue(bool out_max_value) {}
void ArgMax::SetTopK(int top_k) {}
void ArgMax::SetKeepDims(bool keep_dims) {}
void ArgMax::SetAxisType(int axis_type) {}
#endif
int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {

View File

@ -28,8 +28,14 @@ namespace lite {
class ArgMax : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ArgMax, PrimitiveC);
ArgMax() = default;
explicit ArgMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(int axis);
void SetOutMaxValue(bool out_max_value);
void SetTopK(int top_k);
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
#else
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
@ -39,11 +45,6 @@ class ArgMax : public PrimitiveC {
int GetTopK() const;
bool GetKeepDims() const;
int GetAxisType() const;
void SetAxis(int axis);
void SetOutMaxValue(bool out_max_value);
void SetTopK(int top_k);
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
};
} // namespace lite
} // namespace mindspore

View File

@ -39,11 +39,6 @@ int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK()
bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); }
int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); }
void ArgMin::SetAxis(int axis) {}
void ArgMin::SetOutMaxValue(bool out_max_value) {}
void ArgMin::SetTopK(int top_k) {}
void ArgMin::SetKeepDims(bool keep_dims) {}
void ArgMin::SetAxisType(int axis_type) {}
#endif
int ArgMin::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {

View File

@ -28,8 +28,14 @@ namespace lite {
class ArgMin : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ArgMin, PrimitiveC);
ArgMin() = default;
explicit ArgMin(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(int axis);
void SetOutMaxValue(bool out_max_value);
void SetTopK(int top_k);
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
#else
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
@ -39,11 +45,6 @@ class ArgMin : public PrimitiveC {
int GetTopK() const;
bool GetKeepDims() const;
int GetAxisType() const;
void SetAxis(int axis);
void SetOutMaxValue(bool out_max_value);
void SetTopK(int top_k);
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
};
} // namespace lite
} // namespace mindspore

View File

@ -28,6 +28,7 @@ namespace lite {
class Arithmetic : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Arithmetic, PrimitiveC);
Arithmetic() = default;
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else

View File

@ -25,6 +25,7 @@ namespace lite {
class ArithmeticSelf : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ArithmeticSelf, PrimitiveC);
ArithmeticSelf() = default;
explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else

View File

@ -24,11 +24,27 @@ float BatchNorm::GetEpsilon() const { return this->primitive_->value.AsBatchNorm
void BatchNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsBatchNorm()->epsilon = epsilon; }
int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::FusedBatchNormT>();
attr->epsilon = GetValue<float>(prim.GetAttr("epsilon"));
this->primitive_->value.type = schema::PrimitiveType_FusedBatchNorm;
this->primitive_->value.value = attr.release();
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_FusedBatchNorm;
}
if (this->primitive_->value.type != schema::PrimitiveType_FusedBatchNorm) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::FusedBatchNormT();
attr->epsilon = GetValue<float>(prim.GetAttr("epsilon"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
@ -36,7 +52,6 @@ int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &
float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); }
void BatchNorm::SetEpsilon(float epsilon) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,14 +28,15 @@ namespace lite {
class BatchNorm : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BatchNorm, PrimitiveC);
BatchNorm() = default;
explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetEpsilon(float epsilon);
#else
explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetEpsilon() const;
void SetEpsilon(float epsilon);
};
} // namespace lite
} // namespace mindspore

View File

@ -42,8 +42,6 @@ std::vector<int> BatchToSpace::GetCrops() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {}
void BatchToSpace::SetCrops(const std::vector<int> &crops) {}
#endif
namespace {
constexpr int kBatchToSpaceOutputNum = 1;

View File

@ -28,16 +28,17 @@ namespace lite {
class BatchToSpace : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BatchToSpace, PrimitiveC);
BatchToSpace() = default;
explicit BatchToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBlockShape(const std::vector<int> &block_shape);
void SetCrops(const std::vector<int> &crops);
#else
explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetBlockShape() const;
std::vector<int> GetCrops() const;
void SetBlockShape(const std::vector<int> &block_shape);
void SetCrops(const std::vector<int> &crops);
};
} // namespace lite
} // namespace mindspore

View File

@ -25,12 +25,31 @@ std::vector<int> BiasAdd::GetAxis() const { return this->primitive_->value.AsBia
void BiasAdd::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; }
int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::BiasAddT>();
attr->axis = {0};
this->primitive_->value.type = schema::PrimitiveType_BiasAdd;
this->primitive_->value.value = attr.release();
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BiasAdd;
}
if (this->primitive_->value.type != schema::PrimitiveType_BiasAdd) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::BiasAddT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->axis = {0};
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
@ -41,7 +60,6 @@ std::vector<int> BiasAdd::GetAxis() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BiasAdd::SetAxis(const std::vector<int> &axis) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,14 +28,15 @@ namespace lite {
class BiasAdd : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BiasAdd, PrimitiveC);
BiasAdd() = default;
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetAxis(const std::vector<int> &axis);
#else
explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
std::vector<int> GetAxis() const;
void SetAxis(const std::vector<int> &axis);
};
} // namespace lite
} // namespace mindspore

View File

@ -30,7 +30,6 @@ std::vector<int> BiasGrad::GetAxis() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BiasGrad::SetAxis(const std::vector<int> &axis) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,13 +28,15 @@ namespace lite {
class BiasGrad : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BiasGrad, PrimitiveC);
BiasGrad() = default;
explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(const std::vector<int> &axis);
#else
explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
std::vector<int> GetAxis() const;
void SetAxis(const std::vector<int> &axis);
};
} // namespace lite
} // namespace mindspore

View File

@ -30,8 +30,6 @@ void BNGradInput::SetChannels(int channels) { this->primitive_->value.AsBNGradIn
float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); }
int BNGradInput::GetChannels() const { return this->primitive_->value_as_BNGradInput()->channels(); }
void BNGradInput::SetEps(float eps) {}
void BNGradInput::SetChannels(int channels) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,15 +28,16 @@ namespace lite {
class BNGradInput : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BNGradInput, PrimitiveC);
BNGradInput() = default;
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetEps(float eps);
void SetChannels(int channels);
#else
explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetEps() const;
int GetChannels() const;
void SetEps(float eps);
void SetChannels(int channels);
};
} // namespace lite
} // namespace mindspore

View File

@ -32,7 +32,6 @@ std::vector<int> BroadcastTo::GetDstShape() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {}
#endif
namespace {
constexpr int kBroadcastToInputNum = 1;

View File

@ -28,14 +28,16 @@ namespace lite {
class BroadcastTo : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BroadcastTo, PrimitiveC);
BroadcastTo() = default;
explicit BroadcastTo(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetDstShape(const std::vector<int> &dst_shape);
#else
explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetDstShape() const;
void SetDstShape(const std::vector<int> &dst_shape);
};
} // namespace lite
} // namespace mindspore

View File

@ -29,7 +29,6 @@ void CaffePReLU::SetChannelShared(bool channel_shared) {
bool CaffePReLU::GetChannelShared() const { return this->primitive_->value_as_CaffePReLU()->channelShared(); }
void CaffePReLU::SetChannelShared(bool channel_shared) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,13 +28,15 @@ namespace lite {
class CaffePReLU : public Activation {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(CaffePReLU, Activation);
CaffePReLU() = default;
explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {}
void SetChannelShared(bool channel_shared);
#else
explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {}
#endif
bool GetChannelShared() const;
void SetChannelShared(bool channel_shared);
};
} // namespace lite
} // namespace mindspore

View File

@ -30,8 +30,6 @@ void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t;
int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); }
int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); }
void Cast::SetSrcT(int src_t) {}
void Cast::SetDstT(int dst_t) {}
#endif
int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {

View File

@ -28,16 +28,17 @@ namespace lite {
class Cast : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Cast, PrimitiveC);
Cast() = default;
explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetSrcT(int src_t);
void SetDstT(int dst_t);
#else
explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetSrcT() const;
int GetDstT() const;
void SetSrcT(int src_t);
void SetDstT(int dst_t);
};
} // namespace lite
} // namespace mindspore

View File

@ -20,14 +20,15 @@
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/arithmetic_self.h"
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Ceil : public ArithmeticSelf {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Ceil, ArithmeticSelf);
Ceil() = default;
explicit Ceil(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else

View File

@ -30,8 +30,6 @@ void Clip::SetMin(float min) { this->primitive_->value.AsClip()->min = min; }
float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); }
float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); }
void Clip::SetMax(float max) {}
void Clip::SetMin(float min) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,15 +28,16 @@ namespace lite {
class Clip : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Clip, PrimitiveC);
Clip() = default;
explicit Clip(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetMax(float max);
void SetMin(float min);
#else
explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetMax() const;
float GetMin() const;
void SetMax(float max);
void SetMin(float min);
};
} // namespace lite
} // namespace mindspore

View File

@ -30,12 +30,32 @@ void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis
void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; }
int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::ConcatT>();
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
attr->axis = prim_axis;
this->primitive_->value.type = schema::PrimitiveType_Concat;
this->primitive_->value.value = attr.release();
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Concat;
}
if (this->primitive_->value.type != schema::PrimitiveType_Concat) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::ConcatT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
attr->axis = prim_axis;
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
@ -44,8 +64,6 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); }
void Concat::SetAxis(int axis) {}
void Concat::SetN(int n) {}
#endif
namespace {

View File

@ -28,17 +28,18 @@ namespace lite {
class Concat : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Concat, PrimitiveC);
Concat() = default;
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetAxis(int axis);
void SetN(int n);
#else
explicit Concat(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;
int GetN() const;
void SetAxis(int axis);
void SetN(int n);
};
} // namespace lite
} // namespace mindspore

View File

@ -33,7 +33,6 @@ void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstant
float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); }
void ConstantOfShape::SetValue(float value) {}
#endif
int ConstantOfShape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {

View File

@ -28,14 +28,15 @@ namespace lite {
class ConstantOfShape : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ConstantOfShape, PrimitiveC);
ConstantOfShape() = default;
explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetValue(float value);
#else
explicit ConstantOfShape(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
float GetValue() const;
void SetValue(float value);
};
} // namespace lite
} // namespace mindspore

View File

@ -19,7 +19,6 @@
#include <memory>
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#endif
@ -309,8 +308,18 @@ void Conv2D::PopulaterQuantParam(const Primitive &prim,
}
int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Conv2D;
}
if (this->primitive_->value.type != schema::PrimitiveType_Conv2D) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
int group = GetValue<int>(prim.GetAttr("group"));
if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
@ -348,23 +357,6 @@ int Conv2D::GetDilateH() const { return this->primitive_->value_as_Conv2D()->dil
bool Conv2D::GetHasBias() const { return this->primitive_->value_as_Conv2D()->hasBias(); }
int Conv2D::GetActivationType() const { return this->primitive_->value_as_Conv2D()->activationType(); }
void Conv2D::SetFormat(int format) {}
void Conv2D::SetGroup(int group) {}
void Conv2D::SetChannelIn(int channel_in) {}
void Conv2D::SetChannelOut(int channel_out) {}
void Conv2D::SetKernelW(int kernel_w) {}
void Conv2D::SetKernelH(int kernel_h) {}
void Conv2D::SetStrideW(int stride_w) {}
void Conv2D::SetStrideH(int stride_h) {}
void Conv2D::SetPadMode(int pad_mode) {}
void Conv2D::SetPadUp(int pad_up) {}
void Conv2D::SetPadDown(int pad_down) {}
void Conv2D::SetPadLeft(int pad_left) {}
void Conv2D::SetPadRight(int pad_right) {}
void Conv2D::SetDilateW(int dilate_w) {}
void Conv2D::SetDilateH(int dilate_h) {}
void Conv2D::SetHasBias(bool has_bias) {}
void Conv2D::SetActivationType(int activation_type) {}
#endif
void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) {
MS_ASSERT(this->primitive_ != nullptr);

View File

@ -28,12 +28,30 @@ namespace mindspore {
namespace lite {
class Conv2D : public PrimitiveC {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Conv2D, PrimitiveC);
public:
Conv2D() = default;
explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
private:
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
@ -72,23 +90,6 @@ class Conv2D : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
protected:
void ConvInferShape(int input_h, int input_w, int *output_h, int *output_w);

View File

@ -89,23 +89,6 @@ int Conv2DGradFilter::GetActivationType() const {
return this->primitive_->value_as_Conv2DGradFilter()->activationType();
}
void Conv2DGradFilter::SetFormat(int format) {}
void Conv2DGradFilter::SetGroup(int group) {}
void Conv2DGradFilter::SetChannelIn(int channel_in) {}
void Conv2DGradFilter::SetChannelOut(int channel_out) {}
void Conv2DGradFilter::SetKernelW(int kernel_w) {}
void Conv2DGradFilter::SetKernelH(int kernel_h) {}
void Conv2DGradFilter::SetStrideW(int stride_w) {}
void Conv2DGradFilter::SetStrideH(int stride_h) {}
void Conv2DGradFilter::SetPadMode(int pad_mode) {}
void Conv2DGradFilter::SetPadUp(int pad_up) {}
void Conv2DGradFilter::SetPadDown(int pad_down) {}
void Conv2DGradFilter::SetPadLeft(int pad_left) {}
void Conv2DGradFilter::SetPadRight(int pad_right) {}
void Conv2DGradFilter::SetDilateW(int dilate_w) {}
void Conv2DGradFilter::SetDilateH(int dilate_h) {}
void Conv2DGradFilter::SetHasBias(bool has_bias) {}
void Conv2DGradFilter::SetActivationType(int activation_type) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,8 +28,26 @@ namespace lite {
class Conv2DGradFilter : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Conv2DGradFilter, PrimitiveC);
Conv2DGradFilter() = default;
explicit Conv2DGradFilter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit Conv2DGradFilter(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
@ -50,23 +68,6 @@ class Conv2DGradFilter : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
};
} // namespace lite
} // namespace mindspore

View File

@ -87,23 +87,6 @@ int Conv2DGradInput::GetActivationType() const {
return this->primitive_->value_as_Conv2DGradInput()->activationType();
}
void Conv2DGradInput::SetFormat(int format) {}
void Conv2DGradInput::SetGroup(int group) {}
void Conv2DGradInput::SetChannelIn(int channel_in) {}
void Conv2DGradInput::SetChannelOut(int channel_out) {}
void Conv2DGradInput::SetKernelW(int kernel_w) {}
void Conv2DGradInput::SetKernelH(int kernel_h) {}
void Conv2DGradInput::SetStrideW(int stride_w) {}
void Conv2DGradInput::SetStrideH(int stride_h) {}
void Conv2DGradInput::SetPadMode(int pad_mode) {}
void Conv2DGradInput::SetPadUp(int pad_up) {}
void Conv2DGradInput::SetPadDown(int pad_down) {}
void Conv2DGradInput::SetPadLeft(int pad_left) {}
void Conv2DGradInput::SetPadRight(int pad_right) {}
void Conv2DGradInput::SetDilateW(int dilate_w) {}
void Conv2DGradInput::SetDilateH(int dilate_h) {}
void Conv2DGradInput::SetHasBias(bool has_bias) {}
void Conv2DGradInput::SetActivationType(int activation_type) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,8 +28,26 @@ namespace lite {
class Conv2DGradInput : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Conv2DGradInput, PrimitiveC);
Conv2DGradInput() = default;
explicit Conv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit Conv2DGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
@ -50,23 +68,6 @@ class Conv2DGradInput : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
};
} // namespace lite
} // namespace mindspore

View File

@ -33,8 +33,6 @@ std::vector<int64_t> Crop::GetOffsets() const {
return std::vector<int64_t>(fb_vector->begin(), fb_vector->end());
}
void Crop::SetAxis(int64_t axis) {}
void Crop::SetOffsets(const std::vector<int64_t> &offsets) {}
#endif
namespace {
constexpr int kCropOutputNum = 1;

View File

@ -28,16 +28,17 @@ namespace lite {
class Crop : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Crop, PrimitiveC);
Crop() = default;
explicit Crop(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(int64_t axis);
void SetOffsets(const std::vector<int64_t> &offsets);
#else
explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int64_t GetAxis() const;
std::vector<int64_t> GetOffsets() const;
void SetAxis(int64_t axis);
void SetOffsets(const std::vector<int64_t> &offsets);
};
} // namespace lite
} // namespace mindspore

View File

@ -77,23 +77,6 @@ int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()-
bool DeConv2D::GetHasBias() const { return this->primitive_->value_as_DeConv2D()->hasBias(); }
int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); }
void DeConv2D::SetFormat(int format) {}
void DeConv2D::SetGroup(int group) {}
void DeConv2D::SetChannelIn(int channel_in) {}
void DeConv2D::SetChannelOut(int channel_out) {}
void DeConv2D::SetKernelW(int kernel_w) {}
void DeConv2D::SetKernelH(int kernel_h) {}
void DeConv2D::SetStrideW(int stride_w) {}
void DeConv2D::SetStrideH(int stride_h) {}
void DeConv2D::SetPadMode(int pad_mode) {}
void DeConv2D::SetPadUp(int pad_up) {}
void DeConv2D::SetPadDown(int pad_down) {}
void DeConv2D::SetPadLeft(int pad_left) {}
void DeConv2D::SetPadRight(int pad_right) {}
void DeConv2D::SetDilateW(int dilate_w) {}
void DeConv2D::SetDilateH(int dilate_h) {}
void DeConv2D::SetHasBias(bool has_bias) {}
void DeConv2D::SetActivationType(int activation_type) {}
#endif
int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);

View File

@ -28,8 +28,26 @@ namespace lite {
class DeConv2D : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(DeConv2D, PrimitiveC);
DeConv2D() = default;
explicit DeConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit DeConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
@ -51,23 +69,6 @@ class DeConv2D : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
int PadUp() const { return this->pad_u_; }
int PadDown() const { return this->pad_d_; }

View File

@ -92,22 +92,6 @@ int DeDepthwiseConv2D::GetActivationType() const {
return this->primitive_->value_as_DeDepthwiseConv2D()->activationType();
}
void DeDepthwiseConv2D::SetFormat(int format) {}
void DeDepthwiseConv2D::SetChannelIn(int channel_in) {}
void DeDepthwiseConv2D::SetChannelMultiplier(int channel_multiplier) {}
void DeDepthwiseConv2D::SetKernelW(int kernel_w) {}
void DeDepthwiseConv2D::SetKernelH(int kernel_h) {}
void DeDepthwiseConv2D::SetStrideW(int stride_w) {}
void DeDepthwiseConv2D::SetStrideH(int stride_h) {}
void DeDepthwiseConv2D::SetPadMode(int pad_mode) {}
void DeDepthwiseConv2D::SetPadUp(int pad_up) {}
void DeDepthwiseConv2D::SetPadDown(int pad_down) {}
void DeDepthwiseConv2D::SetPadLeft(int pad_left) {}
void DeDepthwiseConv2D::SetPadRight(int pad_right) {}
void DeDepthwiseConv2D::SetDilateW(int dilate_w) {}
void DeDepthwiseConv2D::SetDilateH(int dilate_h) {}
void DeDepthwiseConv2D::SetHasBias(bool has_bias) {}
void DeDepthwiseConv2D::SetActivationType(int activation_type) {}
#endif
int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {

View File

@ -28,8 +28,25 @@ namespace lite {
class DeDepthwiseConv2D : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(DeDepthwiseConv2D, PrimitiveC);
DeDepthwiseConv2D() = default;
explicit DeDepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetChannelIn(int channel_in);
void SetChannelMultiplier(int channel_multiplier);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit DeDepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
@ -50,22 +67,6 @@ class DeDepthwiseConv2D : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetChannelIn(int channel_in);
void SetChannelMultiplier(int channel_multiplier);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
int PadUp() const { return this->pad_u_; }
int PadDown() const { return this->pad_d_; }

View File

@ -30,8 +30,6 @@ void DepthToSpace::SetFormat(int format) { this->primitive_->value.AsDepthToSpac
int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); }
int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); }
void DepthToSpace::SetBlockSize(int block_size) {}
void DepthToSpace::SetFormat(int format) {}
#endif
namespace {
constexpr int kDepthToSpaceOutputNum = 1;

View File

@ -28,16 +28,17 @@ namespace lite {
class DepthToSpace : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(DepthToSpace, PrimitiveC);
DepthToSpace() = default;
explicit DepthToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBlockSize(int block_size);
void SetFormat(int format);
#else
explicit DepthToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetBlockSize() const;
int GetFormat() const;
void SetBlockSize(int block_size);
void SetFormat(int format);
};
} // namespace lite
} // namespace mindspore

View File

@ -254,22 +254,6 @@ int DepthwiseConv2D::GetActivationType() const {
return this->primitive_->value_as_DepthwiseConv2D()->activationType();
}
void DepthwiseConv2D::SetFormat(int format) {}
void DepthwiseConv2D::SetChannelIn(int channel_in) {}
void DepthwiseConv2D::SetChannelMultiplier(int channel_multiplier) {}
void DepthwiseConv2D::SetKernelW(int kernel_w) {}
void DepthwiseConv2D::SetKernelH(int kernel_h) {}
void DepthwiseConv2D::SetStrideW(int stride_w) {}
void DepthwiseConv2D::SetStrideH(int stride_h) {}
void DepthwiseConv2D::SetPadMode(int pad_mode) {}
void DepthwiseConv2D::SetPadUp(int pad_up) {}
void DepthwiseConv2D::SetPadDown(int pad_down) {}
void DepthwiseConv2D::SetPadLeft(int pad_left) {}
void DepthwiseConv2D::SetPadRight(int pad_right) {}
void DepthwiseConv2D::SetDilateW(int dilate_w) {}
void DepthwiseConv2D::SetDilateH(int dilate_h) {}
void DepthwiseConv2D::SetHasBias(bool has_bias) {}
void DepthwiseConv2D::SetActivationType(int activation_type) {}
#endif
int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {

View File

@ -27,12 +27,29 @@ namespace mindspore {
namespace lite {
class DepthwiseConv2D : public PrimitiveC {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(DepthwiseConv2D, PrimitiveC);
public:
DepthwiseConv2D() = default;
explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetFormat(int format);
void SetChannelIn(int channel_in);
void SetChannelMultiplier(int channel_multiplier);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
private:
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
@ -62,22 +79,6 @@ class DepthwiseConv2D : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetChannelIn(int channel_in);
void SetChannelMultiplier(int channel_multiplier);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
int PadUp() const { return this->pad_u_; }
int PadDown() const { return this->pad_d_; }

View File

@ -21,10 +21,30 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Dequant::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::OnnxInt8DequantizeT>();
this->primitive_->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
this->primitive_->value.value = attr.release();
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
}
if (this->primitive_->value.type != schema::PrimitiveType_OnnxInt8Dequantize) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow)(schema::OnnxInt8DequantizeT);
if (attr == nullptr) {
MS_LOG(ERROR) << "attr is nullptr";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#endif

View File

@ -25,6 +25,7 @@ namespace lite {
class Dequant : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Dequant, PrimitiveC);
Dequant() = default;
explicit Dequant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);

View File

@ -119,19 +119,6 @@ bool DetectionPostProcess::GetUseRegularNms() const {
return this->primitive_->value_as_DetectionPostProcess()->UseRegularNms();
}
void DetectionPostProcess::SetFormat(int format) {}
void DetectionPostProcess::SetInputSize(int input_size) {}
void DetectionPostProcess::SetHScale(float h_scale) {}
void DetectionPostProcess::SetWScale(float w_scale) {}
void DetectionPostProcess::SetXScale(float x_scale) {}
void DetectionPostProcess::SetYScale(float y_scale) {}
void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) {}
void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) {}
void DetectionPostProcess::SetMaxDetections(int64_t max_detections) {}
void DetectionPostProcess::SetDetectionsPreClass(int64_t detections_pre_class) {}
void DetectionPostProcess::SetMaxClassesPreDetection(int64_t max_classes_pre_detection) {}
void DetectionPostProcess::SetNumClasses(int64_t num_classes) {}
void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,8 +28,22 @@ namespace lite {
class DetectionPostProcess : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(DetectionPostProcess, PrimitiveC);
DetectionPostProcess() = default;
explicit DetectionPostProcess(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetInputSize(int input_size);
void SetHScale(float h_scale);
void SetWScale(float w_scale);
void SetXScale(float x_scale);
void SetYScale(float y_scale);
void SetNmsIouThreshold(float nms_iou_threshold);
void SetNmsScoreThreshold(float nms_score_threshold);
void SetMaxDetections(int64_t max_detections);
void SetDetectionsPreClass(int64_t detections_pre_class);
void SetMaxClassesPreDetection(int64_t max_classes_pre_detection);
void SetNumClasses(int64_t num_classes);
void SetUseRegularNms(bool use_regular_nms);
#else
explicit DetectionPostProcess(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
@ -46,19 +60,6 @@ class DetectionPostProcess : public PrimitiveC {
int64_t GetMaxClassesPreDetection() const;
int64_t GetNumClasses() const;
bool GetUseRegularNms() const;
void SetFormat(int format);
void SetInputSize(int input_size);
void SetHScale(float h_scale);
void SetWScale(float w_scale);
void SetXScale(float x_scale);
void SetYScale(float y_scale);
void SetNmsIouThreshold(float nms_iou_threshold);
void SetNmsScoreThreshold(float nms_score_threshold);
void SetMaxDetections(int64_t max_detections);
void SetDetectionsPreClass(int64_t detections_pre_class);
void SetMaxClassesPreDetection(int64_t max_classes_pre_detection);
void SetNumClasses(int64_t num_classes);
void SetUseRegularNms(bool use_regular_nms);
};
} // namespace lite
} // namespace mindspore

View File

@ -29,7 +29,6 @@ void Div::SetActivationType(int activation_type) {
int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); }
void Div::SetActivationType(int activation_type) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,13 +28,15 @@ namespace lite {
class Div : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Div, Arithmetic);
Div() = default;
explicit Div(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
void SetActivationType(int activation_type);
#else
explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {}
#endif
int GetActivationType() const;
void SetActivationType(int activation_type);
};
} // namespace lite
} // namespace mindspore

View File

@ -27,7 +27,6 @@ void Dropout::SetRatio(float ratio) { this->primitive_->value.AsDropout()->ratio
float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); }
void Dropout::SetRatio(float ratio) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -20,21 +20,23 @@
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#include "ir/dtype/type_id.h"
namespace mindspore {
namespace lite {
class Dropout : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Dropout, PrimitiveC);
Dropout() = default;
explicit Dropout(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetRatio(float ratio);
#else
explicit Dropout(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetRatio() const;
void SetRatio(float ratio);
};
} // namespace lite
} // namespace mindspore

View File

@ -27,7 +27,6 @@ void Eltwise::SetMode(int mode) { this->primitive_->value.AsEltwise()->mode = (s
int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); }
void Eltwise::SetMode(int mode) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,13 +28,15 @@ namespace lite {
class Eltwise : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Eltwise, PrimitiveC);
Eltwise() = default;
explicit Eltwise(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetMode(int mode);
#else
explicit Eltwise(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int GetMode() const;
void SetMode(int mode);
};
} // namespace lite
} // namespace mindspore

View File

@ -27,7 +27,6 @@ void Elu::SetAlpha(float alpha) { this->primitive_->value.AsElu()->alpha = alpha
float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); }
void Elu::SetAlpha(float alpha) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,13 +28,15 @@ namespace lite {
class Elu : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Elu, PrimitiveC);
Elu() = default;
explicit Elu(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAlpha(float alpha);
#else
explicit Elu(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetAlpha() const;
void SetAlpha(float alpha);
};
} // namespace lite
} // namespace mindspore

View File

@ -27,7 +27,6 @@ void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive_->value.AsEmb
float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); }
void EmbeddingLookup::SetMaxNorm(float max_norm) {}
#endif
int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {

View File

@ -28,14 +28,16 @@ namespace lite {
class EmbeddingLookup : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(EmbeddingLookup, PrimitiveC);
EmbeddingLookup() = default;
explicit EmbeddingLookup(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetMaxNorm(float max_norm);
#else
explicit EmbeddingLookup(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
float GetMaxNorm() const;
void SetMaxNorm(float max_norm);
};
} // namespace lite
} // namespace mindspore

View File

@ -51,9 +51,6 @@ float EmbeddingLookupSparse::GetMaxNortm() const {
return this->primitive_->value_as_EmbeddingLookupSparse()->maxNortm();
}
void EmbeddingLookupSparse::SetSpIds(const std::vector<int> &sp_ids) {}
void EmbeddingLookupSparse::SetSpWeights(const std::vector<float> &sp_weights) {}
void EmbeddingLookupSparse::SetMaxNortm(float max_nortm) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,17 +28,18 @@ namespace lite {
class EmbeddingLookupSparse : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(EmbeddingLookupSparse, PrimitiveC);
EmbeddingLookupSparse() = default;
explicit EmbeddingLookupSparse(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetSpIds(const std::vector<int> &sp_ids);
void SetSpWeights(const std::vector<float> &sp_weights);
void SetMaxNortm(float max_nortm);
#else
explicit EmbeddingLookupSparse(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
std::vector<int> GetSpIds() const;
std::vector<float> GetSpWeights() const;
float GetMaxNortm() const;
void SetSpIds(const std::vector<int> &sp_ids);
void SetSpWeights(const std::vector<float> &sp_weights);
void SetMaxNortm(float max_nortm);
};
} // namespace lite
} // namespace mindspore

View File

@ -28,6 +28,7 @@ namespace lite {
class Equal : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Equal, PrimitiveC);
Equal() = default;
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else

View File

@ -28,6 +28,7 @@ namespace lite {
class Exp : public ArithmeticSelf {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Exp, ArithmeticSelf);
Exp() = default;
explicit Exp(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else

View File

@ -27,7 +27,6 @@ void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim =
int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); }
void ExpandDims::SetDim(int dim) {}
#endif
int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {

View File

@ -28,14 +28,16 @@ namespace lite {
class ExpandDims : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ExpandDims, PrimitiveC);
ExpandDims() = default;
explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetDim(int dim);
#else
explicit ExpandDims(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetDim() const;
void SetDim(int dim);
};
} // namespace lite
} // namespace mindspore

View File

@ -40,8 +40,6 @@ int FakeQuantWithMinMaxVars::GetNumBits() const {
return this->primitive_->value_as_FakeQuantWithMinMaxVars()->numBits();
}
void FakeQuantWithMinMaxVars::SetNarrowRange(bool narrow_range) {}
void FakeQuantWithMinMaxVars::SetNumBits(int num_bits) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,15 +28,16 @@ namespace lite {
class FakeQuantWithMinMaxVars : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FakeQuantWithMinMaxVars, PrimitiveC);
FakeQuantWithMinMaxVars() = default;
explicit FakeQuantWithMinMaxVars(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetNarrowRange(bool narrow_range);
void SetNumBits(int num_bits);
#else
explicit FakeQuantWithMinMaxVars(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
bool GetNarrowRange() const;
int GetNumBits() const;
void SetNarrowRange(bool narrow_range);
void SetNumBits(int num_bits);
};
} // namespace lite
} // namespace mindspore

View File

@ -30,7 +30,6 @@ std::vector<int> Fill::GetDims() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void Fill::SetDims(const std::vector<int> &dims) {}
#endif
int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {

View File

@ -28,14 +28,16 @@ namespace lite {
class Fill : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Fill, PrimitiveC);
Fill() = default;
explicit Fill(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetDims(const std::vector<int> &dims);
#else
explicit Fill(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetDims() const;
void SetDims(const std::vector<int> &dims);
};
} // namespace lite
} // namespace mindspore

View File

@ -51,10 +51,30 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
#ifdef PRIMITIVE_WRITEABLE
int Flatten::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::FlattenT>();
this->primitive_->value.type = schema::PrimitiveType_Flatten;
this->primitive_->value.value = attr.release();
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Flatten;
}
if (this->primitive_->value.type != schema::PrimitiveType_Flatten) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::FlattenT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#endif

View File

@ -28,6 +28,7 @@ namespace lite {
class Flatten : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Flatten, PrimitiveC);
Flatten() = default;
explicit Flatten(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else

View File

@ -28,6 +28,7 @@ namespace lite {
class Floor : public ArithmeticSelf {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Floor, ArithmeticSelf);
Floor() = default;
explicit Floor(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else

View File

@ -28,6 +28,7 @@ namespace lite {
class FloorDiv : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FloorDiv, Arithmetic);
FloorDiv() = default;
explicit FloorDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else

View File

@ -28,6 +28,7 @@ namespace lite {
class FloorMod : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FloorMod, Arithmetic);
FloorMod() = default;
explicit FloorMod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else

View File

@ -37,10 +37,6 @@ int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConn
bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); }
int FullConnection::GetActivationType() const { return this->primitive_->value_as_FullConnection()->activationType(); }
void FullConnection::SetHasBias(bool has_bias) {}
void FullConnection::SetAxis(int axis) {}
void FullConnection::SetUseAxis(bool use_axis) {}
void FullConnection::SetActivationType(int activationType) {}
#endif
int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {

View File

@ -28,8 +28,13 @@ namespace lite {
class FullConnection : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FullConnection, PrimitiveC);
FullConnection() = default;
explicit FullConnection(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetHasBias(bool has_bias);
void SetAxis(int axis);
void SetUseAxis(bool use_axis);
void SetActivationType(int activationType);
#else
explicit FullConnection(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
@ -38,10 +43,6 @@ class FullConnection : public PrimitiveC {
int GetAxis() const;
bool GetUseAxis() const;
int GetActivationType() const;
void SetHasBias(bool has_bias);
void SetAxis(int axis);
void SetUseAxis(bool use_axis);
void SetActivationType(int activationType);
};
} // namespace lite
} // namespace mindspore

View File

@ -33,9 +33,6 @@ float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_Fus
float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); }
int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); }
void FusedBatchNorm::SetEpsilon(float epsilon) {}
void FusedBatchNorm::SetMomentum(float momentum) {}
void FusedBatchNorm::SetSpatial(int spatial) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,17 +28,18 @@ namespace lite {
class FusedBatchNorm : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FusedBatchNorm, PrimitiveC);
FusedBatchNorm() = default;
explicit FusedBatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetEpsilon(float epsilon);
void SetMomentum(float momentum);
void SetSpatial(int spatial);
#else
explicit FusedBatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetEpsilon() const;
float GetMomentum() const;
int GetSpatial() const;
void SetEpsilon(float epsilon);
void SetMomentum(float momentum);
void SetSpatial(int spatial);
};
} // namespace lite
} // namespace mindspore

View File

@ -33,8 +33,6 @@ void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->
int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); }
int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); }
void Gather::SetAxis(int axis) {}
void Gather::SetBatchDims(int batch_dims) {}
#endif
int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {

View File

@ -28,16 +28,17 @@ namespace lite {
class Gather : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Gather, PrimitiveC);
Gather() = default;
explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(int axis);
void SetBatchDims(int batch_dims);
#else
explicit Gather(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;
int GetBatchDims() const;
void SetAxis(int axis);
void SetBatchDims(int batch_dims);
};
} // namespace lite
} // namespace mindspore

View File

@ -27,7 +27,6 @@ void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd
int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); }
void GatherNd::SetBatchDims(int batch_dims) {}
#endif
int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {

View File

@ -28,14 +28,16 @@ namespace lite {
class GatherNd : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GatherNd, PrimitiveC);
GatherNd() = default;
explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBatchDims(int batch_dims);
#else
explicit GatherNd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetBatchDims() const;
void SetBatchDims(int batch_dims);
};
} // namespace lite
} // namespace mindspore

View File

@ -27,6 +27,7 @@ namespace lite {
class Greater : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Greater, Arithmetic);
Greater() = default;
explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else

View File

@ -28,6 +28,7 @@ namespace lite {
class GreaterEqual : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GreaterEqual, Arithmetic);
GreaterEqual() = default;
explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else

View File

@ -33,8 +33,6 @@ std::vector<int> L2Norm::GetAxis() const {
}
float L2Norm::GetEpsilon() const { return this->primitive_->value_as_L2Norm()->epsilon(); }
void L2Norm::SetAxis(const std::vector<int> &axis) {}
void L2Norm::SetEpsilon(float epsilon) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,15 +28,16 @@ namespace lite {
class L2Norm : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(L2Norm, PrimitiveC);
L2Norm() = default;
explicit L2Norm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(const std::vector<int> &axis);
void SetEpsilon(float epsilon);
#else
explicit L2Norm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
std::vector<int> GetAxis() const;
float GetEpsilon() const;
void SetAxis(const std::vector<int> &axis);
void SetEpsilon(float epsilon);
};
} // namespace lite
} // namespace mindspore

View File

@ -29,7 +29,6 @@ void LeakyReLU::SetNegativeSlope(float negative_slope) {
float LeakyReLU::GetNegativeSlope() const { return this->primitive_->value_as_LeakyReLU()->negativeSlope(); }
void LeakyReLU::SetNegativeSlope(float negative_slope) {}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -28,13 +28,15 @@ namespace lite {
class LeakyReLU : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(LeakyReLU, PrimitiveC);
LeakyReLU() = default;
explicit LeakyReLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetNegativeSlope(float negative_slope);
#else
explicit LeakyReLU(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetNegativeSlope() const;
void SetNegativeSlope(float negative_slope);
};
} // namespace lite
} // namespace mindspore

View File

@ -28,6 +28,7 @@ namespace lite {
class Less : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Less, Arithmetic);
Less() = default;
explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else

View File

@ -28,6 +28,7 @@ namespace lite {
class LessEqual : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(LessEqual, Arithmetic);
LessEqual() = default;
explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else

View File

@ -60,10 +60,6 @@ float LocalResponseNormalization::GetBeta() const {
return this->primitive_->value_as_LocalResponseNormalization()->beta();
}
void LocalResponseNormalization::SetDepthRadius(int depth_radius) {}
void LocalResponseNormalization::SetBias(float bias) {}
void LocalResponseNormalization::SetAlpha(float alpha) {}
void LocalResponseNormalization::SetBeta(float beta) {}
#endif
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More