forked from mindspore-Ecosystem/mindspore
decoupling primitive_c problem
This commit is contained in:
parent
70bb0a842a
commit
7a97fe710c
|
@ -19,6 +19,7 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/common/graph_util.h"
|
||||
#include "include/version.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
|
@ -31,7 +32,12 @@ bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
|
|||
}
|
||||
auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i);
|
||||
auto src_prim = c_node->primitive();
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
|
||||
#else
|
||||
auto primitive = const_cast<schema::Primitive *>(src_prim);
|
||||
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
|
||||
#endif
|
||||
if (node->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "unpack primitive == nullptr!";
|
||||
delete node;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/abs.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -27,6 +28,9 @@ int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *AbsCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Abs>(primitive); }
|
||||
Registry AbsRegistry(schema::PrimitiveType_Abs, AbsCreator);
|
||||
#endif
|
||||
Registry AbsParameterRegistry(schema::PrimitiveType_Abs, PopulateArithmeticSelf);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/activation.h"
|
||||
#include <memory>
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/activation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -80,6 +82,30 @@ int Activation::GetType() const { return this->primitive_->value_as_Activation()
|
|||
float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); }
|
||||
float Activation::GetMinVal() const { return this->primitive_->value_as_Activation()->min_val(); }
|
||||
float Activation::GetMaxVal() const { return this->primitive_->value_as_Activation()->max_val(); }
|
||||
|
||||
PrimitiveC *ActivationCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<Activation>(primitive);
|
||||
}
|
||||
Registry ActivationRegistry(schema::PrimitiveType_Activation, ActivationCreator);
|
||||
#endif
|
||||
OpParameter *PopulateActivationParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
|
||||
if (act_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ActivationParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(act_param, 0, sizeof(ActivationParameter));
|
||||
act_param->op_parameter_.type_ = primitive->Type();
|
||||
auto activation =
|
||||
reinterpret_cast<mindspore::lite::Activation *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
act_param->type_ = static_cast<int>(activation->GetType());
|
||||
act_param->alpha_ = activation->GetAlpha();
|
||||
act_param->min_val_ = activation->GetMinVal();
|
||||
act_param->max_val_ = activation->GetMaxVal();
|
||||
return reinterpret_cast<OpParameter *>(act_param);
|
||||
}
|
||||
|
||||
Registry ActivationParameterRegistry(schema::PrimitiveType_Activation, PopulateActivationParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/activation_grad.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -74,6 +76,11 @@ int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flat
|
|||
}
|
||||
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
|
||||
float ActivationGrad::GetAlpha() const { return this->primitive_->value_as_ActivationGrad()->alpha(); }
|
||||
|
||||
PrimitiveC *ActivationGradCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<ActivationGrad>(primitive);
|
||||
}
|
||||
Registry ActivationGradRegistry(schema::PrimitiveType_ActivationGrad, ActivationGradCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/adam.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -62,6 +64,9 @@ int Adam::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *AdamCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Adam>(primitive); }
|
||||
Registry AdamRegistry(schema::PrimitiveType_Adam, AdamCreator);
|
||||
#endif
|
||||
|
||||
int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/add.h"
|
||||
#include <memory>
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -71,6 +73,31 @@ int Add::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
}
|
||||
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }
|
||||
|
||||
PrimitiveC *AddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Add>(primitive); }
|
||||
Registry AddRegistry(schema::PrimitiveType_Add, AddCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
|
||||
if (arithmetic_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
|
||||
arithmetic_param->op_parameter_.type_ = primitive->Type();
|
||||
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
|
||||
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
|
||||
arithmetic_param->activation_type_ =
|
||||
reinterpret_cast<mindspore::lite::Add *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
|
||||
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
|
||||
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
|
||||
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
|
||||
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
return reinterpret_cast<OpParameter *>(arithmetic_param);
|
||||
}
|
||||
Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/addn.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -62,8 +64,22 @@ int AddN::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
|
|||
}
|
||||
int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); }
|
||||
|
||||
PrimitiveC *AddNCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<AddN>(primitive); }
|
||||
Registry AddNRegistry(schema::PrimitiveType_AddN, AddNCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateAddNParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
OpParameter *addn_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
if (addn_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc OpParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(addn_param, 0, sizeof(OpParameter));
|
||||
addn_param->type_ = primitive->Type();
|
||||
return reinterpret_cast<OpParameter *>(addn_param);
|
||||
}
|
||||
Registry AddNParameterRegistry(schema::PrimitiveType_AddN, PopulateAddNParameter);
|
||||
|
||||
namespace {
|
||||
constexpr int kLeastInputNum = 2;
|
||||
}
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/apply_momentum.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -67,6 +69,11 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *ApplyMomentumCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<ApplyMomentum>(primitive);
|
||||
}
|
||||
Registry ApplyMomentumRegistry(schema::PrimitiveType_ApplyMomentum, ApplyMomentumCreator);
|
||||
#endif
|
||||
|
||||
int ApplyMomentum::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/argmax.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/arg_min_max_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -52,8 +55,29 @@ int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK()
|
|||
bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); }
|
||||
int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); }
|
||||
|
||||
PrimitiveC *ArgMaxCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<ArgMax>(primitive); }
|
||||
Registry ArgMaxRegistry(schema::PrimitiveType_ArgMax, ArgMaxCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateArgMaxParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ArgMinMaxParameter *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter)));
|
||||
if (arg_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(arg_param, 0, sizeof(ArgMinMaxParameter));
|
||||
arg_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::ArgMax *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
arg_param->axis_ = param->GetAxis();
|
||||
arg_param->topk_ = param->GetTopK();
|
||||
arg_param->axis_type_ = param->GetAxisType();
|
||||
arg_param->out_value_ = param->GetOutMaxValue();
|
||||
arg_param->keep_dims_ = param->GetKeepDims();
|
||||
return reinterpret_cast<OpParameter *>(arg_param);
|
||||
}
|
||||
|
||||
Registry ArgMaxParameterRegistry(schema::PrimitiveType_ArgMax, PopulateArgMaxParameter);
|
||||
|
||||
int ArgMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/argmin.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/arg_min_max_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -52,8 +55,29 @@ int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK()
|
|||
bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); }
|
||||
int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); }
|
||||
|
||||
PrimitiveC *ArgMinCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<ArgMin>(primitive); }
|
||||
Registry ArgMinRegistry(schema::PrimitiveType_ArgMin, ArgMinCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateArgMinParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ArgMinMaxParameter *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter)));
|
||||
if (arg_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(arg_param, 0, sizeof(ArgMinMaxParameter));
|
||||
arg_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::ArgMin *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
arg_param->axis_ = param->GetAxis();
|
||||
arg_param->topk_ = param->GetTopK();
|
||||
arg_param->axis_type_ = param->GetAxisType();
|
||||
arg_param->out_value_ = param->GetOutMaxValue();
|
||||
arg_param->keep_dims_ = param->GetKeepDims();
|
||||
return reinterpret_cast<OpParameter *>(arg_param);
|
||||
}
|
||||
|
||||
Registry ArgMinParameterRegistry(schema::PrimitiveType_ArgMin, PopulateArgMinParameter);
|
||||
|
||||
int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
|
|
@ -21,6 +21,29 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
|
||||
if (arithmetic_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
|
||||
arithmetic_param->op_parameter_.type_ = primitive->Type();
|
||||
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
|
||||
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
|
||||
|
||||
arithmetic_param->activation_type_ = 0;
|
||||
|
||||
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
|
||||
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
|
||||
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
|
||||
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
return reinterpret_cast<OpParameter *>(arithmetic_param);
|
||||
}
|
||||
|
||||
int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if (inputs_.size() != kDoubleNum) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -51,6 +52,8 @@ class Arithmetic : public PrimitiveC {
|
|||
std::vector<int> in_shape1_;
|
||||
std::vector<int> out_shape_;
|
||||
};
|
||||
|
||||
OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "nnacl/arithmetic_self_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -17,9 +17,21 @@
|
|||
#include "src/ops/arithmetic_self.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateArithmeticSelf(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ArithmeticSelfParameter *arithmetic_self_param =
|
||||
reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter)));
|
||||
if (arithmetic_self_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ArithmeticSelfParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(arithmetic_self_param, 0, sizeof(ArithmeticSelfParameter));
|
||||
arithmetic_self_param->op_parameter_.type_ = primitive->Type();
|
||||
return reinterpret_cast<OpParameter *>(arithmetic_self_param);
|
||||
}
|
||||
|
||||
int ArithmeticSelf::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "nnacl/arithmetic_self_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -37,6 +38,7 @@ class ArithmeticSelf : public PrimitiveC {
|
|||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
OpParameter *PopulateArithmeticSelf(const mindspore::lite::PrimitiveC *primitive);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include "src/ops/assign.h"
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -56,6 +58,9 @@ int Assign::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *AssignCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Assign>(primitive); }
|
||||
Registry AssignRegistry(schema::PrimitiveType_Assign, AssignCreator);
|
||||
#endif
|
||||
|
||||
int Assign::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/batch_norm.h"
|
||||
#include <memory>
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/batchnorm_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -60,6 +63,28 @@ int BatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
|
|||
}
|
||||
float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); }
|
||||
|
||||
PrimitiveC *BatchNormCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<BatchNorm>(primitive);
|
||||
}
|
||||
Registry BatchNormRegistry(schema::PrimitiveType_BatchNorm, BatchNormCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateBatchNorm(const mindspore::lite::PrimitiveC *primitive) {
|
||||
const auto param =
|
||||
reinterpret_cast<mindspore::lite::BatchNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
BatchNormParameter *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
|
||||
if (batch_norm_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc BatchNormParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(batch_norm_param, 0, sizeof(BatchNormParameter));
|
||||
batch_norm_param->op_parameter_.type_ = primitive->Type();
|
||||
batch_norm_param->epsilon_ = param->GetEpsilon();
|
||||
batch_norm_param->fused_ = false;
|
||||
return reinterpret_cast<OpParameter *>(batch_norm_param);
|
||||
}
|
||||
|
||||
Registry BatchNormParameterRegistry(schema::PrimitiveType_BatchNorm, PopulateBatchNorm);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,9 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#include "src/tensor.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/batch_to_space.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -66,7 +69,49 @@ std::vector<int> BatchToSpace::GetCrops() const {
|
|||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
PrimitiveC *BatchToSpaceCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<BatchToSpace>(primitive);
|
||||
}
|
||||
Registry BatchToSpaceRegistry(schema::PrimitiveType_BatchToSpace, BatchToSpaceCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
BatchToSpaceParameter *batch_space_param =
|
||||
reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter)));
|
||||
if (batch_space_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(batch_space_param, 0, sizeof(BatchToSpaceParameter));
|
||||
batch_space_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::BatchToSpace *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
auto block_shape = param->GetBlockShape();
|
||||
if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
|
||||
MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE;
|
||||
free(batch_space_param);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto crops = param->GetCrops();
|
||||
if (crops.size() != BATCH_TO_SPACE_CROPS_SIZE) {
|
||||
MS_LOG(ERROR) << "batch_to_space crops size should be " << BATCH_TO_SPACE_CROPS_SIZE;
|
||||
free(batch_space_param);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
|
||||
batch_space_param->block_shape_[i] = block_shape[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
|
||||
batch_space_param->crops_[i] = crops[i];
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(batch_space_param);
|
||||
}
|
||||
|
||||
Registry BatchToSpaceParameterRegistry(schema::PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter);
|
||||
Registry BatchToSpaceNDParameterRegistry(schema::PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter);
|
||||
|
||||
namespace {
|
||||
constexpr int kBatchToSpaceOutputNum = 1;
|
||||
constexpr int kBatchToSpaceInputNum = 1;
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/bias_add.h"
|
||||
#include <memory>
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -78,6 +80,22 @@ std::vector<int> BiasAdd::GetAxis() const {
|
|||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
PrimitiveC *BiasAddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<BiasAdd>(primitive); }
|
||||
Registry BiasAddRegistry(schema::PrimitiveType_BiasAdd, BiasAddCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateBiasAddParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
|
||||
if (arithmetic_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
|
||||
arithmetic_param->op_parameter_.type_ = primitive->Type();
|
||||
|
||||
return reinterpret_cast<OpParameter *>(arithmetic_param);
|
||||
}
|
||||
Registry BiasAddParameterRegistry(schema::PrimitiveType_BiasAdd, PopulateBiasAddParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/bias_grad.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -74,6 +76,11 @@ std::vector<int> BiasGrad::GetAxis() const {
|
|||
auto fb_vector = this->primitive_->value_as_BiasGrad()->axis();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
PrimitiveC *BiasGradCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<BiasGrad>(primitive);
|
||||
}
|
||||
Registry BiasGradRegistry(schema::PrimitiveType_BiasGrad, BiasGradCreator);
|
||||
#endif
|
||||
|
||||
int BiasGrad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/bn_grad.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/broadcast_to.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/broadcast_to.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -50,7 +53,32 @@ std::vector<int> BroadcastTo::GetDstShape() const {
|
|||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
PrimitiveC *BroadcastToCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<BroadcastTo>(primitive);
|
||||
}
|
||||
Registry BroadcastToRegistry(schema::PrimitiveType_BroadcastTo, BroadcastToCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateBroadcastToParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
BroadcastToParameter *broadcast_param =
|
||||
reinterpret_cast<BroadcastToParameter *>(malloc(sizeof(BroadcastToParameter)));
|
||||
if (broadcast_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc BroadcastToParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(broadcast_param, 0, sizeof(BroadcastToParameter));
|
||||
auto param = reinterpret_cast<mindspore::lite::BroadcastTo *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
broadcast_param->op_parameter_.type_ = primitive->Type();
|
||||
auto dst_shape = param->GetDstShape();
|
||||
broadcast_param->shape_size_ = dst_shape.size();
|
||||
for (size_t i = 0; i < broadcast_param->shape_size_; ++i) {
|
||||
broadcast_param->shape_[i] = dst_shape[i];
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(broadcast_param);
|
||||
}
|
||||
|
||||
Registry BroadcastToParameterRegistry(schema::PrimitiveType_BroadcastTo, PopulateBroadcastToParameter);
|
||||
|
||||
namespace {
|
||||
constexpr int kBroadcastToInputNum = 1;
|
||||
constexpr int kBroadcastToOutputNum = 1;
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/cast.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/cast.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -75,8 +78,26 @@ int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
|
|||
int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); }
|
||||
int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); }
|
||||
|
||||
PrimitiveC *CastCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Cast>(primitive); }
|
||||
Registry CastRegistry(schema::PrimitiveType_Cast, CastCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateCastParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
CastParameter *cast_param = reinterpret_cast<CastParameter *>(malloc(sizeof(CastParameter)));
|
||||
if (cast_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc CastParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(cast_param, 0, sizeof(CastParameter));
|
||||
cast_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::Cast *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
cast_param->src_type_ = param->GetSrcT();
|
||||
cast_param->dst_type_ = param->GetDstT();
|
||||
return reinterpret_cast<OpParameter *>(cast_param);
|
||||
}
|
||||
|
||||
Registry CastParameterRegistry(schema::PrimitiveType_Cast, PopulateCastParameter);
|
||||
|
||||
int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ceil.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
Registry CeilParameterRegistry(schema::PrimitiveType_Ceil, PopulateArithmeticSelf);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -21,6 +21,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -43,6 +44,7 @@ class Ceil : public ArithmeticSelf {
|
|||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/clip.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/clip.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -42,6 +45,24 @@ int Clip::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
|
|||
float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); }
|
||||
float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); }
|
||||
|
||||
PrimitiveC *ClipCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Clip>(primitive); }
|
||||
Registry ClipRegistry(schema::PrimitiveType_Clip, ClipCreator);
|
||||
#endif
|
||||
OpParameter *PopulateClipParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ClipParameter *act_param = reinterpret_cast<ClipParameter *>(malloc(sizeof(ClipParameter)));
|
||||
if (act_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ClipParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(act_param, 0, sizeof(ClipParameter));
|
||||
act_param->op_parameter_.type_ = primitive->Type();
|
||||
auto activation = reinterpret_cast<mindspore::lite::Clip *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
act_param->min_val_ = activation->GetMin();
|
||||
act_param->max_val_ = activation->GetMax();
|
||||
return reinterpret_cast<OpParameter *>(act_param);
|
||||
}
|
||||
|
||||
Registry ClipParameterRegistry(schema::PrimitiveType_Clip, PopulateClipParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/tensor.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/concat_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -76,8 +78,26 @@ int Concat::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
|
|||
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
|
||||
int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); }
|
||||
|
||||
PrimitiveC *ConcatCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Concat>(primitive); }
|
||||
Registry ConcatRegistry(schema::PrimitiveType_Concat, ConcatCreator);
|
||||
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateConcatParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ConcatParameter *concat_param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter)));
|
||||
if (concat_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ConcatParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(concat_param, 0, sizeof(ConcatParameter));
|
||||
concat_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::Concat *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
concat_param->axis_ = param->GetAxis();
|
||||
return reinterpret_cast<OpParameter *>(concat_param);
|
||||
}
|
||||
|
||||
Registry ConcatParameterRegistry(schema::PrimitiveType_Concat, PopulateConcatParameter);
|
||||
|
||||
namespace {
|
||||
constexpr int kConcatOutputNum = 1;
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/tensor.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/constant_of_shape.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
|
@ -45,8 +47,29 @@ int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, fla
|
|||
}
|
||||
float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); }
|
||||
|
||||
PrimitiveC *ConstantOfShapeCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<ConstantOfShape>(primitive);
|
||||
}
|
||||
Registry ConstantOfShapeRegistry(schema::PrimitiveType_ConstantOfShape, ConstantOfShapeCreator);
|
||||
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto attr =
|
||||
reinterpret_cast<mindspore::lite::ConstantOfShape *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
ConstantOfShapeParameter *param =
|
||||
reinterpret_cast<ConstantOfShapeParameter *>(malloc(sizeof(ConstantOfShapeParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ConstantOfShapeParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(ConstantOfShapeParameter));
|
||||
param->op_parameter_.type_ = primitive->Type();
|
||||
param->value_ = attr->GetValue();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter);
|
||||
|
||||
int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
if (inputs_.size() != kShapeInputNum) {
|
||||
MS_LOG(ERROR) << "inputs to ConstantOfShape operator should be 1, but " << inputs_.size() << " is given.";
|
||||
|
|
|
@ -24,9 +24,10 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include <float.h>
|
||||
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -320,7 +321,51 @@ int Conv2D::GetDilateH() const { return this->primitive_->value_as_Conv2D()->dil
|
|||
bool Conv2D::GetHasBias() const { return this->primitive_->value_as_Conv2D()->hasBias(); }
|
||||
int Conv2D::GetActivationType() const { return this->primitive_->value_as_Conv2D()->activationType(); }
|
||||
|
||||
PrimitiveC *Conv2DCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Conv2D>(primitive); }
|
||||
Registry Conv2DRegistry(schema::PrimitiveType_Conv2D, Conv2DCreator);
|
||||
#endif
|
||||
OpParameter *PopulateConvParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
|
||||
if (conv_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ConvParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(conv_param, 0, sizeof(ConvParameter));
|
||||
conv_param->op_parameter_.type_ = primitive->Type();
|
||||
auto conv_primitive =
|
||||
reinterpret_cast<mindspore::lite::Conv2D *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
conv_param->kernel_h_ = conv_primitive->GetKernelH();
|
||||
conv_param->kernel_w_ = conv_primitive->GetKernelW();
|
||||
conv_param->group_ = conv_primitive->GetGroup();
|
||||
conv_param->stride_h_ = conv_primitive->GetStrideH();
|
||||
conv_param->stride_w_ = conv_primitive->GetStrideW();
|
||||
|
||||
auto conv2d_lite_primitive = (lite::Conv2D *)primitive;
|
||||
conv_param->pad_u_ = conv2d_lite_primitive->PadUp();
|
||||
conv_param->pad_d_ = conv2d_lite_primitive->PadDown();
|
||||
conv_param->pad_l_ = conv2d_lite_primitive->PadLeft();
|
||||
conv_param->pad_r_ = conv2d_lite_primitive->PadRight();
|
||||
conv_param->dilation_h_ = conv_primitive->GetDilateH();
|
||||
conv_param->dilation_w_ = conv_primitive->GetDilateW();
|
||||
conv_param->input_channel_ = conv_primitive->GetChannelIn();
|
||||
conv_param->output_channel_ = conv_primitive->GetChannelOut();
|
||||
conv_param->group_ = conv_primitive->GetGroup();
|
||||
auto act_type = conv_primitive->GetActivationType();
|
||||
switch (act_type) {
|
||||
case schema::ActivationType_RELU:
|
||||
conv_param->act_type_ = ActType_Relu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
conv_param->act_type_ = ActType_Relu6;
|
||||
break;
|
||||
default:
|
||||
conv_param->act_type_ = ActType_No;
|
||||
break;
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(conv_param);
|
||||
}
|
||||
Registry Conv2DParameterRegistry(schema::PrimitiveType_Conv2D, PopulateConvParameter);
|
||||
|
||||
void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
int kernel_w = GetKernelW();
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/ops/conv2d_grad_filter.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -176,6 +177,10 @@ int Conv2DGradFilter::GetActivationType() const {
|
|||
return this->primitive_->value_as_Conv2DGradFilter()->activationType();
|
||||
}
|
||||
|
||||
PrimitiveC *Conv2DGradFilterCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<Conv2DGradFilter>(primitive);
|
||||
}
|
||||
Registry conv2DGradFilterRegistry(schema::PrimitiveType_Conv2DGradFilter, Conv2DGradFilterCreator);
|
||||
#endif
|
||||
|
||||
int Conv2DGradFilter::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/conv2d_grad_input.h"
|
||||
#include "src/ops/group_conv2d_grad_input.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -178,6 +180,10 @@ int Conv2DGradInput::GetActivationType() const {
|
|||
return this->primitive_->value_as_Conv2DGradInput()->activationType();
|
||||
}
|
||||
|
||||
PrimitiveC *Conv2DGradInputCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<Conv2DGradInput>(primitive);
|
||||
}
|
||||
Registry Conv2DGradInputRegistry(schema::PrimitiveType_Conv2DGradInput, Conv2DGradInputCreator);
|
||||
#endif
|
||||
|
||||
int Conv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/cos.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
@ -27,6 +29,10 @@ int Cos::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *CosCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Cos>(primitive); }
|
||||
Registry CosRegistry(schema::PrimitiveType_Cos, CosCreator);
|
||||
#endif
|
||||
Registry CosParameterRegistry(schema::PrimitiveType_Cos, PopulateArithmeticSelf);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/crop.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/crop_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -51,7 +54,33 @@ std::vector<int64_t> Crop::GetOffsets() const {
|
|||
return std::vector<int64_t>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
PrimitiveC *CropCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Crop>(primitive); }
|
||||
Registry CropRegistry(schema::PrimitiveType_Crop, CropCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateCropParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto param = reinterpret_cast<mindspore::lite::Crop *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
auto param_offset = param->GetOffsets();
|
||||
if (param_offset.size() > CROP_OFFSET_MAX_SIZE) {
|
||||
MS_LOG(ERROR) << "crop_param offset size(" << param_offset.size() << ") should <= " << CROP_OFFSET_MAX_SIZE;
|
||||
return nullptr;
|
||||
}
|
||||
CropParameter *crop_param = reinterpret_cast<CropParameter *>(malloc(sizeof(CropParameter)));
|
||||
if (crop_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc CropParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(crop_param, 0, sizeof(CropParameter));
|
||||
crop_param->op_parameter_.type_ = primitive->Type();
|
||||
crop_param->axis_ = param->GetAxis();
|
||||
crop_param->offset_size_ = param_offset.size();
|
||||
for (size_t i = 0; i < param_offset.size(); ++i) {
|
||||
crop_param->offset_[i] = param_offset[i];
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(crop_param);
|
||||
}
|
||||
Registry CropParameterRegistry(schema::PrimitiveType_Crop, PopulateCropParameter);
|
||||
|
||||
namespace {
|
||||
constexpr int kCropOutputNum = 1;
|
||||
constexpr int kCropInputNum = 2;
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -31,8 +33,25 @@ int CustomExtractFeatures::UnPackToFlatBuilder(const schema::Primitive *primitiv
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *CustomExtractFeaturesCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<CustomExtractFeatures>(primitive);
|
||||
}
|
||||
Registry CustomExtractFeaturesRegistry(schema::PrimitiveType_CustomExtractFeatures, CustomExtractFeaturesCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateExtractFeaturesParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "new OpParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(OpParameter));
|
||||
param->type_ = primitive->Type();
|
||||
return param;
|
||||
}
|
||||
Registry CustomExtractFeaturesParameterRegistry(schema::PrimitiveType_CustomExtractFeatures,
|
||||
PopulateExtractFeaturesParameter);
|
||||
|
||||
int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.at(0);
|
||||
auto output0 = outputs_.at(0);
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -30,7 +32,25 @@ int CustomNormalize::UnPackToFlatBuilder(const schema::Primitive *primitive, fla
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *CustomNormalizeCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<CustomNormalize>(primitive);
|
||||
}
|
||||
Registry CustomNormalizeRegistry(schema::PrimitiveType_CustomNormalize, CustomNormalizeCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateCustomNormalizeParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "new OpParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(OpParameter));
|
||||
param->type_ = primitive->Type();
|
||||
return param;
|
||||
}
|
||||
Registry CustomNormalizeParameterRegistry(schema::PrimitiveType_CustomNormalize, PopulateCustomNormalizeParameter);
|
||||
|
||||
int CustomNormalize::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.at(0);
|
||||
auto output = outputs_.at(0);
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
*/
|
||||
#include "src/ops/custom_predict.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/predict_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -45,7 +48,27 @@ int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *CustomPredictCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<CustomPredict>(primitive);
|
||||
}
|
||||
Registry CustomPredictRegistry(schema::PrimitiveType_CustomPredict, CustomPredictCreator);
|
||||
#endif
|
||||
OpParameter *PopulateCustomPredictParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
PredictParameter *param = reinterpret_cast<PredictParameter *>(malloc(sizeof(PredictParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc param failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(PredictParameter));
|
||||
param->op_parameter_.type_ = primitive->Type();
|
||||
auto prim = reinterpret_cast<mindspore::lite::CustomPredict *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
param->output_num = prim->GetOutputNum();
|
||||
param->weight_threshold = prim->GetWeightThreshold();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
Registry CustomPredictParameterRegistry(schema::PrimitiveType_CustomPredict, PopulateCustomPredictParameter);
|
||||
|
||||
int CustomPredict::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.at(0);
|
||||
auto output0 = outputs_.at(0);
|
||||
|
|
|
@ -25,6 +25,9 @@
|
|||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -295,7 +298,51 @@ int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()-
|
|||
bool DeConv2D::GetHasBias() const { return this->primitive_->value_as_DeConv2D()->hasBias(); }
|
||||
int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); }
|
||||
|
||||
PrimitiveC *DeConv2DCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<DeConv2D>(primitive);
|
||||
}
|
||||
Registry DeConv2DRegistry(schema::PrimitiveType_DeConv2D, DeConv2DCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateDeconvParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
|
||||
if (conv_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ConvParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(conv_param, 0, sizeof(ConvParameter));
|
||||
conv_param->op_parameter_.type_ = primitive->Type();
|
||||
auto conv_primitive =
|
||||
reinterpret_cast<mindspore::lite::DeConv2D *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
conv_param->kernel_h_ = conv_primitive->GetKernelH();
|
||||
conv_param->kernel_w_ = conv_primitive->GetKernelW();
|
||||
conv_param->stride_h_ = conv_primitive->GetStrideH();
|
||||
conv_param->stride_w_ = conv_primitive->GetStrideW();
|
||||
|
||||
auto deconv_lite_primitive = (lite::DeConv2D *)primitive;
|
||||
conv_param->pad_u_ = deconv_lite_primitive->PadUp();
|
||||
conv_param->pad_d_ = deconv_lite_primitive->PadDown();
|
||||
conv_param->pad_l_ = deconv_lite_primitive->PadLeft();
|
||||
conv_param->pad_r_ = deconv_lite_primitive->PadRight();
|
||||
conv_param->dilation_h_ = conv_primitive->GetDilateH();
|
||||
conv_param->dilation_w_ = conv_primitive->GetDilateW();
|
||||
auto act_type = conv_primitive->GetActivationType();
|
||||
switch (act_type) {
|
||||
case schema::ActivationType_RELU:
|
||||
conv_param->act_type_ = ActType_Relu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
conv_param->act_type_ = ActType_Relu6;
|
||||
break;
|
||||
default:
|
||||
conv_param->act_type_ = ActType_No;
|
||||
break;
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(conv_param);
|
||||
}
|
||||
|
||||
Registry DeConv2DParameterRegistry(schema::PrimitiveType_DeConv2D, PopulateDeconvParameter);
|
||||
|
||||
int DeConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/dedepthwise_conv2d.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -109,7 +112,51 @@ int DeDepthwiseConv2D::GetActivationType() const {
|
|||
return this->primitive_->value_as_DeDepthwiseConv2D()->activationType();
|
||||
}
|
||||
|
||||
PrimitiveC *DeDepthwiseConv2DCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<DeDepthwiseConv2D>(primitive);
|
||||
}
|
||||
Registry DeDepthwiseConv2DRegistry(schema::PrimitiveType_DeDepthwiseConv2D, DeDepthwiseConv2DCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateDeconvDwParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
|
||||
if (conv_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ConvParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(conv_param, 0, sizeof(ConvParameter));
|
||||
conv_param->op_parameter_.type_ = primitive->Type();
|
||||
auto conv_primitive =
|
||||
reinterpret_cast<mindspore::lite::DeDepthwiseConv2D *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
conv_param->kernel_h_ = conv_primitive->GetKernelH();
|
||||
conv_param->kernel_w_ = conv_primitive->GetKernelW();
|
||||
conv_param->stride_h_ = conv_primitive->GetStrideH();
|
||||
conv_param->stride_w_ = conv_primitive->GetStrideW();
|
||||
|
||||
auto deconvdw_lite_primitive = (mindspore::lite::DeDepthwiseConv2D *)primitive;
|
||||
conv_param->pad_u_ = deconvdw_lite_primitive->PadUp();
|
||||
conv_param->pad_d_ = deconvdw_lite_primitive->PadDown();
|
||||
conv_param->pad_l_ = deconvdw_lite_primitive->PadLeft();
|
||||
conv_param->pad_r_ = deconvdw_lite_primitive->PadRight();
|
||||
conv_param->dilation_h_ = conv_primitive->GetDilateH();
|
||||
conv_param->dilation_w_ = conv_primitive->GetDilateW();
|
||||
auto act_type = conv_primitive->GetActivationType();
|
||||
switch (act_type) {
|
||||
case schema::ActivationType_RELU:
|
||||
conv_param->act_type_ = ActType_Relu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
conv_param->act_type_ = ActType_Relu6;
|
||||
break;
|
||||
default:
|
||||
conv_param->act_type_ = ActType_No;
|
||||
break;
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(conv_param);
|
||||
}
|
||||
|
||||
Registry DeDepthwiseConv2DParameterRegistry(schema::PrimitiveType_DeDepthwiseConv2D, PopulateDeconvDwParameter);
|
||||
|
||||
int DeDepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) {
|
||||
MS_LOG(ERROR) << "inputs number is invalid";
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -52,6 +54,9 @@ int Depend::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *DependCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Depend>(primitive); }
|
||||
Registry DependRegistry(schema::PrimitiveType_Depend, DependCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/depth_to_space.h"
|
||||
#include "src/common/common.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/depth_to_space_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -42,7 +45,29 @@ int DepthToSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu
|
|||
int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); }
|
||||
int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); }
|
||||
|
||||
PrimitiveC *DepthToSpaceCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<DepthToSpace>(primitive);
|
||||
}
|
||||
Registry DepthToSpaceRegistry(schema::PrimitiveType_DepthToSpace, DepthToSpaceCreator);
|
||||
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateDepthToSpaceParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
DepthToSpaceParameter *depth_space_param =
|
||||
reinterpret_cast<DepthToSpaceParameter *>(malloc(sizeof(DepthToSpaceParameter)));
|
||||
if (depth_space_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc DepthToSpaceParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(depth_space_param, 0, sizeof(DepthToSpaceParameter));
|
||||
auto param = reinterpret_cast<mindspore::lite::DepthToSpace *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
depth_space_param->op_parameter_.type_ = primitive->Type();
|
||||
depth_space_param->block_size_ = param->GetBlockSize();
|
||||
return reinterpret_cast<OpParameter *>(depth_space_param);
|
||||
}
|
||||
|
||||
Registry DepthToSpaceParameterRegistry(schema::PrimitiveType_DepthToSpace, PopulateDepthToSpaceParameter);
|
||||
|
||||
namespace {
|
||||
constexpr int kDepthToSpaceOutputNum = 1;
|
||||
constexpr int kDepthToSpaceInputNum = 1;
|
||||
|
|
|
@ -21,6 +21,9 @@
|
|||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -191,7 +194,54 @@ int DepthwiseConv2D::GetActivationType() const {
|
|||
return this->primitive_->value_as_DepthwiseConv2D()->activationType();
|
||||
}
|
||||
|
||||
PrimitiveC *DepthWiseConv2DCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<DepthwiseConv2D>(primitive);
|
||||
}
|
||||
Registry DepthWiseConv2DRegistry(schema::PrimitiveType_DepthwiseConv2D, DepthWiseConv2DCreator);
|
||||
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateConvDwParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
|
||||
if (conv_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ConvParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(conv_param, 0, sizeof(ConvParameter));
|
||||
conv_param->op_parameter_.type_ = primitive->Type();
|
||||
|
||||
auto conv_primitive =
|
||||
reinterpret_cast<mindspore::lite::DepthwiseConv2D *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
conv_param->kernel_h_ = conv_primitive->GetKernelH();
|
||||
conv_param->kernel_w_ = conv_primitive->GetKernelW();
|
||||
conv_param->stride_h_ = conv_primitive->GetStrideH();
|
||||
conv_param->stride_w_ = conv_primitive->GetStrideW();
|
||||
|
||||
auto convdw_lite_primitive = (lite::DepthwiseConv2D *)primitive;
|
||||
conv_param->pad_u_ = convdw_lite_primitive->PadUp();
|
||||
conv_param->pad_d_ = convdw_lite_primitive->PadDown();
|
||||
conv_param->pad_l_ = convdw_lite_primitive->PadLeft();
|
||||
conv_param->pad_r_ = convdw_lite_primitive->PadRight();
|
||||
conv_param->input_channel_ = convdw_lite_primitive->GetInputChannel();
|
||||
conv_param->dilation_h_ = conv_primitive->GetDilateH();
|
||||
conv_param->dilation_w_ = conv_primitive->GetDilateW();
|
||||
auto act_type = conv_primitive->GetActivationType();
|
||||
switch (act_type) {
|
||||
case schema::ActivationType_RELU:
|
||||
conv_param->act_type_ = ActType_Relu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
conv_param->act_type_ = ActType_Relu6;
|
||||
break;
|
||||
default:
|
||||
conv_param->act_type_ = ActType_No;
|
||||
break;
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(conv_param);
|
||||
}
|
||||
|
||||
Registry DepthwiseConv2DParameterRegistry(schema::PrimitiveType_DepthwiseConv2D, PopulateConvDwParameter);
|
||||
|
||||
int DepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) {
|
||||
MS_LOG(ERROR) << "inputs number is invalid";
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/detection_post_process.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/detection_post_process_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -139,7 +142,38 @@ bool DetectionPostProcess::GetOutQuantized() const {
|
|||
return this->primitive_->value_as_DetectionPostProcess()->OutQuantized();
|
||||
}
|
||||
|
||||
PrimitiveC *DetectionPostProcessCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<DetectionPostProcess>(primitive);
|
||||
}
|
||||
Registry DetectionPostProcessRegistry(schema::PrimitiveType_DetectionPostProcess, DetectionPostProcessCreator);
|
||||
#endif
|
||||
OpParameter *PopulateDetectionPostProcessParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
DetectionPostProcessParameter *detection_post_process_parameter =
|
||||
reinterpret_cast<DetectionPostProcessParameter *>(malloc(sizeof(DetectionPostProcessParameter)));
|
||||
if (detection_post_process_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc EluParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(detection_post_process_parameter, 0, sizeof(DetectionPostProcessParameter));
|
||||
detection_post_process_parameter->op_parameter_.type_ = primitive->Type();
|
||||
auto param =
|
||||
reinterpret_cast<mindspore::lite::DetectionPostProcess *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
detection_post_process_parameter->h_scale_ = param->GetHScale();
|
||||
detection_post_process_parameter->w_scale_ = param->GetWScale();
|
||||
detection_post_process_parameter->x_scale_ = param->GetXScale();
|
||||
detection_post_process_parameter->y_scale_ = param->GetYScale();
|
||||
detection_post_process_parameter->nms_iou_threshold_ = param->GetNmsIouThreshold();
|
||||
detection_post_process_parameter->nms_score_threshold_ = param->GetNmsScoreThreshold();
|
||||
detection_post_process_parameter->max_detections_ = param->GetMaxDetections();
|
||||
detection_post_process_parameter->detections_per_class_ = param->GetDetectionsPerClass();
|
||||
detection_post_process_parameter->max_classes_per_detection_ = param->GetMaxClassesPerDetection();
|
||||
detection_post_process_parameter->num_classes_ = param->GetNumClasses();
|
||||
detection_post_process_parameter->use_regular_nms_ = param->GetUseRegularNms();
|
||||
return reinterpret_cast<OpParameter *>(detection_post_process_parameter);
|
||||
}
|
||||
Registry DetectionPostProcessParameterRegistry(schema::PrimitiveType_DetectionPostProcess,
|
||||
PopulateDetectionPostProcessParameter);
|
||||
|
||||
namespace {
|
||||
constexpr int kDetectionPostProcessOutputNum = 4;
|
||||
constexpr int kDetectionPostProcessInputNum = 3;
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/div.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -41,6 +43,30 @@ int Div::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
}
|
||||
int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); }
|
||||
|
||||
PrimitiveC *DivCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Div>(primitive); }
|
||||
Registry DivRegistry(schema::PrimitiveType_Div, DivCreator);
|
||||
#endif
|
||||
OpParameter *PopulateDivParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
|
||||
if (arithmetic_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
|
||||
arithmetic_param->op_parameter_.type_ = primitive->Type();
|
||||
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
|
||||
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
|
||||
arithmetic_param->activation_type_ =
|
||||
reinterpret_cast<mindspore::lite::Div *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
|
||||
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
|
||||
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
|
||||
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
|
||||
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
return reinterpret_cast<OpParameter *>(arithmetic_param);
|
||||
}
|
||||
Registry DivParameterRegistry(schema::PrimitiveType_Div, PopulateDivParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/dropout.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -39,6 +41,8 @@ int Dropout::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
}
|
||||
float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); }
|
||||
|
||||
PrimitiveC *DropoutCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Dropout>(primitive); }
|
||||
Registry DropoutRegistry(schema::PrimitiveType_Dropout, DropoutCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/eltwise.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -39,6 +42,35 @@ int Eltwise::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
}
|
||||
int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); }
|
||||
|
||||
PrimitiveC *EltwiseCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Eltwise>(primitive); }
|
||||
Registry EltwiseRegistry(schema::PrimitiveType_Eltwise, EltwiseCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
|
||||
if (arithmetic_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
|
||||
auto eltwise = reinterpret_cast<mindspore::lite::Eltwise *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
switch (eltwise->GetMode()) {
|
||||
case schema::EltwiseMode_PROD:
|
||||
arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Mul;
|
||||
break;
|
||||
case schema::EltwiseMode_SUM:
|
||||
arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Add;
|
||||
break;
|
||||
case schema::EltwiseMode_MAXIMUM:
|
||||
arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Maximum;
|
||||
break;
|
||||
default:
|
||||
free(arithmetic_param);
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(arithmetic_param);
|
||||
}
|
||||
Registry EltwiseParameterRegistry(schema::PrimitiveType_Eltwise, PopulateEltwiseParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/elu.h"
|
||||
#include <memory>
|
||||
#include "nnacl/fp32/elu.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -61,6 +63,22 @@ int Elu::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
}
|
||||
float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); }
|
||||
|
||||
PrimitiveC *EluCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Elu>(primitive); }
|
||||
Registry EluRegistry(schema::PrimitiveType_Elu, EluCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateEluParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
EluParameter *elu_parameter = reinterpret_cast<EluParameter *>(malloc(sizeof(EluParameter)));
|
||||
if (elu_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc EluParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(elu_parameter, 0, sizeof(EluParameter));
|
||||
elu_parameter->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::Elu *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
elu_parameter->alpha_ = param->GetAlpha();
|
||||
return reinterpret_cast<OpParameter *>(elu_parameter);
|
||||
}
|
||||
Registry EluParameterRegistry(schema::PrimitiveType_Elu, PopulateEluParameter);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/embedding_lookup.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/embedding_lookup.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -41,8 +44,35 @@ int EmbeddingLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, fla
|
|||
}
|
||||
float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); }
|
||||
|
||||
PrimitiveC *EmbeddingLookupCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<EmbeddingLookup>(primitive);
|
||||
}
|
||||
Registry EmbeddingLookupRegistry(schema::PrimitiveType_EmbeddingLookup, EmbeddingLookupCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateEmbeddingLookupParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
EmbeddingLookupParameter *embedding_lookup_parameter =
|
||||
reinterpret_cast<EmbeddingLookupParameter *>(malloc(sizeof(EmbeddingLookupParameter)));
|
||||
if (embedding_lookup_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc EmbeddingLookupParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(embedding_lookup_parameter, 0, sizeof(EmbeddingLookupParameter));
|
||||
embedding_lookup_parameter->op_parameter_.type_ = primitive->Type();
|
||||
auto param =
|
||||
reinterpret_cast<mindspore::lite::EmbeddingLookup *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
embedding_lookup_parameter->max_norm_ = param->GetMaxNorm();
|
||||
if (embedding_lookup_parameter->max_norm_ < 0) {
|
||||
MS_LOG(ERROR) << "Embedding lookup max norm should be positive number, got "
|
||||
<< embedding_lookup_parameter->max_norm_;
|
||||
free(embedding_lookup_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(embedding_lookup_parameter);
|
||||
}
|
||||
|
||||
Registry EmbeddingLookupParameterRegistry(schema::PrimitiveType_EmbeddingLookup, PopulateEmbeddingLookupParameter);
|
||||
|
||||
int EmbeddingLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if (inputs_.size() < kDoubleNum) {
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/embedding_lookup_sparse.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -76,6 +78,10 @@ float EmbeddingLookupSparse::GetMaxNortm() const {
|
|||
return this->primitive_->value_as_EmbeddingLookupSparse()->maxNortm();
|
||||
}
|
||||
|
||||
PrimitiveC *EmbeddingLookupSparseCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<EmbeddingLookupSparse>(primitive);
|
||||
}
|
||||
Registry EmbeddingLookupSparseRegistry(schema::PrimitiveType_EmbeddingLookupSparse, EmbeddingLookupSparseCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/equal.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,6 +30,8 @@ int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); }
|
||||
Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator);
|
||||
#endif
|
||||
int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
|
@ -39,5 +43,6 @@ int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
|
|||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/exp.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -71,6 +74,10 @@ float Exp::GetBase() const { return this->primitive_->value_as_Exp()->base(); }
|
|||
float Exp::GetScale() const { return this->primitive_->value_as_Exp()->scale(); }
|
||||
float Exp::GetShift() const { return this->primitive_->value_as_Exp()->shift(); }
|
||||
|
||||
PrimitiveC *ExpCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Exp>(primitive); }
|
||||
Registry ExpRegistry(schema::PrimitiveType_Exp, ExpCreator);
|
||||
#endif
|
||||
Registry ExpParameterRegistry(schema::PrimitiveType_Exp, PopulateArithmeticSelf);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/expand_dims.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/expandDims.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -40,8 +43,27 @@ int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff
|
|||
}
|
||||
int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); }
|
||||
|
||||
PrimitiveC *ExpandDimsCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<ExpandDims>(primitive);
|
||||
}
|
||||
Registry ExpandDimsRegistry(schema::PrimitiveType_ExpandDims, ExpandDimsCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateExpandDimsParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto param = reinterpret_cast<mindspore::lite::ExpandDims *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
ExpandDimsParameter *expand_dims_param = reinterpret_cast<ExpandDimsParameter *>(malloc(sizeof(ExpandDimsParameter)));
|
||||
if (expand_dims_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ExpandDimsParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(expand_dims_param, 0, sizeof(ExpandDimsParameter));
|
||||
expand_dims_param->op_parameter_.type_ = primitive->Type();
|
||||
expand_dims_param->dim_ = param->GetDim();
|
||||
return reinterpret_cast<OpParameter *>(expand_dims_param);
|
||||
}
|
||||
|
||||
Registry ExpandDimsParameterRegistry(schema::PrimitiveType_ExpandDims, PopulateExpandDimsParameter);
|
||||
|
||||
int ExpandDims::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/fake_quant_with_min_max_vars.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -54,6 +56,10 @@ int FakeQuantWithMinMaxVars::GetNumBits() const {
|
|||
return this->primitive_->value_as_FakeQuantWithMinMaxVars()->numBits();
|
||||
}
|
||||
|
||||
PrimitiveC *FakeQuantWithMinMaxVarsCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<FakeQuantWithMinMaxVars>(primitive);
|
||||
}
|
||||
Registry FakeQuantWithMinMaxVarsRegistry(schema::PrimitiveType_FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/fill.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/fill.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -48,8 +51,30 @@ std::vector<int> Fill::GetDims() const {
|
|||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
PrimitiveC *FillCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Fill>(primitive); }
|
||||
Registry FillRegistry(schema::PrimitiveType_Fill, FillCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateFillParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
const auto param = reinterpret_cast<mindspore::lite::Fill *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
FillParameter *fill_param = reinterpret_cast<FillParameter *>(malloc(sizeof(FillParameter)));
|
||||
if (fill_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc FillParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(fill_param, 0, sizeof(FillParameter));
|
||||
fill_param->op_parameter_.type_ = primitive->Type();
|
||||
auto flatDims = param->GetDims();
|
||||
fill_param->num_dims_ = flatDims.size();
|
||||
int i = 0;
|
||||
for (auto iter = flatDims.begin(); iter != flatDims.end(); iter++) {
|
||||
fill_param->dims_[i++] = *iter;
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(fill_param);
|
||||
}
|
||||
|
||||
Registry FillParameterRegistry(schema::PrimitiveType_Fill, PopulateFillParameter);
|
||||
|
||||
int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
#include "src/ops/flatten.h"
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/flatten.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
|
@ -86,6 +89,22 @@ int Flatten::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *FlattenCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Flatten>(primitive); }
|
||||
Registry FlattenRegistry(schema::PrimitiveType_Flatten, FlattenCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateFlattenParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
FlattenParameter *flatten_param = reinterpret_cast<FlattenParameter *>(malloc(sizeof(FlattenParameter)));
|
||||
if (flatten_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc FlattenParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(flatten_param, 0, sizeof(FlattenParameter));
|
||||
flatten_param->op_parameter_.type_ = primitive->Type();
|
||||
return reinterpret_cast<OpParameter *>(flatten_param);
|
||||
}
|
||||
|
||||
Registry FlattenParameterRegistry(schema::PrimitiveType_Flatten, PopulateFlattenParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include "src/ops/flatten_grad.h"
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int FlattenGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
|
@ -85,6 +87,10 @@ int FlattenGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuf
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *FlattenGradCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<FlattenGrad>(primitive);
|
||||
}
|
||||
Registry FlattenGradRegistry(schema::PrimitiveType_FlattenGrad, FlattenGradCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/floor.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,7 +30,10 @@ int Floor::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *FloorCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Floor>(primitive); }
|
||||
Registry FloorRegistry(schema::PrimitiveType_Floor, FloorCreator);
|
||||
#endif
|
||||
Registry FloorParameterRegistry(schema::PrimitiveType_Floor, PopulateArithmeticSelf);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/floor_div.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
@ -29,6 +31,11 @@ int FloorDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *FloorDivCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<FloorDiv>(primitive);
|
||||
}
|
||||
Registry FloorDivRegistry(schema::PrimitiveType_FloorDiv, FloorDivCreator);
|
||||
#endif
|
||||
Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/floor_mod.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,7 +30,11 @@ int FloorMod::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *FloorModCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<FloorMod>(primitive);
|
||||
}
|
||||
Registry FloorModRegistry(schema::PrimitiveType_FloorMod, FloorModCreator);
|
||||
#endif
|
||||
Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/full_connection.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -51,7 +54,38 @@ int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConn
|
|||
bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); }
|
||||
int FullConnection::GetActivationType() const { return this->primitive_->value_as_FullConnection()->activationType(); }
|
||||
|
||||
PrimitiveC *FullConnectionCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<FullConnection>(primitive);
|
||||
}
|
||||
Registry FullConnectionRegistry(schema::PrimitiveType_FullConnection, FullConnectionCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateFullconnectionParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto param =
|
||||
reinterpret_cast<mindspore::lite::FullConnection *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
MatMulParameter *matmul_param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter)));
|
||||
if (matmul_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc MatMulParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(matmul_param, 0, sizeof(MatMulParameter));
|
||||
matmul_param->op_parameter_.type_ = primitive->Type();
|
||||
matmul_param->b_transpose_ = true;
|
||||
matmul_param->a_transpose_ = false;
|
||||
matmul_param->has_bias_ = param->GetHasBias();
|
||||
if (param->GetActivationType() == schema::ActivationType_RELU) {
|
||||
matmul_param->act_type_ = ActType_Relu;
|
||||
} else if (param->GetActivationType() == schema::ActivationType_RELU6) {
|
||||
matmul_param->act_type_ = ActType_Relu6;
|
||||
} else {
|
||||
matmul_param->act_type_ = ActType_No;
|
||||
}
|
||||
|
||||
return reinterpret_cast<OpParameter *>(matmul_param);
|
||||
}
|
||||
|
||||
Registry FullConnectionParameterRegistry(schema::PrimitiveType_FullConnection, PopulateFullconnectionParameter);
|
||||
|
||||
int FullConnection::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input0 = inputs_.front();
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/fused_batchnorm.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/batchnorm_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -72,7 +75,29 @@ float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_Fus
|
|||
float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); }
|
||||
int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); }
|
||||
|
||||
PrimitiveC *FusedBatchNormCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<FusedBatchNorm>(primitive);
|
||||
}
|
||||
Registry FusedBatchNormRegistry(schema::PrimitiveType_FusedBatchNorm, FusedBatchNormCreator);
|
||||
#endif
|
||||
OpParameter *PopulateFusedBatchNorm(const mindspore::lite::PrimitiveC *primitive) {
|
||||
BatchNormParameter *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
|
||||
if (batch_norm_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc BatchNormParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(batch_norm_param, 0, sizeof(BatchNormParameter));
|
||||
batch_norm_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param =
|
||||
reinterpret_cast<mindspore::lite::FusedBatchNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
batch_norm_param->epsilon_ = param->GetEpsilon();
|
||||
batch_norm_param->momentum_ = param->GetMomentum();
|
||||
batch_norm_param->fused_ = true;
|
||||
return reinterpret_cast<OpParameter *>(batch_norm_param);
|
||||
}
|
||||
|
||||
Registry FusedBatchNormParameterRegistry(schema::PrimitiveType_FusedBatchNorm, PopulateFusedBatchNorm);
|
||||
|
||||
int FusedBatchNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
if (outputs_.size() <= i) break;
|
||||
|
|
|
@ -19,6 +19,9 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#include "src/tensor.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/gather_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -46,8 +49,25 @@ int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
|
|||
int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); }
|
||||
int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); }
|
||||
|
||||
PrimitiveC *GatherCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Gather>(primitive); }
|
||||
Registry GatherRegistry(schema::PrimitiveType_Gather, GatherCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateGatherParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto gather_attr = reinterpret_cast<mindspore::lite::Gather *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
GatherParameter *gather_param = reinterpret_cast<GatherParameter *>(malloc(sizeof(GatherParameter)));
|
||||
if (gather_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc GatherParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(gather_param, 0, sizeof(GatherParameter));
|
||||
gather_param->op_parameter_.type_ = primitive->Type();
|
||||
gather_param->axis_ = gather_attr->GetAxis();
|
||||
gather_param->batchDims_ = gather_attr->GetBatchDims();
|
||||
return reinterpret_cast<OpParameter *>(gather_param);
|
||||
}
|
||||
Registry GatherParameterRegistry(schema::PrimitiveType_Gather, PopulateGatherParameter);
|
||||
|
||||
int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if (inputs_.size() != kDoubleNum) {
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/gather_nd.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/gatherNd.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -40,8 +43,28 @@ int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
|
|||
}
|
||||
int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); }
|
||||
|
||||
PrimitiveC *GatherNdCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<GatherNd>(primitive);
|
||||
}
|
||||
Registry GatherNdRegistry(schema::PrimitiveType_GatherNd, GatherNdCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateGatherNdParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
GatherNdParameter *gather_nd_param = reinterpret_cast<GatherNdParameter *>(malloc(sizeof(GatherNdParameter)));
|
||||
if (gather_nd_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc GatherNdParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(gather_nd_param, 0, sizeof(GatherNdParameter));
|
||||
gather_nd_param->op_parameter_.type_ = primitive->Type();
|
||||
auto gatherNd_attr =
|
||||
reinterpret_cast<mindspore::lite::GatherNd *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
gather_nd_param->batchDims_ = gatherNd_attr->GetBatchDims();
|
||||
return reinterpret_cast<OpParameter *>(gather_nd_param);
|
||||
}
|
||||
|
||||
Registry GatherNdParameterRegistry(schema::PrimitiveType_GatherNd, PopulateGatherNdParameter);
|
||||
|
||||
int GatherNd::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if (inputs_.size() != kDoubleNum) {
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/greater.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,6 +30,9 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); }
|
||||
Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator);
|
||||
#endif
|
||||
int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
|
@ -39,5 +44,6 @@ int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
|
|||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/greater_equal.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
@ -27,6 +29,12 @@ int GreaterEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *GreaterEqualCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<GreaterEqual>(primitive);
|
||||
}
|
||||
Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator);
|
||||
|
||||
#endif
|
||||
int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
|
@ -38,5 +46,6 @@ int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *
|
|||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/group_conv2d_grad_input.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -127,6 +129,10 @@ int GroupConv2DGradInput::GetActivationType() const {
|
|||
return this->primitive_->value_as_GroupConv2DGradInput()->activationType();
|
||||
}
|
||||
|
||||
PrimitiveC *GroupConv2DGradInputCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<GroupConv2DGradInput>(primitive);
|
||||
}
|
||||
Registry GroupConv2DGradInputRegistry(schema::PrimitiveType_GroupConv2DGradInput, GroupConv2DGradInputCreator);
|
||||
#endif
|
||||
|
||||
int GroupConv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -30,7 +32,24 @@ int HashtableLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, fla
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *HashtableLookupCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<HashtableLookup>(primitive);
|
||||
}
|
||||
Registry HashtableLookupRegistry(schema::PrimitiveType_HashtableLookup, HashtableLookupCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateHashtableLookupParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "new OpParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(OpParameter));
|
||||
param->type_ = primitive->Type();
|
||||
return param;
|
||||
}
|
||||
Registry HashtableLookupParameterRegistry(schema::PrimitiveType_HashtableLookup, PopulateHashtableLookupParameter);
|
||||
|
||||
int HashtableLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.at(0);
|
||||
auto values = inputs_.at(2);
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/l2_norm.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/l2_norm_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -57,6 +60,44 @@ std::vector<int> L2Norm::GetAxis() const {
|
|||
float L2Norm::GetEpsilon() const { return this->primitive_->value_as_L2Norm()->epsilon(); }
|
||||
int L2Norm::GetActivationType() const { return this->primitive_->value_as_L2Norm()->activationType(); }
|
||||
|
||||
PrimitiveC *L2NormCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<L2Norm>(primitive); }
|
||||
Registry L2NormRegistry(schema::PrimitiveType_L2Norm, L2NormCreator);
|
||||
#endif
|
||||
OpParameter *PopulateL2NormParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
L2NormParameter *l2_norm_parameter = reinterpret_cast<L2NormParameter *>(malloc(sizeof(L2NormParameter)));
|
||||
if (l2_norm_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc L2NormParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(l2_norm_parameter, 0, sizeof(L2NormParameter));
|
||||
l2_norm_parameter->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::L2Norm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
auto axis_vec = param->GetAxis();
|
||||
l2_norm_parameter->axis_num_ = axis_vec.size();
|
||||
l2_norm_parameter->axis_ = reinterpret_cast<int *>(malloc(axis_vec.size() * sizeof(int)));
|
||||
if (l2_norm_parameter->axis_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc axis_ data failed";
|
||||
free(l2_norm_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
for (size_t i = 0; i < axis_vec.size(); i++) {
|
||||
l2_norm_parameter->axis_[i] = axis_vec[i];
|
||||
}
|
||||
if (param->GetEpsilon() < 1e-6) {
|
||||
l2_norm_parameter->epsilon_ = 1e-6;
|
||||
} else {
|
||||
l2_norm_parameter->epsilon_ = param->GetEpsilon();
|
||||
}
|
||||
if (param->GetActivationType() == static_cast<int>(schema::ActivationType_RELU)) {
|
||||
l2_norm_parameter->act_type_ = ActType_Relu;
|
||||
} else if (param->GetActivationType() == static_cast<int>(schema::ActivationType_RELU6)) {
|
||||
l2_norm_parameter->act_type_ = ActType_Relu6;
|
||||
} else {
|
||||
l2_norm_parameter->act_type_ = ActType_No;
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(l2_norm_parameter);
|
||||
}
|
||||
Registry L2NormParameterRegistry(schema::PrimitiveType_L2Norm, PopulateL2NormParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/leaky_relu.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -42,6 +44,10 @@ int LeakyReLU::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *LeakyReLUCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<LeakyReLU>(primitive);
|
||||
}
|
||||
Registry LeakyReLURegistry(schema::PrimitiveType_LeakyReLU, LeakyReLUCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/less.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -29,6 +31,10 @@ int Less::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *LessCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Less>(primitive); }
|
||||
Registry LessRegistry(schema::PrimitiveType_Less, LessCreator);
|
||||
|
||||
#endif
|
||||
int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
|
@ -40,5 +46,6 @@ int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
|||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/less_equal.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,6 +30,10 @@ int LessEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *LessEqualCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<LessEqual>(primitive);
|
||||
}
|
||||
Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator);
|
||||
#endif
|
||||
int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
|
@ -39,5 +45,6 @@ int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o
|
|||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/local_response_normalization.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/local_response_norm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -76,6 +79,34 @@ int LocalResponseNormalization::UnPackToFlatBuilder(const schema::Primitive *pri
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *LocalResponseNormalizationCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<LocalResponseNormalization>(primitive);
|
||||
}
|
||||
Registry LocalResponseNormalizationRegistry(schema::PrimitiveType_LocalResponseNormalization,
|
||||
LocalResponseNormalizationCreator);
|
||||
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateLocalResponseNormParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto local_response_norm_attr = reinterpret_cast<mindspore::lite::LocalResponseNormalization *>(
|
||||
const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
LocalResponseNormParameter *lrn_param =
|
||||
reinterpret_cast<LocalResponseNormParameter *>(malloc(sizeof(LocalResponseNormParameter)));
|
||||
if (lrn_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc LocalResponseNormParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(lrn_param, 0, sizeof(LocalResponseNormParameter));
|
||||
lrn_param->op_parameter_.type_ = primitive->Type();
|
||||
lrn_param->depth_radius_ = local_response_norm_attr->GetDepthRadius();
|
||||
lrn_param->bias_ = local_response_norm_attr->GetBias();
|
||||
lrn_param->alpha_ = local_response_norm_attr->GetAlpha();
|
||||
lrn_param->beta_ = local_response_norm_attr->GetBeta();
|
||||
return reinterpret_cast<OpParameter *>(lrn_param);
|
||||
}
|
||||
|
||||
Registry LocalResponseNormalizationParameterRegistry(schema::PrimitiveType_LocalResponseNormalization,
|
||||
PopulateLocalResponseNormParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include "src/ops/log.h"
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -50,6 +52,11 @@ int Log::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *LogCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Log>(primitive); }
|
||||
Registry LogRegistry(schema::PrimitiveType_Log, LogCreator);
|
||||
|
||||
#endif
|
||||
Registry LogParameterRegistry(schema::PrimitiveType_Log, PopulateArithmeticSelf);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/log_grad.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
@ -32,6 +35,11 @@ int LogGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *LogGradCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<LogGrad>(primitive); }
|
||||
Registry LogGradRegistry(schema::PrimitiveType_LogGrad, LogGradCreator);
|
||||
#endif
|
||||
Registry LogGradParameterRegistry(schema::PrimitiveType_LogGrad, PopulateArithmeticSelf);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/logical_and.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,6 +30,13 @@ int LogicalAnd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *LogicalAndCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<LogicalAnd>(primitive);
|
||||
}
|
||||
Registry LogicalAndRegistry(schema::PrimitiveType_LogicalAnd, LogicalAndCreator);
|
||||
#endif
|
||||
|
||||
Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/logical_not.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,6 +30,12 @@ int LogicalNot::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *LogicalNotCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<LogicalNot>(primitive);
|
||||
}
|
||||
Registry LogicalNotRegistry(schema::PrimitiveType_LogicalNot, LogicalNotCreator);
|
||||
#endif
|
||||
Registry LogicalNotParameterRegistry(schema::PrimitiveType_LogicalNot, PopulateArithmeticSelf);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/logical_or.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,6 +30,12 @@ int LogicalOr::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *LogicalOrCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<LogicalOr>(primitive);
|
||||
}
|
||||
Registry LogicalOrRegistry(schema::PrimitiveType_LogicalOr, LogicalOrCreator);
|
||||
#endif
|
||||
Registry LogicalOrParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/lrn.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -49,6 +51,9 @@ int Lrn::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *LrnCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Lrn>(primitive); }
|
||||
Registry LrnRegistry(schema::PrimitiveType_Lrn, LrnCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
#include "nnacl/lsh_projection_parameter.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -38,7 +40,29 @@ int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *LshProjectionCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<LshProjection>(primitive);
|
||||
}
|
||||
Registry LshProjectionRegistry(schema::PrimitiveType_LshProjection, LshProjectionCreator);
|
||||
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateLshProjectionParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
LshProjectionParameter *lsh_project_param =
|
||||
reinterpret_cast<LshProjectionParameter *>(malloc(sizeof(LshProjectionParameter)));
|
||||
if (lsh_project_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc LshProjectionParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(lsh_project_param, 0, sizeof(LshProjectionParameter));
|
||||
lsh_project_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::LshProjection *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
lsh_project_param->lsh_type_ = param->GetLshType();
|
||||
return reinterpret_cast<OpParameter *>(lsh_project_param);
|
||||
}
|
||||
Registry LshProjectionParameterRegistry(schema::PrimitiveType_LshProjection, PopulateLshProjectionParameter);
|
||||
|
||||
int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) {
|
||||
MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given.";
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/lstm.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/lstm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -39,8 +42,31 @@ int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *LstmCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Lstm>(primitive); }
|
||||
Registry LstmRegistry(schema::PrimitiveType_Lstm, LstmCreator);
|
||||
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateLstmParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
LstmParameter *lstm_param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter)));
|
||||
if (lstm_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc LstmParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(lstm_param, 0, sizeof(LstmParameter));
|
||||
lstm_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::Lstm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
if (param == nullptr) {
|
||||
free(lstm_param);
|
||||
MS_LOG(ERROR) << "get Lstm param nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
lstm_param->bidirectional_ = param->GetBidirection();
|
||||
return reinterpret_cast<OpParameter *>(lstm_param);
|
||||
}
|
||||
Registry LstmParameterRegistry(schema::PrimitiveType_Lstm, PopulateLstmParameter);
|
||||
|
||||
const int kLstmInputNum = 6;
|
||||
const int kLstmOutputNum = 3;
|
||||
int Lstm::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -57,6 +59,11 @@ int MakeTuple::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *MakeTupleCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<MakeTuple>(primitive);
|
||||
}
|
||||
Registry MakeTupleRegistry(schema::PrimitiveType_MakeTuple, MakeTupleCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,9 @@
|
|||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -86,8 +89,27 @@ int MatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *MatMulCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<MatMul>(primitive); }
|
||||
Registry MatMulRegistry(schema::PrimitiveType_MatMul, MatMulCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateMatMulParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto param = reinterpret_cast<mindspore::lite::MatMul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
MatMulParameter *matmul_param = reinterpret_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter)));
|
||||
if (matmul_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc MatMulParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(matmul_param, 0, sizeof(MatMulParameter));
|
||||
matmul_param->op_parameter_.type_ = primitive->Type();
|
||||
matmul_param->b_transpose_ = param->GetTransposeB();
|
||||
matmul_param->a_transpose_ = param->GetTransposeA();
|
||||
matmul_param->has_bias_ = false;
|
||||
matmul_param->act_type_ = ActType_No;
|
||||
return reinterpret_cast<OpParameter *>(matmul_param);
|
||||
}
|
||||
Registry MatMulParameterRegistry(schema::PrimitiveType_MatMul, PopulateMatMulParameter);
|
||||
|
||||
int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input0 = inputs_.front();
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/matrix_diag.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -52,6 +54,11 @@ int MatrixDiag::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *MatrixDiagCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<MatrixDiag>(primitive);
|
||||
}
|
||||
Registry MatrixDiagRegistry(schema::PrimitiveType_MatrixDiag, MatrixDiagCreator);
|
||||
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,8 @@
|
|||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -62,6 +64,10 @@ int Maximum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *MaximumCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Maximum>(primitive); }
|
||||
Registry MaximumRegistry(schema::PrimitiveType_Maximum, MaximumCreator);
|
||||
#endif
|
||||
Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/mean.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/reduce_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -53,8 +56,36 @@ int Mean::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *MeanCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mean>(primitive); }
|
||||
Registry MeanRegistry(schema::PrimitiveType_Mean, MeanCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateMeanParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ReduceParameter *mean_param = reinterpret_cast<ReduceParameter *>(malloc(sizeof(ReduceParameter)));
|
||||
if (mean_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ReduceParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(mean_param, 0, sizeof(ReduceParameter));
|
||||
mean_param->op_parameter_.type_ = primitive->Type();
|
||||
auto mean = reinterpret_cast<mindspore::lite::Mean *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
mean_param->keep_dims_ = mean->GetKeepDims();
|
||||
auto axisVector = mean->GetAxis();
|
||||
if (axisVector.size() > REDUCE_MAX_AXES_NUM) {
|
||||
MS_LOG(ERROR) << "Reduce axes size " << axisVector.size() << " exceed limit " << REDUCE_MAX_AXES_NUM;
|
||||
free(mean_param);
|
||||
return nullptr;
|
||||
}
|
||||
mean_param->num_axes_ = static_cast<int>(axisVector.size());
|
||||
int i = 0;
|
||||
for (auto iter = axisVector.begin(); iter != axisVector.end(); iter++) {
|
||||
mean_param->axes_[i++] = *iter;
|
||||
}
|
||||
mean_param->mode_ = static_cast<int>(schema::ReduceMode_ReduceMean);
|
||||
return reinterpret_cast<OpParameter *>(mean_param);
|
||||
}
|
||||
Registry MeanParameterRegistry(schema::PrimitiveType_Mean, PopulateMeanParameter);
|
||||
|
||||
namespace {
|
||||
constexpr size_t kInputSize = 1;
|
||||
constexpr size_t kOutputSize = 1;
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/minimum.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,6 +30,10 @@ int Minimum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *MinimumCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Minimum>(primitive); }
|
||||
Registry MinimumRegistry(schema::PrimitiveType_Minimum, MinimumCreator);
|
||||
#endif
|
||||
|
||||
Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/mul.h"
|
||||
#include <memory>
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -72,6 +74,30 @@ int Mul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *MulCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mul>(primitive); }
|
||||
Registry MulRegistry(schema::PrimitiveType_Mul, MulCreator);
|
||||
#endif
|
||||
OpParameter *PopulateMulParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
|
||||
if (arithmetic_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
|
||||
arithmetic_param->op_parameter_.type_ = primitive->Type();
|
||||
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
|
||||
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
|
||||
arithmetic_param->activation_type_ =
|
||||
reinterpret_cast<mindspore::lite::Mul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
|
||||
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
|
||||
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
|
||||
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
|
||||
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
|
||||
return reinterpret_cast<OpParameter *>(arithmetic_param);
|
||||
}
|
||||
Registry MulParameterRegistry(schema::PrimitiveType_Mul, PopulateMulParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
#include "src/ops/nchw2nhwc.h"
|
||||
#include "src/common/common.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/transpose.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -29,8 +32,29 @@ int Nchw2Nhwc::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *Nchw2NhwcCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<Nchw2Nhwc>(primitive);
|
||||
}
|
||||
Registry Nchw2NhwcRegistry(schema::PrimitiveType_Nchw2Nhwc, Nchw2NhwcCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateNchw2NhwcParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
TransposeParameter *parameter = reinterpret_cast<TransposeParameter *>(malloc(sizeof(TransposeParameter)));
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc OpParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(parameter, 0, sizeof(OpParameter));
|
||||
parameter->op_parameter_.type_ = primitive->Type();
|
||||
parameter->num_axes_ = 4;
|
||||
parameter->perm_[0] = 0;
|
||||
parameter->perm_[1] = 2;
|
||||
parameter->perm_[2] = 3;
|
||||
parameter->perm_[3] = 1;
|
||||
return reinterpret_cast<OpParameter *>(parameter);
|
||||
}
|
||||
Registry Nchw2NhwcParameterRegistry(schema::PrimitiveType_Nchw2Nhwc, PopulateNchw2NhwcParameter);
|
||||
|
||||
int Nchw2Nhwc::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/neg.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -52,6 +54,9 @@ int Neg::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *NegCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Neg>(primitive); }
|
||||
Registry NegRegistry(schema::PrimitiveType_Neg, NegCreator);
|
||||
#endif
|
||||
Registry NegParameterRegistry(schema::PrimitiveType_Neg, PopulateArithmeticSelf);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/neg_grad.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,6 +30,11 @@ int NegGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *NegGradCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<NegGrad>(primitive); }
|
||||
Registry NegGradRegistry(schema::PrimitiveType_NegGrad, NegGradCreator);
|
||||
|
||||
#endif
|
||||
Registry NegGradParameterRegistry(schema::PrimitiveType_NegGrad, PopulateArithmeticSelf);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
#include "src/ops/nhwc2nchw.h"
|
||||
#include "src/common/common.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/transpose.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
|
@ -30,8 +33,30 @@ int Nhwc2Nchw::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *Nhwc2NchwCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<Nhwc2Nchw>(primitive);
|
||||
}
|
||||
Registry Nhwc2NchwRegistry(schema::PrimitiveType_Nhwc2Nchw, Nhwc2NchwCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateNhwc2NchwParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
TransposeParameter *parameter = reinterpret_cast<TransposeParameter *>(malloc(sizeof(TransposeParameter)));
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc OpParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(parameter, 0, sizeof(OpParameter));
|
||||
parameter->op_parameter_.type_ = primitive->Type();
|
||||
parameter->num_axes_ = 4;
|
||||
parameter->perm_[0] = 0;
|
||||
parameter->perm_[1] = 3;
|
||||
parameter->perm_[2] = 1;
|
||||
parameter->perm_[3] = 2;
|
||||
return reinterpret_cast<OpParameter *>(parameter);
|
||||
}
|
||||
|
||||
Registry Nhwc2NchwParameterRegistry(schema::PrimitiveType_Nhwc2Nchw, PopulateNhwc2NchwParameter);
|
||||
|
||||
int Nhwc2Nchw::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input = inputs_.front();
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/not_equal.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -28,6 +30,11 @@ int NotEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *NotEqualCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<NotEqual>(primitive);
|
||||
}
|
||||
Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator);
|
||||
|
||||
#endif
|
||||
int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
|
@ -39,5 +46,6 @@ int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> ou
|
|||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/one_hot.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/fp32/one_hot.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -71,8 +74,30 @@ int OneHot::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *OneHotCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<OneHot>(primitive); }
|
||||
Registry OneHotRegistry(schema::PrimitiveType_OneHot, OneHotCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulateOneHotParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
OneHotParameter *one_hot_param = reinterpret_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter)));
|
||||
if (one_hot_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc OneHotParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(one_hot_param, 0, sizeof(OneHotParameter));
|
||||
one_hot_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::OneHot *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
if (param == nullptr) {
|
||||
free(one_hot_param);
|
||||
MS_LOG(ERROR) << "get OneHot param nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
one_hot_param->axis_ = param->GetAxis();
|
||||
return reinterpret_cast<OpParameter *>(one_hot_param);
|
||||
}
|
||||
Registry OneHotParameterRegistry(schema::PrimitiveType_OneHot, PopulateOneHotParameter);
|
||||
|
||||
namespace {
|
||||
constexpr size_t kOneHotInputNum = 4;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_OP_REGISTER_H
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_OP_REGISTER_H
|
||||
|
||||
#include <map>
|
||||
#include "src/ops/primitive_c.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OpsRegistry {
|
||||
public:
|
||||
static OpsRegistry *GetInstance() {
|
||||
static OpsRegistry registry;
|
||||
return ®istry;
|
||||
}
|
||||
|
||||
void insertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) {
|
||||
primitive_creators[type] = creator;
|
||||
}
|
||||
PrimitiveCCreator getPrimitiveCreator(schema::PrimitiveType type) {
|
||||
if (primitive_creators.find(type) != primitive_creators.end()) {
|
||||
return primitive_creators[type];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(type);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void insertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; }
|
||||
|
||||
ParameterCreator getParameterCreator(schema::PrimitiveType type) {
|
||||
if (parameter_creators.find(type) != parameter_creators.end()) {
|
||||
return parameter_creators[type];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported parameter type in Create : " << schema::EnumNamePrimitiveType(type);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
std::map<schema::PrimitiveType, PrimitiveCCreator> primitive_creators;
|
||||
std::map<schema::PrimitiveType, ParameterCreator> parameter_creators;
|
||||
};
|
||||
|
||||
class Registry {
|
||||
public:
|
||||
Registry(schema::PrimitiveType primitive_type, PrimitiveCCreator creator) {
|
||||
OpsRegistry::GetInstance()->insertPrimitiveCMap(primitive_type, creator);
|
||||
}
|
||||
Registry(schema::PrimitiveType primitive_type, ParameterCreator creator) {
|
||||
OpsRegistry::GetInstance()->insertParameterMap(primitive_type, creator);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_OP_REGISTER_H
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/p_relu.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/prelu_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -46,6 +49,24 @@ int PReLU::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *PReLUCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<PReLU>(primitive); }
|
||||
Registry PReLURegistry(schema::PrimitiveType_PReLU, PReLUCreator);
|
||||
|
||||
#endif
|
||||
|
||||
OpParameter *PopulatePReLUParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto param = reinterpret_cast<mindspore::lite::PReLU *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
PReluParameter *prelu_param = reinterpret_cast<PReluParameter *>(malloc(sizeof(PReluParameter)));
|
||||
if (prelu_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc PReluParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(prelu_param, 0, sizeof(PReluParameter));
|
||||
prelu_param->op_parameter_.type_ = primitive->Type();
|
||||
prelu_param->channelShared = param->GetChannelShared();
|
||||
return reinterpret_cast<OpParameter *>(prelu_param);
|
||||
}
|
||||
Registry PReLUParameterRegistry(schema::PrimitiveType_PReLU, PopulatePReLUParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/pad.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/pad_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -57,7 +60,42 @@ int Pad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *PadCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Pad>(primitive); }
|
||||
Registry PadRegistry(schema::PrimitiveType_Pad, PadCreator);
|
||||
#endif
|
||||
OpParameter *PopulatePadParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
PadParameter *pad_param = reinterpret_cast<PadParameter *>(malloc(sizeof(PadParameter)));
|
||||
if (pad_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc PadParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(pad_param, 0, sizeof(PadParameter));
|
||||
pad_param->op_parameter_.type_ = primitive->Type();
|
||||
auto pad_node = reinterpret_cast<mindspore::lite::Pad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
pad_param->pad_mode_ = pad_node->GetPaddingMode();
|
||||
if (pad_param->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
|
||||
pad_param->constant_value_ = pad_node->GetConstantValue();
|
||||
auto size = pad_node->GetPaddings().size();
|
||||
if (size > MAX_PAD_SIZE) {
|
||||
MS_LOG(ERROR) << "Invalid padding size: " << size;
|
||||
free(pad_param);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) {
|
||||
pad_param->paddings_[i] = 0;
|
||||
}
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i];
|
||||
}
|
||||
pad_param->padding_length = MAX_PAD_SIZE;
|
||||
}
|
||||
|
||||
return reinterpret_cast<OpParameter *>(pad_param);
|
||||
}
|
||||
Registry PadParameterRegistry(schema::PrimitiveType_Pad, PopulatePadParameter);
|
||||
|
||||
int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if (this->primitive_ == nullptr) {
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/permute.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -50,6 +52,9 @@ int Permute::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *PermuteCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Permute>(primitive); }
|
||||
Registry PermuteRegistry(schema::PrimitiveType_Permute, PermuteCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,9 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
|
@ -158,8 +161,73 @@ int Pooling::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *PoolingCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Pooling>(primitive); }
|
||||
Registry PoolingRegistry(schema::PrimitiveType_Pooling, PoolingCreator);
|
||||
|
||||
#endif
|
||||
|
||||
OpParameter *PopulatePoolingParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto pooling_primitive =
|
||||
reinterpret_cast<mindspore::lite::Pooling *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
|
||||
if (pooling_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc PoolingParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(pooling_param, 0, sizeof(PoolingParameter));
|
||||
pooling_param->op_parameter_.type_ = primitive->Type();
|
||||
pooling_param->global_ = pooling_primitive->GetGlobal();
|
||||
pooling_param->window_w_ = pooling_primitive->GetWindowW();
|
||||
pooling_param->window_h_ = pooling_primitive->GetWindowH();
|
||||
auto pooling_lite_primitive = (lite::Pooling *)primitive;
|
||||
pooling_param->pad_u_ = pooling_lite_primitive->PadUp();
|
||||
pooling_param->pad_d_ = pooling_lite_primitive->PadDown();
|
||||
pooling_param->pad_l_ = pooling_lite_primitive->PadLeft();
|
||||
pooling_param->pad_r_ = pooling_lite_primitive->PadRight();
|
||||
pooling_param->stride_w_ = pooling_primitive->GetStrideW();
|
||||
pooling_param->stride_h_ = pooling_primitive->GetStrideH();
|
||||
pooling_param->avg_mode_ = pooling_primitive->GetAvgMode();
|
||||
|
||||
auto is_global = pooling_primitive->GetGlobal();
|
||||
pooling_param->global_ = is_global;
|
||||
auto pool_mode = pooling_primitive->GetPoolingMode();
|
||||
switch (pool_mode) {
|
||||
case schema::PoolMode_MAX_POOLING:
|
||||
pooling_param->pool_mode_ = PoolMode_MaxPool;
|
||||
break;
|
||||
case schema::PoolMode_MEAN_POOLING:
|
||||
pooling_param->pool_mode_ = PoolMode_AvgPool;
|
||||
break;
|
||||
default:
|
||||
pooling_param->pool_mode_ = PoolMode_No;
|
||||
break;
|
||||
}
|
||||
|
||||
auto round_mode = pooling_primitive->GetRoundMode();
|
||||
switch (round_mode) {
|
||||
case schema::RoundMode_FLOOR:
|
||||
pooling_param->round_mode_ = RoundMode_Floor;
|
||||
break;
|
||||
case schema::RoundMode_CEIL:
|
||||
pooling_param->round_mode_ = RoundMode_Ceil;
|
||||
break;
|
||||
default:
|
||||
pooling_param->round_mode_ = RoundMode_No;
|
||||
break;
|
||||
}
|
||||
|
||||
if (pooling_primitive->GetActivationType() == schema::ActivationType_RELU) {
|
||||
pooling_param->act_type_ = ActType_Relu;
|
||||
} else if (pooling_primitive->GetActivationType() == schema::ActivationType_RELU6) {
|
||||
pooling_param->act_type_ = ActType_Relu6;
|
||||
} else {
|
||||
pooling_param->act_type_ = ActType_No;
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(pooling_param);
|
||||
}
|
||||
|
||||
Registry PoolingParameterRegistry(schema::PrimitiveType_Pooling, PopulatePoolingParameter);
|
||||
|
||||
int Pooling::PadUp() const { return this->pad_u_; }
|
||||
int Pooling::PadDown() const { return this->pad_d_; }
|
||||
int Pooling::PadLeft() const { return this->pad_l_; }
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/pooling_grad.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -142,6 +144,11 @@ int PoolingGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuf
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *PoolingGradCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<PoolingGrad>(primitive);
|
||||
}
|
||||
Registry PoolingGradRegistry(schema::PrimitiveType_PoolingGrad, PoolingGradCreator);
|
||||
#endif
|
||||
|
||||
int PoolingGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/ops/power.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
#include "nnacl/power_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -45,8 +48,28 @@ int Power::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *PowerCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Power>(primitive); }
|
||||
Registry PowerRegistry(schema::PrimitiveType_Power, PowerCreator);
|
||||
#endif
|
||||
|
||||
OpParameter *PopulatePowerParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
PowerParameter *power_param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter)));
|
||||
if (power_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc PowerParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(power_param, 0, sizeof(PowerParameter));
|
||||
power_param->op_parameter_.type_ = primitive->Type();
|
||||
auto power = reinterpret_cast<mindspore::lite::Power *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
power_param->power_ = power->GetPower();
|
||||
power_param->scale_ = power->GetScale();
|
||||
power_param->shift_ = power->GetShift();
|
||||
return reinterpret_cast<OpParameter *>(power_param);
|
||||
}
|
||||
|
||||
Registry PowerParameterRegistry(schema::PrimitiveType_Power, PopulatePowerParameter);
|
||||
|
||||
int Power::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto x_tensor = inputs[0];
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/ops/power_grad.h"
|
||||
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -77,6 +79,11 @@ int PowerGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
|
|||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *PowerGradCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<PowerGrad>(primitive);
|
||||
}
|
||||
Registry PowerGradRegistry(schema::PrimitiveType_PowerGrad, PowerGradCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue