forked from mindspore-Ecosystem/mindspore
Add MS_DECLARE_PARENT and UnPackAttr judge null and change setter position
This commit is contained in:
parent
80d570f003
commit
6c1eb3c22d
|
@ -19,11 +19,6 @@
|
|||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
#include "schema/model_generated.h"
|
||||
#endif
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_ABS_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_ABS_H_
|
||||
|
@ -33,6 +28,7 @@ namespace lite {
|
|||
class Abs : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Abs, ArithmeticSelf);
|
||||
Abs() = default;
|
||||
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -27,7 +27,18 @@ void Activation::SetType(int type) { this->primitive_->value.AsActivation()->typ
|
|||
void Activation::SetAlpha(float alpha) { this->primitive_->value.AsActivation()->alpha = alpha; }
|
||||
|
||||
int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Activation;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Activation) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::ActivationT>();
|
||||
if (prim.name() == "ReLU") {
|
||||
attr->type = schema::ActivationType_RELU;
|
||||
|
@ -36,18 +47,17 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
|||
} else if (prim.name() == "ReLU6") {
|
||||
attr->type = schema::ActivationType_RELU6;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Activation;
|
||||
this->primitive_->value.value = attr.release();
|
||||
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
|
||||
int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); }
|
||||
float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); }
|
||||
|
||||
void Activation::SetType(int type) {}
|
||||
void Activation::SetAlpha(float alpha) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,16 +27,17 @@ namespace lite {
|
|||
class Activation : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Activation, PrimitiveC);
|
||||
Activation() = default;
|
||||
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
void SetType(int type);
|
||||
void SetAlpha(float alpha);
|
||||
#else
|
||||
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int GetType() const;
|
||||
float GetAlpha() const;
|
||||
void SetType(int type);
|
||||
void SetAlpha(float alpha);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,7 +29,6 @@ void ActivationGrad::SetType(int type) {
|
|||
|
||||
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
|
||||
|
||||
void ActivationGrad::SetType(int type) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,13 +28,14 @@ namespace lite {
|
|||
class ActivationGrad : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(ActivationGrad, PrimitiveC);
|
||||
ActivationGrad() = default;
|
||||
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetType(int type);
|
||||
#else
|
||||
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int GetType() const;
|
||||
void SetType(int type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,7 +36,7 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
|
|||
this->primitive_->value.type = schema::PrimitiveType_Add;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Add) {
|
||||
MS_LOG(ERROR) << "Primitive type should be add";
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
|
@ -53,7 +53,6 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
|
|||
|
||||
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }
|
||||
|
||||
void Add::SetActivationType(int activation_type) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,25 +22,21 @@
|
|||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/arithmetic.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "schema/inner/model_generated.h"
|
||||
#else
|
||||
#include "schema/model_generated.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Add : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Add, Arithmetic);
|
||||
Add() = default;
|
||||
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
int GetActivationType() const;
|
||||
void SetActivationType(int activation_type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,6 @@ void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; }
|
|||
|
||||
int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); }
|
||||
|
||||
void AddN::SetN(int n) {}
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -28,14 +28,15 @@ namespace lite {
|
|||
class AddN : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(AddN, PrimitiveC);
|
||||
AddN() = default;
|
||||
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetN(int n);
|
||||
#else
|
||||
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetN() const;
|
||||
void SetN(int n);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,11 +39,6 @@ int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK()
|
|||
bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); }
|
||||
int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); }
|
||||
|
||||
void ArgMax::SetAxis(int axis) {}
|
||||
void ArgMax::SetOutMaxValue(bool out_max_value) {}
|
||||
void ArgMax::SetTopK(int top_k) {}
|
||||
void ArgMax::SetKeepDims(bool keep_dims) {}
|
||||
void ArgMax::SetAxisType(int axis_type) {}
|
||||
#endif
|
||||
|
||||
int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,8 +28,14 @@ namespace lite {
|
|||
class ArgMax : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(ArgMax, PrimitiveC);
|
||||
ArgMax() = default;
|
||||
explicit ArgMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetAxis(int axis);
|
||||
void SetOutMaxValue(bool out_max_value);
|
||||
void SetTopK(int top_k);
|
||||
void SetKeepDims(bool keep_dims);
|
||||
void SetAxisType(int axis_type);
|
||||
#else
|
||||
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
@ -39,11 +45,6 @@ class ArgMax : public PrimitiveC {
|
|||
int GetTopK() const;
|
||||
bool GetKeepDims() const;
|
||||
int GetAxisType() const;
|
||||
void SetAxis(int axis);
|
||||
void SetOutMaxValue(bool out_max_value);
|
||||
void SetTopK(int top_k);
|
||||
void SetKeepDims(bool keep_dims);
|
||||
void SetAxisType(int axis_type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,11 +39,6 @@ int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK()
|
|||
bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); }
|
||||
int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); }
|
||||
|
||||
void ArgMin::SetAxis(int axis) {}
|
||||
void ArgMin::SetOutMaxValue(bool out_max_value) {}
|
||||
void ArgMin::SetTopK(int top_k) {}
|
||||
void ArgMin::SetKeepDims(bool keep_dims) {}
|
||||
void ArgMin::SetAxisType(int axis_type) {}
|
||||
#endif
|
||||
|
||||
int ArgMin::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,8 +28,14 @@ namespace lite {
|
|||
class ArgMin : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(ArgMin, PrimitiveC);
|
||||
ArgMin() = default;
|
||||
explicit ArgMin(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetAxis(int axis);
|
||||
void SetOutMaxValue(bool out_max_value);
|
||||
void SetTopK(int top_k);
|
||||
void SetKeepDims(bool keep_dims);
|
||||
void SetAxisType(int axis_type);
|
||||
#else
|
||||
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
@ -39,11 +45,6 @@ class ArgMin : public PrimitiveC {
|
|||
int GetTopK() const;
|
||||
bool GetKeepDims() const;
|
||||
int GetAxisType() const;
|
||||
void SetAxis(int axis);
|
||||
void SetOutMaxValue(bool out_max_value);
|
||||
void SetTopK(int top_k);
|
||||
void SetKeepDims(bool keep_dims);
|
||||
void SetAxisType(int axis_type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
class Arithmetic : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Arithmetic, PrimitiveC);
|
||||
Arithmetic() = default;
|
||||
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -25,6 +25,7 @@ namespace lite {
|
|||
class ArithmeticSelf : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(ArithmeticSelf, PrimitiveC);
|
||||
ArithmeticSelf() = default;
|
||||
explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -24,11 +24,27 @@ float BatchNorm::GetEpsilon() const { return this->primitive_->value.AsBatchNorm
|
|||
void BatchNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsBatchNorm()->epsilon = epsilon; }
|
||||
|
||||
int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::FusedBatchNormT>();
|
||||
attr->epsilon = GetValue<float>(prim.GetAttr("epsilon"));
|
||||
this->primitive_->value.type = schema::PrimitiveType_FusedBatchNorm;
|
||||
this->primitive_->value.value = attr.release();
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_FusedBatchNorm;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_FusedBatchNorm) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::FusedBatchNormT();
|
||||
attr->epsilon = GetValue<float>(prim.GetAttr("epsilon"));
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -36,7 +52,6 @@ int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &
|
|||
|
||||
float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); }
|
||||
|
||||
void BatchNorm::SetEpsilon(float epsilon) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,14 +28,15 @@ namespace lite {
|
|||
class BatchNorm : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(BatchNorm, PrimitiveC);
|
||||
BatchNorm() = default;
|
||||
explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
void SetEpsilon(float epsilon);
|
||||
#else
|
||||
explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
float GetEpsilon() const;
|
||||
void SetEpsilon(float epsilon);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,8 +42,6 @@ std::vector<int> BatchToSpace::GetCrops() const {
|
|||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {}
|
||||
void BatchToSpace::SetCrops(const std::vector<int> &crops) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kBatchToSpaceOutputNum = 1;
|
||||
|
|
|
@ -28,16 +28,17 @@ namespace lite {
|
|||
class BatchToSpace : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(BatchToSpace, PrimitiveC);
|
||||
BatchToSpace() = default;
|
||||
explicit BatchToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetBlockShape(const std::vector<int> &block_shape);
|
||||
void SetCrops(const std::vector<int> &crops);
|
||||
#else
|
||||
explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetBlockShape() const;
|
||||
std::vector<int> GetCrops() const;
|
||||
void SetBlockShape(const std::vector<int> &block_shape);
|
||||
void SetCrops(const std::vector<int> &crops);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,12 +25,31 @@ std::vector<int> BiasAdd::GetAxis() const { return this->primitive_->value.AsBia
|
|||
void BiasAdd::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; }
|
||||
|
||||
int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::BiasAddT>();
|
||||
attr->axis = {0};
|
||||
this->primitive_->value.type = schema::PrimitiveType_BiasAdd;
|
||||
this->primitive_->value.value = attr.release();
|
||||
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_BiasAdd;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_BiasAdd) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::BiasAddT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->axis = {0};
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -41,7 +60,6 @@ std::vector<int> BiasAdd::GetAxis() const {
|
|||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void BiasAdd::SetAxis(const std::vector<int> &axis) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,14 +28,15 @@ namespace lite {
|
|||
class BiasAdd : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(BiasAdd, PrimitiveC);
|
||||
BiasAdd() = default;
|
||||
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
#else
|
||||
explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
std::vector<int> GetAxis() const;
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,6 @@ std::vector<int> BiasGrad::GetAxis() const {
|
|||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void BiasGrad::SetAxis(const std::vector<int> &axis) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,13 +28,15 @@ namespace lite {
|
|||
class BiasGrad : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(BiasGrad, PrimitiveC);
|
||||
BiasGrad() = default;
|
||||
explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
|
||||
#else
|
||||
explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
std::vector<int> GetAxis() const;
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,8 +30,6 @@ void BNGradInput::SetChannels(int channels) { this->primitive_->value.AsBNGradIn
|
|||
float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); }
|
||||
int BNGradInput::GetChannels() const { return this->primitive_->value_as_BNGradInput()->channels(); }
|
||||
|
||||
void BNGradInput::SetEps(float eps) {}
|
||||
void BNGradInput::SetChannels(int channels) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,15 +28,16 @@ namespace lite {
|
|||
class BNGradInput : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(BNGradInput, PrimitiveC);
|
||||
BNGradInput() = default;
|
||||
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetEps(float eps);
|
||||
void SetChannels(int channels);
|
||||
#else
|
||||
explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
float GetEps() const;
|
||||
int GetChannels() const;
|
||||
void SetEps(float eps);
|
||||
void SetChannels(int channels);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,7 +32,6 @@ std::vector<int> BroadcastTo::GetDstShape() const {
|
|||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kBroadcastToInputNum = 1;
|
||||
|
|
|
@ -28,14 +28,16 @@ namespace lite {
|
|||
class BroadcastTo : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(BroadcastTo, PrimitiveC);
|
||||
BroadcastTo() = default;
|
||||
explicit BroadcastTo(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetDstShape(const std::vector<int> &dst_shape);
|
||||
|
||||
#else
|
||||
explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetDstShape() const;
|
||||
void SetDstShape(const std::vector<int> &dst_shape);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,7 +29,6 @@ void CaffePReLU::SetChannelShared(bool channel_shared) {
|
|||
|
||||
bool CaffePReLU::GetChannelShared() const { return this->primitive_->value_as_CaffePReLU()->channelShared(); }
|
||||
|
||||
void CaffePReLU::SetChannelShared(bool channel_shared) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,13 +28,15 @@ namespace lite {
|
|||
class CaffePReLU : public Activation {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(CaffePReLU, Activation);
|
||||
CaffePReLU() = default;
|
||||
explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {}
|
||||
void SetChannelShared(bool channel_shared);
|
||||
|
||||
#else
|
||||
explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {}
|
||||
#endif
|
||||
bool GetChannelShared() const;
|
||||
void SetChannelShared(bool channel_shared);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,8 +30,6 @@ void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t;
|
|||
int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); }
|
||||
int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); }
|
||||
|
||||
void Cast::SetSrcT(int src_t) {}
|
||||
void Cast::SetDstT(int dst_t) {}
|
||||
#endif
|
||||
|
||||
int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,16 +28,17 @@ namespace lite {
|
|||
class Cast : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Cast, PrimitiveC);
|
||||
Cast() = default;
|
||||
explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetSrcT(int src_t);
|
||||
void SetDstT(int dst_t);
|
||||
#else
|
||||
explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetSrcT() const;
|
||||
int GetDstT() const;
|
||||
void SetSrcT(int src_t);
|
||||
void SetDstT(int dst_t);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,14 +20,15 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Ceil : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Ceil, ArithmeticSelf);
|
||||
Ceil() = default;
|
||||
explicit Ceil(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -30,8 +30,6 @@ void Clip::SetMin(float min) { this->primitive_->value.AsClip()->min = min; }
|
|||
float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); }
|
||||
float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); }
|
||||
|
||||
void Clip::SetMax(float max) {}
|
||||
void Clip::SetMin(float min) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,15 +28,16 @@ namespace lite {
|
|||
class Clip : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Clip, PrimitiveC);
|
||||
Clip() = default;
|
||||
explicit Clip(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetMax(float max);
|
||||
void SetMin(float min);
|
||||
#else
|
||||
explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
float GetMax() const;
|
||||
float GetMin() const;
|
||||
void SetMax(float max);
|
||||
void SetMin(float min);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,12 +30,32 @@ void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis
|
|||
void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; }
|
||||
|
||||
int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::ConcatT>();
|
||||
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
|
||||
attr->axis = prim_axis;
|
||||
this->primitive_->value.type = schema::PrimitiveType_Concat;
|
||||
this->primitive_->value.value = attr.release();
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Concat;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Concat) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::ConcatT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
|
||||
attr->axis = prim_axis;
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -44,8 +64,6 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
|
||||
int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); }
|
||||
|
||||
void Concat::SetAxis(int axis) {}
|
||||
void Concat::SetN(int n) {}
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -28,17 +28,18 @@ namespace lite {
|
|||
class Concat : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Concat, PrimitiveC);
|
||||
Concat() = default;
|
||||
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
void SetAxis(int axis);
|
||||
void SetN(int n);
|
||||
#else
|
||||
explicit Concat(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
int GetN() const;
|
||||
void SetAxis(int axis);
|
||||
void SetN(int n);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,7 +33,6 @@ void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstant
|
|||
|
||||
float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); }
|
||||
|
||||
void ConstantOfShape::SetValue(float value) {}
|
||||
#endif
|
||||
|
||||
int ConstantOfShape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,14 +28,15 @@ namespace lite {
|
|||
class ConstantOfShape : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(ConstantOfShape, PrimitiveC);
|
||||
ConstantOfShape() = default;
|
||||
explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetValue(float value);
|
||||
#else
|
||||
explicit ConstantOfShape(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
float GetValue() const;
|
||||
void SetValue(float value);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include <memory>
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
|
@ -309,8 +308,18 @@ void Conv2D::PopulaterQuantParam(const Primitive &prim,
|
|||
}
|
||||
|
||||
int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Conv2D;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Conv2D) {
|
||||
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int group = GetValue<int>(prim.GetAttr("group"));
|
||||
if (group > 1) {
|
||||
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
|
||||
|
@ -348,23 +357,6 @@ int Conv2D::GetDilateH() const { return this->primitive_->value_as_Conv2D()->dil
|
|||
bool Conv2D::GetHasBias() const { return this->primitive_->value_as_Conv2D()->hasBias(); }
|
||||
int Conv2D::GetActivationType() const { return this->primitive_->value_as_Conv2D()->activationType(); }
|
||||
|
||||
void Conv2D::SetFormat(int format) {}
|
||||
void Conv2D::SetGroup(int group) {}
|
||||
void Conv2D::SetChannelIn(int channel_in) {}
|
||||
void Conv2D::SetChannelOut(int channel_out) {}
|
||||
void Conv2D::SetKernelW(int kernel_w) {}
|
||||
void Conv2D::SetKernelH(int kernel_h) {}
|
||||
void Conv2D::SetStrideW(int stride_w) {}
|
||||
void Conv2D::SetStrideH(int stride_h) {}
|
||||
void Conv2D::SetPadMode(int pad_mode) {}
|
||||
void Conv2D::SetPadUp(int pad_up) {}
|
||||
void Conv2D::SetPadDown(int pad_down) {}
|
||||
void Conv2D::SetPadLeft(int pad_left) {}
|
||||
void Conv2D::SetPadRight(int pad_right) {}
|
||||
void Conv2D::SetDilateW(int dilate_w) {}
|
||||
void Conv2D::SetDilateH(int dilate_h) {}
|
||||
void Conv2D::SetHasBias(bool has_bias) {}
|
||||
void Conv2D::SetActivationType(int activation_type) {}
|
||||
#endif
|
||||
void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
|
|
|
@ -28,12 +28,30 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
class Conv2D : public PrimitiveC {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Conv2D, PrimitiveC);
|
||||
|
||||
public:
|
||||
Conv2D() = default;
|
||||
explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
|
||||
private:
|
||||
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
|
||||
|
@ -72,23 +90,6 @@ class Conv2D : public PrimitiveC {
|
|||
int GetDilateH() const;
|
||||
bool GetHasBias() const;
|
||||
int GetActivationType() const;
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
|
||||
protected:
|
||||
void ConvInferShape(int input_h, int input_w, int *output_h, int *output_w);
|
||||
|
|
|
@ -89,23 +89,6 @@ int Conv2DGradFilter::GetActivationType() const {
|
|||
return this->primitive_->value_as_Conv2DGradFilter()->activationType();
|
||||
}
|
||||
|
||||
void Conv2DGradFilter::SetFormat(int format) {}
|
||||
void Conv2DGradFilter::SetGroup(int group) {}
|
||||
void Conv2DGradFilter::SetChannelIn(int channel_in) {}
|
||||
void Conv2DGradFilter::SetChannelOut(int channel_out) {}
|
||||
void Conv2DGradFilter::SetKernelW(int kernel_w) {}
|
||||
void Conv2DGradFilter::SetKernelH(int kernel_h) {}
|
||||
void Conv2DGradFilter::SetStrideW(int stride_w) {}
|
||||
void Conv2DGradFilter::SetStrideH(int stride_h) {}
|
||||
void Conv2DGradFilter::SetPadMode(int pad_mode) {}
|
||||
void Conv2DGradFilter::SetPadUp(int pad_up) {}
|
||||
void Conv2DGradFilter::SetPadDown(int pad_down) {}
|
||||
void Conv2DGradFilter::SetPadLeft(int pad_left) {}
|
||||
void Conv2DGradFilter::SetPadRight(int pad_right) {}
|
||||
void Conv2DGradFilter::SetDilateW(int dilate_w) {}
|
||||
void Conv2DGradFilter::SetDilateH(int dilate_h) {}
|
||||
void Conv2DGradFilter::SetHasBias(bool has_bias) {}
|
||||
void Conv2DGradFilter::SetActivationType(int activation_type) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,8 +28,26 @@ namespace lite {
|
|||
class Conv2DGradFilter : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Conv2DGradFilter, PrimitiveC);
|
||||
Conv2DGradFilter() = default;
|
||||
explicit Conv2DGradFilter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
explicit Conv2DGradFilter(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
@ -50,23 +68,6 @@ class Conv2DGradFilter : public PrimitiveC {
|
|||
int GetDilateH() const;
|
||||
bool GetHasBias() const;
|
||||
int GetActivationType() const;
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -87,23 +87,6 @@ int Conv2DGradInput::GetActivationType() const {
|
|||
return this->primitive_->value_as_Conv2DGradInput()->activationType();
|
||||
}
|
||||
|
||||
void Conv2DGradInput::SetFormat(int format) {}
|
||||
void Conv2DGradInput::SetGroup(int group) {}
|
||||
void Conv2DGradInput::SetChannelIn(int channel_in) {}
|
||||
void Conv2DGradInput::SetChannelOut(int channel_out) {}
|
||||
void Conv2DGradInput::SetKernelW(int kernel_w) {}
|
||||
void Conv2DGradInput::SetKernelH(int kernel_h) {}
|
||||
void Conv2DGradInput::SetStrideW(int stride_w) {}
|
||||
void Conv2DGradInput::SetStrideH(int stride_h) {}
|
||||
void Conv2DGradInput::SetPadMode(int pad_mode) {}
|
||||
void Conv2DGradInput::SetPadUp(int pad_up) {}
|
||||
void Conv2DGradInput::SetPadDown(int pad_down) {}
|
||||
void Conv2DGradInput::SetPadLeft(int pad_left) {}
|
||||
void Conv2DGradInput::SetPadRight(int pad_right) {}
|
||||
void Conv2DGradInput::SetDilateW(int dilate_w) {}
|
||||
void Conv2DGradInput::SetDilateH(int dilate_h) {}
|
||||
void Conv2DGradInput::SetHasBias(bool has_bias) {}
|
||||
void Conv2DGradInput::SetActivationType(int activation_type) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,8 +28,26 @@ namespace lite {
|
|||
class Conv2DGradInput : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Conv2DGradInput, PrimitiveC);
|
||||
Conv2DGradInput() = default;
|
||||
explicit Conv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
explicit Conv2DGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
@ -50,23 +68,6 @@ class Conv2DGradInput : public PrimitiveC {
|
|||
int GetDilateH() const;
|
||||
bool GetHasBias() const;
|
||||
int GetActivationType() const;
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,8 +33,6 @@ std::vector<int64_t> Crop::GetOffsets() const {
|
|||
return std::vector<int64_t>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void Crop::SetAxis(int64_t axis) {}
|
||||
void Crop::SetOffsets(const std::vector<int64_t> &offsets) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kCropOutputNum = 1;
|
||||
|
|
|
@ -28,16 +28,17 @@ namespace lite {
|
|||
class Crop : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Crop, PrimitiveC);
|
||||
Crop() = default;
|
||||
explicit Crop(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetAxis(int64_t axis);
|
||||
void SetOffsets(const std::vector<int64_t> &offsets);
|
||||
#else
|
||||
explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int64_t GetAxis() const;
|
||||
std::vector<int64_t> GetOffsets() const;
|
||||
void SetAxis(int64_t axis);
|
||||
void SetOffsets(const std::vector<int64_t> &offsets);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -77,23 +77,6 @@ int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()-
|
|||
bool DeConv2D::GetHasBias() const { return this->primitive_->value_as_DeConv2D()->hasBias(); }
|
||||
int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); }
|
||||
|
||||
void DeConv2D::SetFormat(int format) {}
|
||||
void DeConv2D::SetGroup(int group) {}
|
||||
void DeConv2D::SetChannelIn(int channel_in) {}
|
||||
void DeConv2D::SetChannelOut(int channel_out) {}
|
||||
void DeConv2D::SetKernelW(int kernel_w) {}
|
||||
void DeConv2D::SetKernelH(int kernel_h) {}
|
||||
void DeConv2D::SetStrideW(int stride_w) {}
|
||||
void DeConv2D::SetStrideH(int stride_h) {}
|
||||
void DeConv2D::SetPadMode(int pad_mode) {}
|
||||
void DeConv2D::SetPadUp(int pad_up) {}
|
||||
void DeConv2D::SetPadDown(int pad_down) {}
|
||||
void DeConv2D::SetPadLeft(int pad_left) {}
|
||||
void DeConv2D::SetPadRight(int pad_right) {}
|
||||
void DeConv2D::SetDilateW(int dilate_w) {}
|
||||
void DeConv2D::SetDilateH(int dilate_h) {}
|
||||
void DeConv2D::SetHasBias(bool has_bias) {}
|
||||
void DeConv2D::SetActivationType(int activation_type) {}
|
||||
#endif
|
||||
int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
|
|
|
@ -28,8 +28,26 @@ namespace lite {
|
|||
class DeConv2D : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(DeConv2D, PrimitiveC);
|
||||
DeConv2D() = default;
|
||||
explicit DeConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
explicit DeConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
@ -51,23 +69,6 @@ class DeConv2D : public PrimitiveC {
|
|||
int GetDilateH() const;
|
||||
bool GetHasBias() const;
|
||||
int GetActivationType() const;
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
|
||||
int PadUp() const { return this->pad_u_; }
|
||||
int PadDown() const { return this->pad_d_; }
|
||||
|
|
|
@ -92,22 +92,6 @@ int DeDepthwiseConv2D::GetActivationType() const {
|
|||
return this->primitive_->value_as_DeDepthwiseConv2D()->activationType();
|
||||
}
|
||||
|
||||
void DeDepthwiseConv2D::SetFormat(int format) {}
|
||||
void DeDepthwiseConv2D::SetChannelIn(int channel_in) {}
|
||||
void DeDepthwiseConv2D::SetChannelMultiplier(int channel_multiplier) {}
|
||||
void DeDepthwiseConv2D::SetKernelW(int kernel_w) {}
|
||||
void DeDepthwiseConv2D::SetKernelH(int kernel_h) {}
|
||||
void DeDepthwiseConv2D::SetStrideW(int stride_w) {}
|
||||
void DeDepthwiseConv2D::SetStrideH(int stride_h) {}
|
||||
void DeDepthwiseConv2D::SetPadMode(int pad_mode) {}
|
||||
void DeDepthwiseConv2D::SetPadUp(int pad_up) {}
|
||||
void DeDepthwiseConv2D::SetPadDown(int pad_down) {}
|
||||
void DeDepthwiseConv2D::SetPadLeft(int pad_left) {}
|
||||
void DeDepthwiseConv2D::SetPadRight(int pad_right) {}
|
||||
void DeDepthwiseConv2D::SetDilateW(int dilate_w) {}
|
||||
void DeDepthwiseConv2D::SetDilateH(int dilate_h) {}
|
||||
void DeDepthwiseConv2D::SetHasBias(bool has_bias) {}
|
||||
void DeDepthwiseConv2D::SetActivationType(int activation_type) {}
|
||||
#endif
|
||||
int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
||||
std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,8 +28,25 @@ namespace lite {
|
|||
class DeDepthwiseConv2D : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(DeDepthwiseConv2D, PrimitiveC);
|
||||
DeDepthwiseConv2D() = default;
|
||||
explicit DeDepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetFormat(int format);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelMultiplier(int channel_multiplier);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
explicit DeDepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
@ -50,22 +67,6 @@ class DeDepthwiseConv2D : public PrimitiveC {
|
|||
int GetDilateH() const;
|
||||
bool GetHasBias() const;
|
||||
int GetActivationType() const;
|
||||
void SetFormat(int format);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelMultiplier(int channel_multiplier);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
|
||||
int PadUp() const { return this->pad_u_; }
|
||||
int PadDown() const { return this->pad_d_; }
|
||||
|
|
|
@ -30,8 +30,6 @@ void DepthToSpace::SetFormat(int format) { this->primitive_->value.AsDepthToSpac
|
|||
int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); }
|
||||
int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); }
|
||||
|
||||
void DepthToSpace::SetBlockSize(int block_size) {}
|
||||
void DepthToSpace::SetFormat(int format) {}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kDepthToSpaceOutputNum = 1;
|
||||
|
|
|
@ -28,16 +28,17 @@ namespace lite {
|
|||
class DepthToSpace : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(DepthToSpace, PrimitiveC);
|
||||
DepthToSpace() = default;
|
||||
explicit DepthToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetBlockSize(int block_size);
|
||||
void SetFormat(int format);
|
||||
#else
|
||||
explicit DepthToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetBlockSize() const;
|
||||
int GetFormat() const;
|
||||
void SetBlockSize(int block_size);
|
||||
void SetFormat(int format);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -254,22 +254,6 @@ int DepthwiseConv2D::GetActivationType() const {
|
|||
return this->primitive_->value_as_DepthwiseConv2D()->activationType();
|
||||
}
|
||||
|
||||
void DepthwiseConv2D::SetFormat(int format) {}
|
||||
void DepthwiseConv2D::SetChannelIn(int channel_in) {}
|
||||
void DepthwiseConv2D::SetChannelMultiplier(int channel_multiplier) {}
|
||||
void DepthwiseConv2D::SetKernelW(int kernel_w) {}
|
||||
void DepthwiseConv2D::SetKernelH(int kernel_h) {}
|
||||
void DepthwiseConv2D::SetStrideW(int stride_w) {}
|
||||
void DepthwiseConv2D::SetStrideH(int stride_h) {}
|
||||
void DepthwiseConv2D::SetPadMode(int pad_mode) {}
|
||||
void DepthwiseConv2D::SetPadUp(int pad_up) {}
|
||||
void DepthwiseConv2D::SetPadDown(int pad_down) {}
|
||||
void DepthwiseConv2D::SetPadLeft(int pad_left) {}
|
||||
void DepthwiseConv2D::SetPadRight(int pad_right) {}
|
||||
void DepthwiseConv2D::SetDilateW(int dilate_w) {}
|
||||
void DepthwiseConv2D::SetDilateH(int dilate_h) {}
|
||||
void DepthwiseConv2D::SetHasBias(bool has_bias) {}
|
||||
void DepthwiseConv2D::SetActivationType(int activation_type) {}
|
||||
#endif
|
||||
int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
||||
std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -27,12 +27,29 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
class DepthwiseConv2D : public PrimitiveC {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(DepthwiseConv2D, PrimitiveC);
|
||||
|
||||
public:
|
||||
DepthwiseConv2D() = default;
|
||||
explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
void SetFormat(int format);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelMultiplier(int channel_multiplier);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
|
||||
private:
|
||||
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
|
@ -62,22 +79,6 @@ class DepthwiseConv2D : public PrimitiveC {
|
|||
int GetDilateH() const;
|
||||
bool GetHasBias() const;
|
||||
int GetActivationType() const;
|
||||
void SetFormat(int format);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelMultiplier(int channel_multiplier);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
|
||||
int PadUp() const { return this->pad_u_; }
|
||||
int PadDown() const { return this->pad_d_; }
|
||||
|
|
|
@ -21,10 +21,30 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Dequant::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::OnnxInt8DequantizeT>();
|
||||
this->primitive_->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
|
||||
this->primitive_->value.value = attr.release();
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_OnnxInt8Dequantize) {
|
||||
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow)(schema::OnnxInt8DequantizeT);
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "attr is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -25,6 +25,7 @@ namespace lite {
|
|||
class Dequant : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Dequant, PrimitiveC);
|
||||
Dequant() = default;
|
||||
explicit Dequant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
|
|
|
@ -119,19 +119,6 @@ bool DetectionPostProcess::GetUseRegularNms() const {
|
|||
return this->primitive_->value_as_DetectionPostProcess()->UseRegularNms();
|
||||
}
|
||||
|
||||
void DetectionPostProcess::SetFormat(int format) {}
|
||||
void DetectionPostProcess::SetInputSize(int input_size) {}
|
||||
void DetectionPostProcess::SetHScale(float h_scale) {}
|
||||
void DetectionPostProcess::SetWScale(float w_scale) {}
|
||||
void DetectionPostProcess::SetXScale(float x_scale) {}
|
||||
void DetectionPostProcess::SetYScale(float y_scale) {}
|
||||
void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) {}
|
||||
void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) {}
|
||||
void DetectionPostProcess::SetMaxDetections(int64_t max_detections) {}
|
||||
void DetectionPostProcess::SetDetectionsPreClass(int64_t detections_pre_class) {}
|
||||
void DetectionPostProcess::SetMaxClassesPreDetection(int64_t max_classes_pre_detection) {}
|
||||
void DetectionPostProcess::SetNumClasses(int64_t num_classes) {}
|
||||
void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,8 +28,22 @@ namespace lite {
|
|||
class DetectionPostProcess : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(DetectionPostProcess, PrimitiveC);
|
||||
DetectionPostProcess() = default;
|
||||
explicit DetectionPostProcess(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetFormat(int format);
|
||||
void SetInputSize(int input_size);
|
||||
void SetHScale(float h_scale);
|
||||
void SetWScale(float w_scale);
|
||||
void SetXScale(float x_scale);
|
||||
void SetYScale(float y_scale);
|
||||
void SetNmsIouThreshold(float nms_iou_threshold);
|
||||
void SetNmsScoreThreshold(float nms_score_threshold);
|
||||
void SetMaxDetections(int64_t max_detections);
|
||||
void SetDetectionsPreClass(int64_t detections_pre_class);
|
||||
void SetMaxClassesPreDetection(int64_t max_classes_pre_detection);
|
||||
void SetNumClasses(int64_t num_classes);
|
||||
void SetUseRegularNms(bool use_regular_nms);
|
||||
#else
|
||||
explicit DetectionPostProcess(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
@ -46,19 +60,6 @@ class DetectionPostProcess : public PrimitiveC {
|
|||
int64_t GetMaxClassesPreDetection() const;
|
||||
int64_t GetNumClasses() const;
|
||||
bool GetUseRegularNms() const;
|
||||
void SetFormat(int format);
|
||||
void SetInputSize(int input_size);
|
||||
void SetHScale(float h_scale);
|
||||
void SetWScale(float w_scale);
|
||||
void SetXScale(float x_scale);
|
||||
void SetYScale(float y_scale);
|
||||
void SetNmsIouThreshold(float nms_iou_threshold);
|
||||
void SetNmsScoreThreshold(float nms_score_threshold);
|
||||
void SetMaxDetections(int64_t max_detections);
|
||||
void SetDetectionsPreClass(int64_t detections_pre_class);
|
||||
void SetMaxClassesPreDetection(int64_t max_classes_pre_detection);
|
||||
void SetNumClasses(int64_t num_classes);
|
||||
void SetUseRegularNms(bool use_regular_nms);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,7 +29,6 @@ void Div::SetActivationType(int activation_type) {
|
|||
|
||||
int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); }
|
||||
|
||||
void Div::SetActivationType(int activation_type) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,13 +28,15 @@ namespace lite {
|
|||
class Div : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Div, Arithmetic);
|
||||
Div() = default;
|
||||
explicit Div(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
void SetActivationType(int activation_type);
|
||||
|
||||
#else
|
||||
explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
int GetActivationType() const;
|
||||
void SetActivationType(int activation_type);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,6 @@ void Dropout::SetRatio(float ratio) { this->primitive_->value.AsDropout()->ratio
|
|||
|
||||
float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); }
|
||||
|
||||
void Dropout::SetRatio(float ratio) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,21 +20,23 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Dropout : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Dropout, PrimitiveC);
|
||||
Dropout() = default;
|
||||
explicit Dropout(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetRatio(float ratio);
|
||||
|
||||
#else
|
||||
explicit Dropout(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
float GetRatio() const;
|
||||
void SetRatio(float ratio);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,6 @@ void Eltwise::SetMode(int mode) { this->primitive_->value.AsEltwise()->mode = (s
|
|||
|
||||
int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); }
|
||||
|
||||
void Eltwise::SetMode(int mode) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,13 +28,15 @@ namespace lite {
|
|||
class Eltwise : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Eltwise, PrimitiveC);
|
||||
Eltwise() = default;
|
||||
explicit Eltwise(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetMode(int mode);
|
||||
|
||||
#else
|
||||
explicit Eltwise(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int GetMode() const;
|
||||
void SetMode(int mode);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,6 @@ void Elu::SetAlpha(float alpha) { this->primitive_->value.AsElu()->alpha = alpha
|
|||
|
||||
float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); }
|
||||
|
||||
void Elu::SetAlpha(float alpha) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,13 +28,15 @@ namespace lite {
|
|||
class Elu : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Elu, PrimitiveC);
|
||||
Elu() = default;
|
||||
explicit Elu(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetAlpha(float alpha);
|
||||
|
||||
#else
|
||||
explicit Elu(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
float GetAlpha() const;
|
||||
void SetAlpha(float alpha);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,6 @@ void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive_->value.AsEmb
|
|||
|
||||
float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); }
|
||||
|
||||
void EmbeddingLookup::SetMaxNorm(float max_norm) {}
|
||||
#endif
|
||||
|
||||
int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,14 +28,16 @@ namespace lite {
|
|||
class EmbeddingLookup : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(EmbeddingLookup, PrimitiveC);
|
||||
EmbeddingLookup() = default;
|
||||
explicit EmbeddingLookup(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetMaxNorm(float max_norm);
|
||||
|
||||
#else
|
||||
explicit EmbeddingLookup(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
float GetMaxNorm() const;
|
||||
void SetMaxNorm(float max_norm);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,9 +51,6 @@ float EmbeddingLookupSparse::GetMaxNortm() const {
|
|||
return this->primitive_->value_as_EmbeddingLookupSparse()->maxNortm();
|
||||
}
|
||||
|
||||
void EmbeddingLookupSparse::SetSpIds(const std::vector<int> &sp_ids) {}
|
||||
void EmbeddingLookupSparse::SetSpWeights(const std::vector<float> &sp_weights) {}
|
||||
void EmbeddingLookupSparse::SetMaxNortm(float max_nortm) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,17 +28,18 @@ namespace lite {
|
|||
class EmbeddingLookupSparse : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(EmbeddingLookupSparse, PrimitiveC);
|
||||
EmbeddingLookupSparse() = default;
|
||||
explicit EmbeddingLookupSparse(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetSpIds(const std::vector<int> &sp_ids);
|
||||
void SetSpWeights(const std::vector<float> &sp_weights);
|
||||
void SetMaxNortm(float max_nortm);
|
||||
#else
|
||||
explicit EmbeddingLookupSparse(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
std::vector<int> GetSpIds() const;
|
||||
std::vector<float> GetSpWeights() const;
|
||||
float GetMaxNortm() const;
|
||||
void SetSpIds(const std::vector<int> &sp_ids);
|
||||
void SetSpWeights(const std::vector<float> &sp_weights);
|
||||
void SetMaxNortm(float max_nortm);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
class Equal : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Equal, PrimitiveC);
|
||||
Equal() = default;
|
||||
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
class Exp : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Exp, ArithmeticSelf);
|
||||
Exp() = default;
|
||||
explicit Exp(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -27,7 +27,6 @@ void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim =
|
|||
|
||||
int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); }
|
||||
|
||||
void ExpandDims::SetDim(int dim) {}
|
||||
#endif
|
||||
|
||||
int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,14 +28,16 @@ namespace lite {
|
|||
class ExpandDims : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(ExpandDims, PrimitiveC);
|
||||
ExpandDims() = default;
|
||||
explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetDim(int dim);
|
||||
|
||||
#else
|
||||
explicit ExpandDims(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetDim() const;
|
||||
void SetDim(int dim);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,8 +40,6 @@ int FakeQuantWithMinMaxVars::GetNumBits() const {
|
|||
return this->primitive_->value_as_FakeQuantWithMinMaxVars()->numBits();
|
||||
}
|
||||
|
||||
void FakeQuantWithMinMaxVars::SetNarrowRange(bool narrow_range) {}
|
||||
void FakeQuantWithMinMaxVars::SetNumBits(int num_bits) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,15 +28,16 @@ namespace lite {
|
|||
class FakeQuantWithMinMaxVars : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(FakeQuantWithMinMaxVars, PrimitiveC);
|
||||
FakeQuantWithMinMaxVars() = default;
|
||||
explicit FakeQuantWithMinMaxVars(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetNarrowRange(bool narrow_range);
|
||||
void SetNumBits(int num_bits);
|
||||
#else
|
||||
explicit FakeQuantWithMinMaxVars(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
bool GetNarrowRange() const;
|
||||
int GetNumBits() const;
|
||||
void SetNarrowRange(bool narrow_range);
|
||||
void SetNumBits(int num_bits);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,6 @@ std::vector<int> Fill::GetDims() const {
|
|||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
void Fill::SetDims(const std::vector<int> &dims) {}
|
||||
#endif
|
||||
|
||||
int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,14 +28,16 @@ namespace lite {
|
|||
class Fill : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Fill, PrimitiveC);
|
||||
Fill() = default;
|
||||
explicit Fill(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetDims(const std::vector<int> &dims);
|
||||
|
||||
#else
|
||||
explicit Fill(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetDims() const;
|
||||
void SetDims(const std::vector<int> &dims);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,10 +51,30 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
}
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Flatten::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
this->primitive_ = new (schema::PrimitiveT);
|
||||
auto attr = std::make_unique<schema::FlattenT>();
|
||||
this->primitive_->value.type = schema::PrimitiveType_Flatten;
|
||||
this->primitive_->value.value = attr.release();
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Flatten;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Flatten) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::FlattenT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
class Flatten : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Flatten, PrimitiveC);
|
||||
Flatten() = default;
|
||||
explicit Flatten(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
class Floor : public ArithmeticSelf {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Floor, ArithmeticSelf);
|
||||
Floor() = default;
|
||||
explicit Floor(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
class FloorDiv : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(FloorDiv, Arithmetic);
|
||||
FloorDiv() = default;
|
||||
explicit FloorDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
class FloorMod : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(FloorMod, Arithmetic);
|
||||
FloorMod() = default;
|
||||
explicit FloorMod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -37,10 +37,6 @@ int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConn
|
|||
bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); }
|
||||
int FullConnection::GetActivationType() const { return this->primitive_->value_as_FullConnection()->activationType(); }
|
||||
|
||||
void FullConnection::SetHasBias(bool has_bias) {}
|
||||
void FullConnection::SetAxis(int axis) {}
|
||||
void FullConnection::SetUseAxis(bool use_axis) {}
|
||||
void FullConnection::SetActivationType(int activationType) {}
|
||||
#endif
|
||||
int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
||||
std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,8 +28,13 @@ namespace lite {
|
|||
class FullConnection : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(FullConnection, PrimitiveC);
|
||||
FullConnection() = default;
|
||||
explicit FullConnection(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetAxis(int axis);
|
||||
void SetUseAxis(bool use_axis);
|
||||
void SetActivationType(int activationType);
|
||||
#else
|
||||
explicit FullConnection(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
|
@ -38,10 +43,6 @@ class FullConnection : public PrimitiveC {
|
|||
int GetAxis() const;
|
||||
bool GetUseAxis() const;
|
||||
int GetActivationType() const;
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetAxis(int axis);
|
||||
void SetUseAxis(bool use_axis);
|
||||
void SetActivationType(int activationType);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,9 +33,6 @@ float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_Fus
|
|||
float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); }
|
||||
int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); }
|
||||
|
||||
void FusedBatchNorm::SetEpsilon(float epsilon) {}
|
||||
void FusedBatchNorm::SetMomentum(float momentum) {}
|
||||
void FusedBatchNorm::SetSpatial(int spatial) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,17 +28,18 @@ namespace lite {
|
|||
class FusedBatchNorm : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(FusedBatchNorm, PrimitiveC);
|
||||
FusedBatchNorm() = default;
|
||||
explicit FusedBatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetEpsilon(float epsilon);
|
||||
void SetMomentum(float momentum);
|
||||
void SetSpatial(int spatial);
|
||||
#else
|
||||
explicit FusedBatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
float GetEpsilon() const;
|
||||
float GetMomentum() const;
|
||||
int GetSpatial() const;
|
||||
void SetEpsilon(float epsilon);
|
||||
void SetMomentum(float momentum);
|
||||
void SetSpatial(int spatial);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,8 +33,6 @@ void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->
|
|||
int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); }
|
||||
int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); }
|
||||
|
||||
void Gather::SetAxis(int axis) {}
|
||||
void Gather::SetBatchDims(int batch_dims) {}
|
||||
#endif
|
||||
|
||||
int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,16 +28,17 @@ namespace lite {
|
|||
class Gather : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Gather, PrimitiveC);
|
||||
Gather() = default;
|
||||
explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetAxis(int axis);
|
||||
void SetBatchDims(int batch_dims);
|
||||
#else
|
||||
explicit Gather(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
int GetBatchDims() const;
|
||||
void SetAxis(int axis);
|
||||
void SetBatchDims(int batch_dims);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,6 @@ void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd
|
|||
|
||||
int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); }
|
||||
|
||||
void GatherNd::SetBatchDims(int batch_dims) {}
|
||||
#endif
|
||||
|
||||
int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
|
|
|
@ -28,14 +28,16 @@ namespace lite {
|
|||
class GatherNd : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(GatherNd, PrimitiveC);
|
||||
GatherNd() = default;
|
||||
explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetBatchDims(int batch_dims);
|
||||
|
||||
#else
|
||||
explicit GatherNd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetBatchDims() const;
|
||||
void SetBatchDims(int batch_dims);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,7 @@ namespace lite {
|
|||
class Greater : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Greater, Arithmetic);
|
||||
Greater() = default;
|
||||
explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
class GreaterEqual : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(GreaterEqual, Arithmetic);
|
||||
GreaterEqual() = default;
|
||||
explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -33,8 +33,6 @@ std::vector<int> L2Norm::GetAxis() const {
|
|||
}
|
||||
float L2Norm::GetEpsilon() const { return this->primitive_->value_as_L2Norm()->epsilon(); }
|
||||
|
||||
void L2Norm::SetAxis(const std::vector<int> &axis) {}
|
||||
void L2Norm::SetEpsilon(float epsilon) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,15 +28,16 @@ namespace lite {
|
|||
class L2Norm : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(L2Norm, PrimitiveC);
|
||||
L2Norm() = default;
|
||||
explicit L2Norm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
void SetEpsilon(float epsilon);
|
||||
#else
|
||||
explicit L2Norm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
std::vector<int> GetAxis() const;
|
||||
float GetEpsilon() const;
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
void SetEpsilon(float epsilon);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,7 +29,6 @@ void LeakyReLU::SetNegativeSlope(float negative_slope) {
|
|||
|
||||
float LeakyReLU::GetNegativeSlope() const { return this->primitive_->value_as_LeakyReLU()->negativeSlope(); }
|
||||
|
||||
void LeakyReLU::SetNegativeSlope(float negative_slope) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,13 +28,15 @@ namespace lite {
|
|||
class LeakyReLU : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(LeakyReLU, PrimitiveC);
|
||||
LeakyReLU() = default;
|
||||
explicit LeakyReLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetNegativeSlope(float negative_slope);
|
||||
|
||||
#else
|
||||
explicit LeakyReLU(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
#endif
|
||||
float GetNegativeSlope() const;
|
||||
void SetNegativeSlope(float negative_slope);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
class Less : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Less, Arithmetic);
|
||||
Less() = default;
|
||||
explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
class LessEqual : public Arithmetic {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(LessEqual, Arithmetic);
|
||||
LessEqual() = default;
|
||||
explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
|
|
|
@ -60,10 +60,6 @@ float LocalResponseNormalization::GetBeta() const {
|
|||
return this->primitive_->value_as_LocalResponseNormalization()->beta();
|
||||
}
|
||||
|
||||
void LocalResponseNormalization::SetDepthRadius(int depth_radius) {}
|
||||
void LocalResponseNormalization::SetBias(float bias) {}
|
||||
void LocalResponseNormalization::SetAlpha(float alpha) {}
|
||||
void LocalResponseNormalization::SetBeta(float beta) {}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue