add op populate

This commit is contained in:
gongdaguo1 2023-02-27 11:38:24 +08:00
parent 0db81d3bb6
commit c58e13c692
12 changed files with 611 additions and 158 deletions

View File

@ -0,0 +1,122 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/operator_populate/operator_populate_register.h"
#include "nnacl/fp32/activation_fp32.h"
#include "ops/fusion/activation.h"
#include "ops/relu.h"
#include "ops/relu6.h"
#include "ops/leaky_relu.h"
#include "ops/sigmoid.h"
#include "ops/tanh.h"
#include "ops/hswish.h"
#include "ops/hsigmoid.h"
#include "ops/gelu.h"
#include "ops/softplus.h"
#include "ops/elu.h"
using mindspore::ops::kActivationType;
using mindspore::ops::kAlpha;
using mindspore::ops::kApproximate;
using mindspore::ops::kMaxVal;
using mindspore::ops::kMinVal;
using mindspore::ops::kNameActivation;
using mindspore::ops::kNameElu;
using mindspore::ops::kNameGeLU;
using mindspore::ops::kNameHSigmoid;
using mindspore::ops::kNameHSwish;
using mindspore::ops::kNameLeakyRelu;
using mindspore::ops::kNameReLU;
using mindspore::ops::kNameReLU6;
using mindspore::ops::kNameSigmoid;
using mindspore::ops::kNameSoftplus;
using mindspore::ops::kNameTanh;
using mindspore::schema::PrimitiveType_Activation;
namespace mindspore {
namespace lite {
OpParameter *PopulateActivationOpParameter(const BaseOperatorPtr &base_operator) {
auto param = reinterpret_cast<ActivationParameter *>(PopulateOpParameter<ActivationParameter>());
if (param == nullptr) {
MS_LOG(ERROR) << "new ActivationParameter failed.";
return nullptr;
}
mindspore::ValuePtr attr = base_operator->GetPrim()->GetAttr(kActivationType);
if (attr != nullptr) {
auto activation_type = GetValue<int64_t>(attr);
static const std::set<int> activation_types = {
schema::ActivationType_RELU, schema::ActivationType_RELU6, schema::ActivationType_LEAKY_RELU,
schema::ActivationType_SIGMOID, schema::ActivationType_TANH, schema::ActivationType_SWISH,
schema::ActivationType_HSWISH, schema::ActivationType_HSIGMOID, schema::ActivationType_HARD_TANH,
schema::ActivationType_GELU, schema::ActivationType_SOFTPLUS, schema::ActivationType_ELU};
if (activation_types.find(activation_type) == activation_types.end()) {
MS_LOG(ERROR) << "invalid activation type: " << activation_type;
free(param);
return nullptr;
}
param->type_ = activation_type;
} else {
auto type_name = base_operator->name();
static const std::map<std::string, int> op_type_map = {{kNameReLU, schema::ActivationType_RELU},
{kNameReLU6, schema::ActivationType_RELU6},
{kNameLeakyRelu, schema::ActivationType_LEAKY_RELU},
{kNameSigmoid, schema::ActivationType_SIGMOID},
{kNameTanh, schema::ActivationType_TANH},
{kNameHSwish, schema::ActivationType_HSWISH},
{kNameHSigmoid, schema::ActivationType_HSIGMOID},
{kNameGeLU, schema::ActivationType_GELU},
{kNameSoftplus, schema::ActivationType_SOFTPLUS},
{kNameElu, schema::ActivationType_ELU}};
auto iter = op_type_map.find(type_name);
if (iter == op_type_map.end()) {
MS_LOG(ERROR) << "invalid activation type: " << type_name;
free(param);
return nullptr;
}
param->type_ = iter->second;
}
mindspore::ValuePtr alpha = base_operator->GetPrim()->GetAttr(kAlpha);
if (alpha != nullptr) {
param->alpha_ = GetValue<float>(alpha);
}
mindspore::ValuePtr min_val = base_operator->GetPrim()->GetAttr(kMinVal);
if (min_val != nullptr) {
param->min_val_ = GetValue<float>(min_val);
}
mindspore::ValuePtr max_val = base_operator->GetPrim()->GetAttr(kMaxVal);
if (max_val != nullptr) {
param->max_val_ = GetValue<float>(max_val);
}
mindspore::ValuePtr approximate = base_operator->GetPrim()->GetAttr(kApproximate);
if (approximate != nullptr) {
param->approximate_ = GetValue<bool>(approximate);
}
return reinterpret_cast<OpParameter *>(param);
}
REG_OPERATOR_POPULATE(kNameActivation, PrimitiveType_Activation, PopulateActivationOpParameter)
REG_OPERATOR_POPULATE(kNameReLU, PrimitiveType_Activation, PopulateActivationOpParameter)
REG_OPERATOR_POPULATE(kNameReLU6, PrimitiveType_Activation, PopulateActivationOpParameter)
REG_OPERATOR_POPULATE(kNameLeakyRelu, PrimitiveType_Activation, PopulateActivationOpParameter)
REG_OPERATOR_POPULATE(kNameSigmoid, PrimitiveType_Activation, PopulateActivationOpParameter)
REG_OPERATOR_POPULATE(kNameTanh, PrimitiveType_Activation, PopulateActivationOpParameter)
REG_OPERATOR_POPULATE(kNameHSwish, PrimitiveType_Activation, PopulateActivationOpParameter)
REG_OPERATOR_POPULATE(kNameHSigmoid, PrimitiveType_Activation, PopulateActivationOpParameter)
REG_OPERATOR_POPULATE(kNameGeLU, PrimitiveType_Activation, PopulateActivationOpParameter)
REG_OPERATOR_POPULATE(kNameSoftplus, PrimitiveType_Activation, PopulateActivationOpParameter)
REG_OPERATOR_POPULATE(kNameElu, PrimitiveType_Activation, PopulateActivationOpParameter)
} // namespace lite
} // namespace mindspore

View File

@ -1,49 +0,0 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/operator_populate/operator_populate_register.h"
#include "nnacl/arithmetic.h"
#include "ops/add.h"
using mindspore::ops::kNameAdd;
namespace mindspore {
namespace lite {
OpParameter *PopulateAddOpParameter(const BaseOperatorPtr &base_operator) {
if (base_operator == nullptr) {
MS_LOG(ERROR) << "base_operator is nullptr";
return nullptr;
}
auto op_parameter_ptr = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (op_parameter_ptr == nullptr) {
MS_LOG(ERROR) << "Make OpParameter ptr failed";
return nullptr;
}
memset(op_parameter_ptr, 0, sizeof(ArithmeticParameter));
auto name = base_operator->name();
auto iter = kOpNameWithPrimitiveType.find(name);
if (iter == kOpNameWithPrimitiveType.end()) {
MS_LOG(ERROR) << "Can not find ParameterPtrGen : " << name;
return nullptr;
}
op_parameter_ptr->op_parameter_.type_ = iter->second;
op_parameter_ptr->broadcasting_ = false;
op_parameter_ptr->ndim_ = 0;
op_parameter_ptr->activation_type_ = 0;
return reinterpret_cast<OpParameter *>(op_parameter_ptr);
}
REG_OPERATOR_POPULATE(kNameAdd, PopulateAddOpParameter)
} // namespace lite
} // namespace mindspore

View File

@ -31,7 +31,12 @@
#include "ops/floor_mod.h" #include "ops/floor_mod.h"
#include "ops/squared_difference.h" #include "ops/squared_difference.h"
#include "ops/mod.h" #include "ops/mod.h"
#include "ops/add.h"
#include "ops/fusion/add_fusion.h"
using mindspore::ops::kActivationType;
using mindspore::ops::kNameAdd;
using mindspore::ops::kNameAddFusion;
using mindspore::ops::kNameEqual; using mindspore::ops::kNameEqual;
using mindspore::ops::kNameFloorDiv; using mindspore::ops::kNameFloorDiv;
using mindspore::ops::kNameFloorMod; using mindspore::ops::kNameFloorMod;
@ -48,57 +53,66 @@ using mindspore::ops::kNameNotEqual;
using mindspore::ops::kNameRealDiv; using mindspore::ops::kNameRealDiv;
using mindspore::ops::kNameSquaredDifference; using mindspore::ops::kNameSquaredDifference;
using mindspore::schema::PrimitiveType_AddFusion;
using mindspore::schema::PrimitiveType_Equal;
using mindspore::schema::PrimitiveType_FloorDiv;
using mindspore::schema::PrimitiveType_FloorMod;
using mindspore::schema::PrimitiveType_Greater;
using mindspore::schema::PrimitiveType_GreaterEqual;
using mindspore::schema::PrimitiveType_Less;
using mindspore::schema::PrimitiveType_LessEqual;
using mindspore::schema::PrimitiveType_LogicalAnd;
using mindspore::schema::PrimitiveType_LogicalOr;
using mindspore::schema::PrimitiveType_Maximum;
using mindspore::schema::PrimitiveType_Minimum;
using mindspore::schema::PrimitiveType_Mod;
using mindspore::schema::PrimitiveType_NotEqual;
using mindspore::schema::PrimitiveType_RealDiv;
using mindspore::schema::PrimitiveType_SquaredDifference;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
ArithmeticParameter *PopulateArithmeticCommonOpPara(const BaseOperatorPtr &base_operator) { OpParameter *PopulateArithmeticCommonOpPara(const BaseOperatorPtr &base_operator) {
if (base_operator == nullptr) { auto param = reinterpret_cast<ArithmeticParameter *>(PopulateOpParameter<ArithmeticParameter>());
MS_LOG(ERROR) << "base_operator is nullptr";
return nullptr;
}
auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; MS_LOG(ERROR) << "new ArithmeticParameter failed.";
return nullptr; return nullptr;
} }
memset(param, 0, sizeof(ArithmeticParameter));
auto name = base_operator->name();
auto iter = kOpNameWithPrimitiveType.find(name);
if (iter == kOpNameWithPrimitiveType.end()) {
MS_LOG(ERROR) << "Can not find ParameterPtrGen : " << name;
return nullptr;
}
param->op_parameter_.type_ = iter->second;
param->broadcasting_ = false; param->broadcasting_ = false;
param->ndim_ = 0; param->ndim_ = 0;
param->activation_type_ = 0; param->activation_type_ = 0;
return param; return reinterpret_cast<OpParameter *>(param);
} }
OpParameter *PopulateArithmeticOp(const BaseOperatorPtr &base_operator) { OpParameter *PopulateAddOpParameter(const BaseOperatorPtr &base_operator) {
ArithmeticParameter *param = PopulateArithmeticCommonOpPara(base_operator); ArithmeticParameter *param = reinterpret_cast<ArithmeticParameter *>(PopulateArithmeticCommonOpPara(base_operator));
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; MS_LOG(ERROR) << "PopulateArithmeticCommonOpPara failed.";
return nullptr; return nullptr;
} }
mindspore::ValuePtr attr = base_operator->GetPrim()->GetAttr(kActivationType);
if (attr != nullptr) {
param->activation_type_ = ActivationType(GetValue<int64_t>(attr));
}
return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);
} }
REG_OPERATOR_POPULATE(kNameRealDiv, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameAdd, PrimitiveType_AddFusion, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameLogicalAnd, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameAddFusion, PrimitiveType_AddFusion, PopulateAddOpParameter)
REG_OPERATOR_POPULATE(kNameLogicalOr, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameRealDiv, PrimitiveType_RealDiv, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameEqual, PopulateArithmeticOp); REG_OPERATOR_POPULATE(kNameLogicalAnd, PrimitiveType_LogicalAnd, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameNotEqual, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameLogicalOr, PrimitiveType_LogicalOr, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameLess, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameEqual, PrimitiveType_Equal, PopulateArithmeticCommonOpPara);
REG_OPERATOR_POPULATE(kNameLessEqual, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameNotEqual, PrimitiveType_NotEqual, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameGreater, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameLess, PrimitiveType_Less, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameGreaterEqual, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameLessEqual, PrimitiveType_LessEqual, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameMaximum, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameGreater, PrimitiveType_Greater, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameMinimum, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameGreaterEqual, PrimitiveType_GreaterEqual, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameFloorDiv, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameMaximum, PrimitiveType_Maximum, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameFloorMod, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameMinimum, PrimitiveType_Minimum, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameMod, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameFloorDiv, PrimitiveType_FloorDiv, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameSquaredDifference, PopulateArithmeticOp) REG_OPERATOR_POPULATE(kNameFloorMod, PrimitiveType_FloorMod, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameMod, PrimitiveType_Mod, PopulateArithmeticCommonOpPara)
REG_OPERATOR_POPULATE(kNameSquaredDifference, PrimitiveType_SquaredDifference, PopulateArithmeticCommonOpPara)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -1,46 +0,0 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/operator_populate/operator_populate_register.h"
#include "ops/assert.h"
using mindspore::ops::kNameAssert;
namespace mindspore {
namespace lite {
OpParameter *PopulateAssertOpParameter(const BaseOperatorPtr &base_operator) {
if (base_operator == nullptr) {
MS_LOG(ERROR) << "base_operator is nullptr";
return nullptr;
}
auto op_parameter_ptr = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (op_parameter_ptr == nullptr) {
MS_LOG(ERROR) << "Make OpParameter ptr failed";
return nullptr;
}
memset(op_parameter_ptr, 0, sizeof(OpParameter));
auto name = base_operator->name();
auto iter = kOpNameWithPrimitiveType.find(name);
if (iter == kOpNameWithPrimitiveType.end()) {
MS_LOG(ERROR) << "Can not find ParameterPtrGen : " << name;
return nullptr;
}
op_parameter_ptr->type_ = iter->second;
return op_parameter_ptr;
}
REG_OPERATOR_POPULATE(kNameAssert, PopulateAssertOpParameter)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/operator_populate/operator_populate_register.h"
#include "nnacl/where_parameter.h"
#include "nnacl/sparse_to_dense_parameter.h"
#include "nnacl/transpose.h"
#include "nnacl/triu_tril.h"
#include "nnacl/fp32/unique_fp32.h"
#include "nnacl/scatter_nd_parameter.h"
#include "ops/assert.h"
#include "ops/where.h"
#include "ops/unsorted_segment_sum.h"
#include "ops/unique.h"
#include "ops/triu.h"
#include "ops/tril.h"
#include "ops/transpose.h"
#include "ops/sparse_to_dense.h"
#include "ops/sparse_segment_sum.h"
#include "ops/sparse_reshape.h"
#include "ops/sparse_fill_empty_rows.h"
#include "ops/size.h"
#include "ops/shape.h"
#include "ops/select.h"
#include "ops/scatter_nd_update.h"
#include "ops/tensor_scatter_add.h"
#include "ops/scatter_nd.h"
#include "ops/expand_dims.h"
namespace mindspore {
namespace lite {
REG_OP_BASE_POPULATE(Assert)
REG_OP_BASE_POPULATE(UnsortedSegmentSum)
REG_OP_BASE_POPULATE(SparseSegmentSum)
REG_OP_BASE_POPULATE(SparseReshape)
REG_OP_BASE_POPULATE(SparseFillEmptyRows)
REG_OP_BASE_POPULATE(Size)
REG_OP_BASE_POPULATE(Shape)
REG_OP_BASE_POPULATE(Select)
REG_OP_BASE_POPULATE(ExpandDims)
REG_OP_DEFAULT_POPULATE(SparseToDense)
REG_OP_DEFAULT_POPULATE(Transpose)
REG_OP_DEFAULT_POPULATE(Tril)
REG_OP_DEFAULT_POPULATE(Triu)
REG_OP_DEFAULT_POPULATE(Where)
REG_OP_DEFAULT_POPULATE(Unique)
using mindspore::ops::kNameScatterNd;
using mindspore::ops::kNameScatterNdUpdate;
using mindspore::ops::kNameTensorScatterAdd;
using mindspore::schema::PrimitiveType_ScatterNd;
using mindspore::schema::PrimitiveType_ScatterNdUpdate;
using mindspore::schema::PrimitiveType_TensorScatterAdd;
REG_OPERATOR_POPULATE(kNameScatterNd, PrimitiveType_ScatterNd, PopulateOpParameter<ScatterNDParameter>)
REG_OPERATOR_POPULATE(kNameScatterNdUpdate, PrimitiveType_ScatterNdUpdate, PopulateOpParameter<ScatterNDParameter>)
REG_OPERATOR_POPULATE(kNameTensorScatterAdd, PrimitiveType_TensorScatterAdd, PopulateOpParameter<ScatterNDParameter>)
} // namespace lite
} // namespace mindspore

View File

@ -22,5 +22,9 @@ OperatorPopulateRegistry *OperatorPopulateRegistry::GetInstance() {
static OperatorPopulateRegistry registry; static OperatorPopulateRegistry registry;
return &registry; return &registry;
} }
OpParameter *CreatePopulatePtr(const BaseOperatorPtr &base_operator) {
return OperatorPopulateRegistry::GetInstance()->CreatePopulateByOp(base_operator);
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "src/common/common.h" #include "src/common/common.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
@ -28,58 +29,89 @@
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"
#include "ops/base_operator.h" #include "ops/base_operator.h"
#include "mindspore/core/ir/primitive.h"
#include "ops/primitive_c.h"
#include "ops/op_name.h"
#include "schema/ops_generated.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
using BaseOperatorPtr = std::shared_ptr<mindspore::ops::BaseOperator>; using BaseOperatorPtr = std::shared_ptr<mindspore::ops::BaseOperator>;
typedef OpParameter *(*ParameterPtrGen)(const BaseOperatorPtr &base_operator); typedef OpParameter *(*ParameterPtrGen)(const BaseOperatorPtr &base_operator);
OpParameter *CreatePopulatePtr(const BaseOperatorPtr &base_operator);
static const std::map<std::string, int> kOpNameWithPrimitiveType = {{"Add", 5},
{"Assert", 17},
{"Equal", 53},
{"FloorDiv", 64},
{"FloorMod", 65},
{"Greater", 71},
{"GreaterEqual", 72},
{"Less", 77},
{"LessEqual", 78},
{"LogicalAnd", 81},
{"LogicalOr", 83},
{"Maximum", 90},
{"Minimum", 96},
{"Mod", 98},
{"NotEqual", 103},
{"RealDiv", 117},
{"SquaredDifference", 149}};
class OperatorPopulateRegistry { class OperatorPopulateRegistry {
public: public:
static OperatorPopulateRegistry *GetInstance(); static OperatorPopulateRegistry *GetInstance();
void InsertOperatorParameterMap(const std::string &name, ParameterPtrGen creator) { op_parameters_[name] = creator; } void InsertOperatorParameterMap(const std::string &name, int type, ParameterPtrGen creator) {
op_parameters_[name] = std::make_pair(creator, type);
}
ParameterPtrGen GetParameterPtrCreator(const std::string &name) { return CreatePopulatePtr; }
ParameterPtrGen GetParameterPtrCreator(const std::string &name) { OpParameter *CreatePopulateByOp(const BaseOperatorPtr &base_operator) {
auto iter = op_parameters_.find(name); MS_CHECK_TRUE_RET(base_operator != nullptr, nullptr);
auto iter = op_parameters_.find(base_operator->name());
if (iter == op_parameters_.end()) { if (iter == op_parameters_.end()) {
MS_LOG(ERROR) << "Unsupported op in creator " << name; MS_LOG(ERROR) << "Unsupported op in creator " << base_operator->name();
return nullptr; return nullptr;
} }
return iter->second; if (base_operator->GetPrim() == nullptr) {
MS_LOG(ERROR) << "invalid op " << base_operator->name();
return nullptr;
}
auto param = iter->second.first(base_operator);
if (param != nullptr) {
param->type_ = iter->second.second;
}
return param;
} }
protected: protected:
std::map<std::string, ParameterPtrGen> op_parameters_; std::map<std::string, std::pair<ParameterPtrGen, int>> op_parameters_;
}; };
class Registry { class Registry {
public: public:
Registry(std::string name, ParameterPtrGen creator) noexcept { Registry(std::string name, int type, ParameterPtrGen creator) noexcept {
OperatorPopulateRegistry::GetInstance()->InsertOperatorParameterMap(name, creator); OperatorPopulateRegistry::GetInstance()->InsertOperatorParameterMap(name, type, creator);
} }
~Registry() = default; ~Registry() = default;
}; };
#define REG_OPERATOR_POPULATE(name, creator) static Registry g_##name(name, creator); template <typename T>
OpParameter *PopulateOpParameter() {
auto op_parameter_ptr = reinterpret_cast<OpParameter *>(malloc(sizeof(T)));
if (op_parameter_ptr == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter ptr failed";
return nullptr;
}
memset(op_parameter_ptr, 0, sizeof(T));
return reinterpret_cast<OpParameter *>(op_parameter_ptr);
}
template <typename T>
OpParameter *PopulateOpParameter(const BaseOperatorPtr &base_operator) {
auto op_parameter_ptr = reinterpret_cast<OpParameter *>(malloc(sizeof(T)));
if (op_parameter_ptr == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter ptr failed";
return nullptr;
}
memset(op_parameter_ptr, 0, sizeof(T));
return reinterpret_cast<OpParameter *>(op_parameter_ptr);
}
#define REG_OPERATOR_POPULATE(name, type, creator) static Registry g_##name(name, type, creator);
#define REG_OP_BASE_POPULATE(op) \
using mindspore::ops::kName##op; \
using mindspore::schema::PrimitiveType_##op; \
REG_OPERATOR_POPULATE(kName##op, PrimitiveType_##op, PopulateOpParameter<OpParameter>)
#define REG_OP_DEFAULT_POPULATE(op) \
using mindspore::ops::kName##op; \
using mindspore::schema::PrimitiveType_##op; \
REG_OPERATOR_POPULATE(kName##op, PrimitiveType_##op, PopulateOpParameter<op##Parameter>)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_SRC_COMMON_OPS_OPERATOR_POPULATE_H_ #endif // MINDSPORE_LITE_SRC_COMMON_OPS_OPERATOR_POPULATE_H_

View File

@ -0,0 +1,49 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/operator_populate/operator_populate_register.h"
#include "nnacl/softmax_parameter.h"
#include "ops/softmax.h"
using mindspore::ops::kAxis;
using mindspore::ops::kNameSoftmax;
using mindspore::schema::PrimitiveType_Softmax;
namespace mindspore {
namespace lite {
OpParameter *PopulateSoftmaxOpParameter(const BaseOperatorPtr &base_operator) {
auto param = reinterpret_cast<SoftmaxParameter *>(PopulateOpParameter<SoftmaxParameter>());
if (param == nullptr) {
MS_LOG(ERROR) << "new SoftmaxParameter failed.";
return nullptr;
}
ValuePtr attr = base_operator->GetPrim()->GetAttr(kAxis);
if (attr == nullptr) {
MS_LOG(ERROR) << "The attr(" << kAxis << ") of operator(" << base_operator->name() << ") not exist";
free(param);
return nullptr;
}
auto flat_axis = GetValue<std::vector<int64_t>>(attr);
if (flat_axis.size() != 1) {
MS_LOG(ERROR) << "axis number invalid!number: " << flat_axis.size();
free(param);
return nullptr;
}
param->axis_ = flat_axis.data()[0];
return reinterpret_cast<OpParameter *>(param);
}
REG_OPERATOR_POPULATE(kNameSoftmax, PrimitiveType_Softmax, PopulateSoftmaxOpParameter)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,100 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/operator_populate/operator_populate_register.h"
#include "nnacl/split_parameter.h"
#include "nnacl/op_base.h"
#include "ops/split.h"
using mindspore::ops::kAxis;
using mindspore::ops::kNameSplit;
using mindspore::ops::kOutputNum;
using mindspore::ops::kSizeSplits;
using mindspore::schema::PrimitiveType_Split;
namespace mindspore {
namespace lite {
void DestroySplitSizes(OpParameter *parameter) {
MS_CHECK_PTR_IF_NULL(parameter);
auto param = reinterpret_cast<SplitParameter *>(parameter);
if (param->split_sizes_ != nullptr) {
free(param->split_sizes_);
param->split_sizes_ = nullptr;
}
}
OpParameter *PopulateSplitOpParameter(const BaseOperatorPtr &base_operator) {
auto param = reinterpret_cast<SplitParameter *>(PopulateOpParameter<SplitParameter>());
if (param == nullptr) {
MS_LOG(ERROR) << "new SplitParameter failed.";
return nullptr;
}
mindspore::ValuePtr attr_output = base_operator->GetPrim()->GetAttr(kOutputNum);
if (attr_output == nullptr) {
MS_LOG(ERROR) << "The attr(" << kOutputNum << ") of operator(" << base_operator->name() << ") not exist";
free(param);
return nullptr;
}
auto output_num = GetValue<int64_t>(attr_output);
if (output_num > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int)) || output_num <= 0) {
MS_LOG(ERROR) << "The value of param->num_split_ is not correct";
free(param);
return nullptr;
}
param->num_split_ = output_num;
/* free split_sizes_ in split op base */
param->split_sizes_ = reinterpret_cast<int *>(malloc(static_cast<size_t>(output_num) * sizeof(int)));
if (param->split_sizes_ == nullptr) {
MS_LOG(ERROR) << "malloc param split_sizes_ error";
free(param);
return nullptr;
}
param->op_parameter_.destroy_func_ = DestroySplitSizes;
memset(param->split_sizes_, 0, static_cast<size_t>(param->num_split_) * sizeof(int));
mindspore::ValuePtr attr_size_splits = base_operator->GetPrim()->GetAttr(kSizeSplits);
if (attr_size_splits == nullptr) {
MS_LOG(ERROR) << "The attr(" << kSizeSplits << ") of operator(" << base_operator->name() << ") not exist";
DestroySplitSizes(param);
free(param);
return nullptr;
}
auto split_sizes_vector = GetValue<std::vector<int64_t>>(attr_size_splits);
if (split_sizes_vector.size() <= static_cast<uint32_t>(param->num_split_)) {
int i = 0;
for (auto iter : split_sizes_vector) {
param->split_sizes_[i++] = iter;
}
param->split_count_ = param->num_split_;
} else {
param->split_count_ = 0;
}
mindspore::ValuePtr attr = base_operator->GetPrim()->GetAttr(kAxis);
if (attr == nullptr) {
MS_LOG(ERROR) << "The attr(" << kAxis << ") of operator(" << base_operator->name() << ") not exist";
DestroySplitSizes(param);
free(param);
return nullptr;
}
param->split_dim_ = GetValue<int64_t>(attr);
return reinterpret_cast<OpParameter *>(param);
}
REG_OPERATOR_POPULATE(kNameSplit, PrimitiveType_Split, PopulateSplitOpParameter)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/operator_populate/operator_populate_register.h"
#include "nnacl/squeeze_parameter.h"
#include "ops/squeeze.h"
using mindspore::ops::kAxis;
using mindspore::ops::kNameSqueeze;
using mindspore::schema::PrimitiveType_Squeeze;
namespace mindspore {
namespace lite {
OpParameter *PopulateSqueezeOpParameter(const BaseOperatorPtr &base_operator) {
auto param = reinterpret_cast<SqueezeParameter *>(PopulateOpParameter<SqueezeParameter>());
if (param == nullptr) {
MS_LOG(ERROR) << "new SqueezeParameter failed.";
return nullptr;
}
mindspore::ValuePtr attr = base_operator->GetPrim()->GetAttr(kAxis);
if (attr == nullptr) {
MS_LOG(ERROR) << "The attr(" << kAxis << ") of operator(" << base_operator->name() << ") not exist";
free(param);
return nullptr;
}
auto flat_axis = GetValue<std::vector<int64_t>>(attr);
if (flat_axis.size() > MAX_SHAPE_SIZE) {
MS_LOG(ERROR) << "Invalid axis size " << flat_axis.size();
free(param);
return nullptr;
}
param->axis_size_ = flat_axis.size();
for (size_t i = 0; i < param->axis_size_; i++) {
CHECK_LESS_RETURN_RET(INT32_MAX, flat_axis[i], nullptr, param);
param->axis_[i] = flat_axis[i];
}
return reinterpret_cast<OpParameter *>(param);
}
REG_OPERATOR_POPULATE(kNameSqueeze, PrimitiveType_Squeeze, PopulateSqueezeOpParameter)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/operator_populate/operator_populate_register.h"
#include "nnacl/unsqueeze_parameter.h"
#include "ops/unsqueeze.h"
using mindspore::ops::kAxis;
using mindspore::ops::kNameUnsqueeze;
using mindspore::schema::PrimitiveType_Unsqueeze;
namespace mindspore {
namespace lite {
OpParameter *PopulateUnsqueezeOpParameter(const BaseOperatorPtr &base_operator) {
auto param = reinterpret_cast<UnSqueezeParameter *>(PopulateOpParameter<UnSqueezeParameter>());
if (param == nullptr) {
MS_LOG(ERROR) << "new UnSqueezeParameter failed.";
return nullptr;
}
ValuePtr attr = base_operator->GetPrim()->GetAttr(kAxis);
if (attr == nullptr) {
MS_LOG(ERROR) << "The attr(" << kAxis << ") of operator(" << base_operator->name() << ") not exist";
free(param);
return nullptr;
}
auto flat_axis = GetValue<std::vector<int64_t>>(attr);
if (flat_axis.size() > COMM_SHAPE_SIZE) {
MS_LOG(ERROR) << "Invalid axis size " << flat_axis.size();
free(param);
return nullptr;
}
param->num_dim_ = flat_axis.size();
int i = 0;
for (auto flat_axi : flat_axis) {
CHECK_LESS_RETURN_RET(INT32_MAX, flat_axi, nullptr, param);
param->dims_[i++] = flat_axi;
}
return reinterpret_cast<OpParameter *>(param);
}
REG_OPERATOR_POPULATE(kNameUnsqueeze, PrimitiveType_Unsqueeze, PopulateUnsqueezeOpParameter)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/operator_populate/operator_populate_register.h"
#include "nnacl/unstack_parameter.h"
#include "ops/unstack.h"
using mindspore::ops::kAxis;
using mindspore::ops::kNameUnstack;
using mindspore::schema::PrimitiveType_Unstack;
namespace mindspore {
namespace lite {
OpParameter *PopulateUnstackOpParameter(const BaseOperatorPtr &base_operator) {
auto param = reinterpret_cast<UnstackParameter *>(PopulateOpParameter<UnstackParameter>());
if (param == nullptr) {
MS_LOG(ERROR) << "new UnstackParameter failed.";
return nullptr;
}
mindspore::ValuePtr attr = base_operator->GetPrim()->GetAttr(kAxis);
if (attr == nullptr) {
MS_LOG(ERROR) << "The attr(" << kAxis << ") of operator(" << base_operator->name() << ") not exist";
free(param);
return nullptr;
}
auto axis = GetValue<int64_t>(attr);
CHECK_LESS_RETURN_RET(INT32_MAX, axis, nullptr, param);
param->axis_ = axis;
return reinterpret_cast<OpParameter *>(param);
}
REG_OPERATOR_POPULATE(kNameUnstack, PrimitiveType_Unstack, PopulateUnstackOpParameter)
} // namespace lite
} // namespace mindspore