forked from mindspore-Ecosystem/mindspore
add op populate
This commit is contained in:
parent
0db81d3bb6
commit
c58e13c692
|
@ -0,0 +1,122 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "src/common/ops/operator_populate/operator_populate_register.h"
|
||||||
|
#include "nnacl/fp32/activation_fp32.h"
|
||||||
|
#include "ops/fusion/activation.h"
|
||||||
|
#include "ops/relu.h"
|
||||||
|
#include "ops/relu6.h"
|
||||||
|
#include "ops/leaky_relu.h"
|
||||||
|
#include "ops/sigmoid.h"
|
||||||
|
#include "ops/tanh.h"
|
||||||
|
#include "ops/hswish.h"
|
||||||
|
#include "ops/hsigmoid.h"
|
||||||
|
#include "ops/gelu.h"
|
||||||
|
#include "ops/softplus.h"
|
||||||
|
#include "ops/elu.h"
|
||||||
|
using mindspore::ops::kActivationType;
|
||||||
|
using mindspore::ops::kAlpha;
|
||||||
|
using mindspore::ops::kApproximate;
|
||||||
|
using mindspore::ops::kMaxVal;
|
||||||
|
using mindspore::ops::kMinVal;
|
||||||
|
using mindspore::ops::kNameActivation;
|
||||||
|
using mindspore::ops::kNameElu;
|
||||||
|
using mindspore::ops::kNameGeLU;
|
||||||
|
using mindspore::ops::kNameHSigmoid;
|
||||||
|
using mindspore::ops::kNameHSwish;
|
||||||
|
using mindspore::ops::kNameLeakyRelu;
|
||||||
|
using mindspore::ops::kNameReLU;
|
||||||
|
using mindspore::ops::kNameReLU6;
|
||||||
|
using mindspore::ops::kNameSigmoid;
|
||||||
|
using mindspore::ops::kNameSoftplus;
|
||||||
|
using mindspore::ops::kNameTanh;
|
||||||
|
using mindspore::schema::PrimitiveType_Activation;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
OpParameter *PopulateActivationOpParameter(const BaseOperatorPtr &base_operator) {
|
||||||
|
auto param = reinterpret_cast<ActivationParameter *>(PopulateOpParameter<ActivationParameter>());
|
||||||
|
if (param == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new ActivationParameter failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
mindspore::ValuePtr attr = base_operator->GetPrim()->GetAttr(kActivationType);
|
||||||
|
if (attr != nullptr) {
|
||||||
|
auto activation_type = GetValue<int64_t>(attr);
|
||||||
|
static const std::set<int> activation_types = {
|
||||||
|
schema::ActivationType_RELU, schema::ActivationType_RELU6, schema::ActivationType_LEAKY_RELU,
|
||||||
|
schema::ActivationType_SIGMOID, schema::ActivationType_TANH, schema::ActivationType_SWISH,
|
||||||
|
schema::ActivationType_HSWISH, schema::ActivationType_HSIGMOID, schema::ActivationType_HARD_TANH,
|
||||||
|
schema::ActivationType_GELU, schema::ActivationType_SOFTPLUS, schema::ActivationType_ELU};
|
||||||
|
if (activation_types.find(activation_type) == activation_types.end()) {
|
||||||
|
MS_LOG(ERROR) << "invalid activation type: " << activation_type;
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
param->type_ = activation_type;
|
||||||
|
} else {
|
||||||
|
auto type_name = base_operator->name();
|
||||||
|
static const std::map<std::string, int> op_type_map = {{kNameReLU, schema::ActivationType_RELU},
|
||||||
|
{kNameReLU6, schema::ActivationType_RELU6},
|
||||||
|
{kNameLeakyRelu, schema::ActivationType_LEAKY_RELU},
|
||||||
|
{kNameSigmoid, schema::ActivationType_SIGMOID},
|
||||||
|
{kNameTanh, schema::ActivationType_TANH},
|
||||||
|
{kNameHSwish, schema::ActivationType_HSWISH},
|
||||||
|
{kNameHSigmoid, schema::ActivationType_HSIGMOID},
|
||||||
|
{kNameGeLU, schema::ActivationType_GELU},
|
||||||
|
{kNameSoftplus, schema::ActivationType_SOFTPLUS},
|
||||||
|
{kNameElu, schema::ActivationType_ELU}};
|
||||||
|
auto iter = op_type_map.find(type_name);
|
||||||
|
if (iter == op_type_map.end()) {
|
||||||
|
MS_LOG(ERROR) << "invalid activation type: " << type_name;
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
param->type_ = iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::ValuePtr alpha = base_operator->GetPrim()->GetAttr(kAlpha);
|
||||||
|
if (alpha != nullptr) {
|
||||||
|
param->alpha_ = GetValue<float>(alpha);
|
||||||
|
}
|
||||||
|
mindspore::ValuePtr min_val = base_operator->GetPrim()->GetAttr(kMinVal);
|
||||||
|
if (min_val != nullptr) {
|
||||||
|
param->min_val_ = GetValue<float>(min_val);
|
||||||
|
}
|
||||||
|
mindspore::ValuePtr max_val = base_operator->GetPrim()->GetAttr(kMaxVal);
|
||||||
|
if (max_val != nullptr) {
|
||||||
|
param->max_val_ = GetValue<float>(max_val);
|
||||||
|
}
|
||||||
|
mindspore::ValuePtr approximate = base_operator->GetPrim()->GetAttr(kApproximate);
|
||||||
|
if (approximate != nullptr) {
|
||||||
|
param->approximate_ = GetValue<bool>(approximate);
|
||||||
|
}
|
||||||
|
|
||||||
|
return reinterpret_cast<OpParameter *>(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_OPERATOR_POPULATE(kNameActivation, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
REG_OPERATOR_POPULATE(kNameReLU, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
REG_OPERATOR_POPULATE(kNameReLU6, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
REG_OPERATOR_POPULATE(kNameLeakyRelu, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
REG_OPERATOR_POPULATE(kNameSigmoid, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
REG_OPERATOR_POPULATE(kNameTanh, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
REG_OPERATOR_POPULATE(kNameHSwish, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
REG_OPERATOR_POPULATE(kNameHSigmoid, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
REG_OPERATOR_POPULATE(kNameGeLU, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
REG_OPERATOR_POPULATE(kNameSoftplus, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
REG_OPERATOR_POPULATE(kNameElu, PrimitiveType_Activation, PopulateActivationOpParameter)
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -1,49 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "src/common/ops/operator_populate/operator_populate_register.h"
|
|
||||||
#include "nnacl/arithmetic.h"
|
|
||||||
#include "ops/add.h"
|
|
||||||
using mindspore::ops::kNameAdd;
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
OpParameter *PopulateAddOpParameter(const BaseOperatorPtr &base_operator) {
|
|
||||||
if (base_operator == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "base_operator is nullptr";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto op_parameter_ptr = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
|
|
||||||
if (op_parameter_ptr == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Make OpParameter ptr failed";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
memset(op_parameter_ptr, 0, sizeof(ArithmeticParameter));
|
|
||||||
auto name = base_operator->name();
|
|
||||||
|
|
||||||
auto iter = kOpNameWithPrimitiveType.find(name);
|
|
||||||
if (iter == kOpNameWithPrimitiveType.end()) {
|
|
||||||
MS_LOG(ERROR) << "Can not find ParameterPtrGen : " << name;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
op_parameter_ptr->op_parameter_.type_ = iter->second;
|
|
||||||
op_parameter_ptr->broadcasting_ = false;
|
|
||||||
op_parameter_ptr->ndim_ = 0;
|
|
||||||
op_parameter_ptr->activation_type_ = 0;
|
|
||||||
return reinterpret_cast<OpParameter *>(op_parameter_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
REG_OPERATOR_POPULATE(kNameAdd, PopulateAddOpParameter)
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
|
@ -31,7 +31,12 @@
|
||||||
#include "ops/floor_mod.h"
|
#include "ops/floor_mod.h"
|
||||||
#include "ops/squared_difference.h"
|
#include "ops/squared_difference.h"
|
||||||
#include "ops/mod.h"
|
#include "ops/mod.h"
|
||||||
|
#include "ops/add.h"
|
||||||
|
#include "ops/fusion/add_fusion.h"
|
||||||
|
|
||||||
|
using mindspore::ops::kActivationType;
|
||||||
|
using mindspore::ops::kNameAdd;
|
||||||
|
using mindspore::ops::kNameAddFusion;
|
||||||
using mindspore::ops::kNameEqual;
|
using mindspore::ops::kNameEqual;
|
||||||
using mindspore::ops::kNameFloorDiv;
|
using mindspore::ops::kNameFloorDiv;
|
||||||
using mindspore::ops::kNameFloorMod;
|
using mindspore::ops::kNameFloorMod;
|
||||||
|
@ -48,57 +53,66 @@ using mindspore::ops::kNameNotEqual;
|
||||||
using mindspore::ops::kNameRealDiv;
|
using mindspore::ops::kNameRealDiv;
|
||||||
using mindspore::ops::kNameSquaredDifference;
|
using mindspore::ops::kNameSquaredDifference;
|
||||||
|
|
||||||
|
using mindspore::schema::PrimitiveType_AddFusion;
|
||||||
|
using mindspore::schema::PrimitiveType_Equal;
|
||||||
|
using mindspore::schema::PrimitiveType_FloorDiv;
|
||||||
|
using mindspore::schema::PrimitiveType_FloorMod;
|
||||||
|
using mindspore::schema::PrimitiveType_Greater;
|
||||||
|
using mindspore::schema::PrimitiveType_GreaterEqual;
|
||||||
|
using mindspore::schema::PrimitiveType_Less;
|
||||||
|
using mindspore::schema::PrimitiveType_LessEqual;
|
||||||
|
using mindspore::schema::PrimitiveType_LogicalAnd;
|
||||||
|
using mindspore::schema::PrimitiveType_LogicalOr;
|
||||||
|
using mindspore::schema::PrimitiveType_Maximum;
|
||||||
|
using mindspore::schema::PrimitiveType_Minimum;
|
||||||
|
using mindspore::schema::PrimitiveType_Mod;
|
||||||
|
using mindspore::schema::PrimitiveType_NotEqual;
|
||||||
|
using mindspore::schema::PrimitiveType_RealDiv;
|
||||||
|
using mindspore::schema::PrimitiveType_SquaredDifference;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
ArithmeticParameter *PopulateArithmeticCommonOpPara(const BaseOperatorPtr &base_operator) {
|
OpParameter *PopulateArithmeticCommonOpPara(const BaseOperatorPtr &base_operator) {
|
||||||
if (base_operator == nullptr) {
|
auto param = reinterpret_cast<ArithmeticParameter *>(PopulateOpParameter<ArithmeticParameter>());
|
||||||
MS_LOG(ERROR) << "base_operator is nullptr";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
|
|
||||||
if (param == nullptr) {
|
if (param == nullptr) {
|
||||||
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
|
MS_LOG(ERROR) << "new ArithmeticParameter failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
memset(param, 0, sizeof(ArithmeticParameter));
|
|
||||||
auto name = base_operator->name();
|
|
||||||
|
|
||||||
auto iter = kOpNameWithPrimitiveType.find(name);
|
|
||||||
if (iter == kOpNameWithPrimitiveType.end()) {
|
|
||||||
MS_LOG(ERROR) << "Can not find ParameterPtrGen : " << name;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
param->op_parameter_.type_ = iter->second;
|
|
||||||
param->broadcasting_ = false;
|
param->broadcasting_ = false;
|
||||||
param->ndim_ = 0;
|
param->ndim_ = 0;
|
||||||
param->activation_type_ = 0;
|
param->activation_type_ = 0;
|
||||||
return param;
|
return reinterpret_cast<OpParameter *>(param);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpParameter *PopulateArithmeticOp(const BaseOperatorPtr &base_operator) {
|
OpParameter *PopulateAddOpParameter(const BaseOperatorPtr &base_operator) {
|
||||||
ArithmeticParameter *param = PopulateArithmeticCommonOpPara(base_operator);
|
ArithmeticParameter *param = reinterpret_cast<ArithmeticParameter *>(PopulateArithmeticCommonOpPara(base_operator));
|
||||||
if (param == nullptr) {
|
if (param == nullptr) {
|
||||||
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
|
MS_LOG(ERROR) << "PopulateArithmeticCommonOpPara failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
mindspore::ValuePtr attr = base_operator->GetPrim()->GetAttr(kActivationType);
|
||||||
|
if (attr != nullptr) {
|
||||||
|
param->activation_type_ = ActivationType(GetValue<int64_t>(attr));
|
||||||
|
}
|
||||||
return reinterpret_cast<OpParameter *>(param);
|
return reinterpret_cast<OpParameter *>(param);
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_OPERATOR_POPULATE(kNameRealDiv, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameAdd, PrimitiveType_AddFusion, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameLogicalAnd, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameAddFusion, PrimitiveType_AddFusion, PopulateAddOpParameter)
|
||||||
REG_OPERATOR_POPULATE(kNameLogicalOr, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameRealDiv, PrimitiveType_RealDiv, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameEqual, PopulateArithmeticOp);
|
REG_OPERATOR_POPULATE(kNameLogicalAnd, PrimitiveType_LogicalAnd, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameNotEqual, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameLogicalOr, PrimitiveType_LogicalOr, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameLess, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameEqual, PrimitiveType_Equal, PopulateArithmeticCommonOpPara);
|
||||||
REG_OPERATOR_POPULATE(kNameLessEqual, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameNotEqual, PrimitiveType_NotEqual, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameGreater, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameLess, PrimitiveType_Less, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameGreaterEqual, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameLessEqual, PrimitiveType_LessEqual, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameMaximum, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameGreater, PrimitiveType_Greater, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameMinimum, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameGreaterEqual, PrimitiveType_GreaterEqual, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameFloorDiv, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameMaximum, PrimitiveType_Maximum, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameFloorMod, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameMinimum, PrimitiveType_Minimum, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameMod, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameFloorDiv, PrimitiveType_FloorDiv, PopulateArithmeticCommonOpPara)
|
||||||
REG_OPERATOR_POPULATE(kNameSquaredDifference, PopulateArithmeticOp)
|
REG_OPERATOR_POPULATE(kNameFloorMod, PrimitiveType_FloorMod, PopulateArithmeticCommonOpPara)
|
||||||
|
REG_OPERATOR_POPULATE(kNameMod, PrimitiveType_Mod, PopulateArithmeticCommonOpPara)
|
||||||
|
REG_OPERATOR_POPULATE(kNameSquaredDifference, PrimitiveType_SquaredDifference, PopulateArithmeticCommonOpPara)
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,46 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "src/common/ops/operator_populate/operator_populate_register.h"
|
|
||||||
#include "ops/assert.h"
|
|
||||||
|
|
||||||
using mindspore::ops::kNameAssert;
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
OpParameter *PopulateAssertOpParameter(const BaseOperatorPtr &base_operator) {
|
|
||||||
if (base_operator == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "base_operator is nullptr";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto op_parameter_ptr = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
|
||||||
if (op_parameter_ptr == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Make OpParameter ptr failed";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
memset(op_parameter_ptr, 0, sizeof(OpParameter));
|
|
||||||
auto name = base_operator->name();
|
|
||||||
|
|
||||||
auto iter = kOpNameWithPrimitiveType.find(name);
|
|
||||||
if (iter == kOpNameWithPrimitiveType.end()) {
|
|
||||||
MS_LOG(ERROR) << "Can not find ParameterPtrGen : " << name;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
op_parameter_ptr->type_ = iter->second;
|
|
||||||
return op_parameter_ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
REG_OPERATOR_POPULATE(kNameAssert, PopulateAssertOpParameter)
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "src/common/ops/operator_populate/operator_populate_register.h"
|
||||||
|
#include "nnacl/where_parameter.h"
|
||||||
|
#include "nnacl/sparse_to_dense_parameter.h"
|
||||||
|
#include "nnacl/transpose.h"
|
||||||
|
#include "nnacl/triu_tril.h"
|
||||||
|
#include "nnacl/fp32/unique_fp32.h"
|
||||||
|
#include "nnacl/scatter_nd_parameter.h"
|
||||||
|
#include "ops/assert.h"
|
||||||
|
#include "ops/where.h"
|
||||||
|
#include "ops/unsorted_segment_sum.h"
|
||||||
|
#include "ops/unique.h"
|
||||||
|
#include "ops/triu.h"
|
||||||
|
#include "ops/tril.h"
|
||||||
|
#include "ops/transpose.h"
|
||||||
|
#include "ops/sparse_to_dense.h"
|
||||||
|
#include "ops/sparse_segment_sum.h"
|
||||||
|
#include "ops/sparse_reshape.h"
|
||||||
|
#include "ops/sparse_fill_empty_rows.h"
|
||||||
|
#include "ops/size.h"
|
||||||
|
#include "ops/shape.h"
|
||||||
|
#include "ops/select.h"
|
||||||
|
#include "ops/scatter_nd_update.h"
|
||||||
|
#include "ops/tensor_scatter_add.h"
|
||||||
|
#include "ops/scatter_nd.h"
|
||||||
|
#include "ops/expand_dims.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
REG_OP_BASE_POPULATE(Assert)
|
||||||
|
REG_OP_BASE_POPULATE(UnsortedSegmentSum)
|
||||||
|
REG_OP_BASE_POPULATE(SparseSegmentSum)
|
||||||
|
REG_OP_BASE_POPULATE(SparseReshape)
|
||||||
|
REG_OP_BASE_POPULATE(SparseFillEmptyRows)
|
||||||
|
REG_OP_BASE_POPULATE(Size)
|
||||||
|
REG_OP_BASE_POPULATE(Shape)
|
||||||
|
REG_OP_BASE_POPULATE(Select)
|
||||||
|
REG_OP_BASE_POPULATE(ExpandDims)
|
||||||
|
|
||||||
|
REG_OP_DEFAULT_POPULATE(SparseToDense)
|
||||||
|
REG_OP_DEFAULT_POPULATE(Transpose)
|
||||||
|
REG_OP_DEFAULT_POPULATE(Tril)
|
||||||
|
REG_OP_DEFAULT_POPULATE(Triu)
|
||||||
|
REG_OP_DEFAULT_POPULATE(Where)
|
||||||
|
REG_OP_DEFAULT_POPULATE(Unique)
|
||||||
|
using mindspore::ops::kNameScatterNd;
|
||||||
|
using mindspore::ops::kNameScatterNdUpdate;
|
||||||
|
using mindspore::ops::kNameTensorScatterAdd;
|
||||||
|
using mindspore::schema::PrimitiveType_ScatterNd;
|
||||||
|
using mindspore::schema::PrimitiveType_ScatterNdUpdate;
|
||||||
|
using mindspore::schema::PrimitiveType_TensorScatterAdd;
|
||||||
|
REG_OPERATOR_POPULATE(kNameScatterNd, PrimitiveType_ScatterNd, PopulateOpParameter<ScatterNDParameter>)
|
||||||
|
REG_OPERATOR_POPULATE(kNameScatterNdUpdate, PrimitiveType_ScatterNdUpdate, PopulateOpParameter<ScatterNDParameter>)
|
||||||
|
REG_OPERATOR_POPULATE(kNameTensorScatterAdd, PrimitiveType_TensorScatterAdd, PopulateOpParameter<ScatterNDParameter>)
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -22,5 +22,9 @@ OperatorPopulateRegistry *OperatorPopulateRegistry::GetInstance() {
|
||||||
static OperatorPopulateRegistry registry;
|
static OperatorPopulateRegistry registry;
|
||||||
return ®istry;
|
return ®istry;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpParameter *CreatePopulatePtr(const BaseOperatorPtr &base_operator) {
|
||||||
|
return OperatorPopulateRegistry::GetInstance()->CreatePopulateByOp(base_operator);
|
||||||
|
}
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
#include "src/common/common.h"
|
#include "src/common/common.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
@ -28,58 +29,89 @@
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
#include "src/common/log_util.h"
|
#include "src/common/log_util.h"
|
||||||
#include "ops/base_operator.h"
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ir/primitive.h"
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "ops/op_name.h"
|
||||||
|
#include "schema/ops_generated.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
using BaseOperatorPtr = std::shared_ptr<mindspore::ops::BaseOperator>;
|
using BaseOperatorPtr = std::shared_ptr<mindspore::ops::BaseOperator>;
|
||||||
typedef OpParameter *(*ParameterPtrGen)(const BaseOperatorPtr &base_operator);
|
typedef OpParameter *(*ParameterPtrGen)(const BaseOperatorPtr &base_operator);
|
||||||
|
OpParameter *CreatePopulatePtr(const BaseOperatorPtr &base_operator);
|
||||||
static const std::map<std::string, int> kOpNameWithPrimitiveType = {{"Add", 5},
|
|
||||||
{"Assert", 17},
|
|
||||||
{"Equal", 53},
|
|
||||||
{"FloorDiv", 64},
|
|
||||||
{"FloorMod", 65},
|
|
||||||
{"Greater", 71},
|
|
||||||
{"GreaterEqual", 72},
|
|
||||||
{"Less", 77},
|
|
||||||
{"LessEqual", 78},
|
|
||||||
{"LogicalAnd", 81},
|
|
||||||
{"LogicalOr", 83},
|
|
||||||
{"Maximum", 90},
|
|
||||||
{"Minimum", 96},
|
|
||||||
{"Mod", 98},
|
|
||||||
{"NotEqual", 103},
|
|
||||||
{"RealDiv", 117},
|
|
||||||
{"SquaredDifference", 149}};
|
|
||||||
|
|
||||||
class OperatorPopulateRegistry {
|
class OperatorPopulateRegistry {
|
||||||
public:
|
public:
|
||||||
static OperatorPopulateRegistry *GetInstance();
|
static OperatorPopulateRegistry *GetInstance();
|
||||||
|
|
||||||
void InsertOperatorParameterMap(const std::string &name, ParameterPtrGen creator) { op_parameters_[name] = creator; }
|
void InsertOperatorParameterMap(const std::string &name, int type, ParameterPtrGen creator) {
|
||||||
|
op_parameters_[name] = std::make_pair(creator, type);
|
||||||
|
}
|
||||||
|
ParameterPtrGen GetParameterPtrCreator(const std::string &name) { return CreatePopulatePtr; }
|
||||||
|
|
||||||
ParameterPtrGen GetParameterPtrCreator(const std::string &name) {
|
OpParameter *CreatePopulateByOp(const BaseOperatorPtr &base_operator) {
|
||||||
auto iter = op_parameters_.find(name);
|
MS_CHECK_TRUE_RET(base_operator != nullptr, nullptr);
|
||||||
|
auto iter = op_parameters_.find(base_operator->name());
|
||||||
if (iter == op_parameters_.end()) {
|
if (iter == op_parameters_.end()) {
|
||||||
MS_LOG(ERROR) << "Unsupported op in creator " << name;
|
MS_LOG(ERROR) << "Unsupported op in creator " << base_operator->name();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return iter->second;
|
if (base_operator->GetPrim() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "invalid op " << base_operator->name();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto param = iter->second.first(base_operator);
|
||||||
|
if (param != nullptr) {
|
||||||
|
param->type_ = iter->second.second;
|
||||||
|
}
|
||||||
|
return param;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::map<std::string, ParameterPtrGen> op_parameters_;
|
std::map<std::string, std::pair<ParameterPtrGen, int>> op_parameters_;
|
||||||
};
|
};
|
||||||
class Registry {
|
class Registry {
|
||||||
public:
|
public:
|
||||||
Registry(std::string name, ParameterPtrGen creator) noexcept {
|
Registry(std::string name, int type, ParameterPtrGen creator) noexcept {
|
||||||
OperatorPopulateRegistry::GetInstance()->InsertOperatorParameterMap(name, creator);
|
OperatorPopulateRegistry::GetInstance()->InsertOperatorParameterMap(name, type, creator);
|
||||||
}
|
}
|
||||||
|
|
||||||
~Registry() = default;
|
~Registry() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REG_OPERATOR_POPULATE(name, creator) static Registry g_##name(name, creator);
|
template <typename T>
|
||||||
|
OpParameter *PopulateOpParameter() {
|
||||||
|
auto op_parameter_ptr = reinterpret_cast<OpParameter *>(malloc(sizeof(T)));
|
||||||
|
if (op_parameter_ptr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc OpParameter ptr failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
memset(op_parameter_ptr, 0, sizeof(T));
|
||||||
|
return reinterpret_cast<OpParameter *>(op_parameter_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
OpParameter *PopulateOpParameter(const BaseOperatorPtr &base_operator) {
|
||||||
|
auto op_parameter_ptr = reinterpret_cast<OpParameter *>(malloc(sizeof(T)));
|
||||||
|
if (op_parameter_ptr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc OpParameter ptr failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
memset(op_parameter_ptr, 0, sizeof(T));
|
||||||
|
return reinterpret_cast<OpParameter *>(op_parameter_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define REG_OPERATOR_POPULATE(name, type, creator) static Registry g_##name(name, type, creator);
|
||||||
|
|
||||||
|
#define REG_OP_BASE_POPULATE(op) \
|
||||||
|
using mindspore::ops::kName##op; \
|
||||||
|
using mindspore::schema::PrimitiveType_##op; \
|
||||||
|
REG_OPERATOR_POPULATE(kName##op, PrimitiveType_##op, PopulateOpParameter<OpParameter>)
|
||||||
|
|
||||||
|
#define REG_OP_DEFAULT_POPULATE(op) \
|
||||||
|
using mindspore::ops::kName##op; \
|
||||||
|
using mindspore::schema::PrimitiveType_##op; \
|
||||||
|
REG_OPERATOR_POPULATE(kName##op, PrimitiveType_##op, PopulateOpParameter<op##Parameter>)
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_LITE_SRC_COMMON_OPS_OPERATOR_POPULATE_H_
|
#endif // MINDSPORE_LITE_SRC_COMMON_OPS_OPERATOR_POPULATE_H_
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "src/common/ops/operator_populate/operator_populate_register.h"
|
||||||
|
#include "nnacl/softmax_parameter.h"
|
||||||
|
#include "ops/softmax.h"
|
||||||
|
using mindspore::ops::kAxis;
|
||||||
|
using mindspore::ops::kNameSoftmax;
|
||||||
|
using mindspore::schema::PrimitiveType_Softmax;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
OpParameter *PopulateSoftmaxOpParameter(const BaseOperatorPtr &base_operator) {
|
||||||
|
auto param = reinterpret_cast<SoftmaxParameter *>(PopulateOpParameter<SoftmaxParameter>());
|
||||||
|
if (param == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new SoftmaxParameter failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
ValuePtr attr = base_operator->GetPrim()->GetAttr(kAxis);
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The attr(" << kAxis << ") of operator(" << base_operator->name() << ") not exist";
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto flat_axis = GetValue<std::vector<int64_t>>(attr);
|
||||||
|
if (flat_axis.size() != 1) {
|
||||||
|
MS_LOG(ERROR) << "axis number invalid!number: " << flat_axis.size();
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
param->axis_ = flat_axis.data()[0];
|
||||||
|
return reinterpret_cast<OpParameter *>(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_OPERATOR_POPULATE(kNameSoftmax, PrimitiveType_Softmax, PopulateSoftmaxOpParameter)
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,100 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "src/common/ops/operator_populate/operator_populate_register.h"
|
||||||
|
#include "nnacl/split_parameter.h"
|
||||||
|
#include "nnacl/op_base.h"
|
||||||
|
#include "ops/split.h"
|
||||||
|
using mindspore::ops::kAxis;
|
||||||
|
using mindspore::ops::kNameSplit;
|
||||||
|
using mindspore::ops::kOutputNum;
|
||||||
|
using mindspore::ops::kSizeSplits;
|
||||||
|
using mindspore::schema::PrimitiveType_Split;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
void DestroySplitSizes(OpParameter *parameter) {
|
||||||
|
MS_CHECK_PTR_IF_NULL(parameter);
|
||||||
|
auto param = reinterpret_cast<SplitParameter *>(parameter);
|
||||||
|
if (param->split_sizes_ != nullptr) {
|
||||||
|
free(param->split_sizes_);
|
||||||
|
param->split_sizes_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OpParameter *PopulateSplitOpParameter(const BaseOperatorPtr &base_operator) {
|
||||||
|
auto param = reinterpret_cast<SplitParameter *>(PopulateOpParameter<SplitParameter>());
|
||||||
|
if (param == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new SplitParameter failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::ValuePtr attr_output = base_operator->GetPrim()->GetAttr(kOutputNum);
|
||||||
|
if (attr_output == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The attr(" << kOutputNum << ") of operator(" << base_operator->name() << ") not exist";
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto output_num = GetValue<int64_t>(attr_output);
|
||||||
|
if (output_num > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int)) || output_num <= 0) {
|
||||||
|
MS_LOG(ERROR) << "The value of param->num_split_ is not correct";
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
param->num_split_ = output_num;
|
||||||
|
|
||||||
|
/* free split_sizes_ in split op base */
|
||||||
|
param->split_sizes_ = reinterpret_cast<int *>(malloc(static_cast<size_t>(output_num) * sizeof(int)));
|
||||||
|
if (param->split_sizes_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc param split_sizes_ error";
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
param->op_parameter_.destroy_func_ = DestroySplitSizes;
|
||||||
|
memset(param->split_sizes_, 0, static_cast<size_t>(param->num_split_) * sizeof(int));
|
||||||
|
|
||||||
|
mindspore::ValuePtr attr_size_splits = base_operator->GetPrim()->GetAttr(kSizeSplits);
|
||||||
|
if (attr_size_splits == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The attr(" << kSizeSplits << ") of operator(" << base_operator->name() << ") not exist";
|
||||||
|
DestroySplitSizes(param);
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto split_sizes_vector = GetValue<std::vector<int64_t>>(attr_size_splits);
|
||||||
|
|
||||||
|
if (split_sizes_vector.size() <= static_cast<uint32_t>(param->num_split_)) {
|
||||||
|
int i = 0;
|
||||||
|
for (auto iter : split_sizes_vector) {
|
||||||
|
param->split_sizes_[i++] = iter;
|
||||||
|
}
|
||||||
|
param->split_count_ = param->num_split_;
|
||||||
|
} else {
|
||||||
|
param->split_count_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::ValuePtr attr = base_operator->GetPrim()->GetAttr(kAxis);
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The attr(" << kAxis << ") of operator(" << base_operator->name() << ") not exist";
|
||||||
|
DestroySplitSizes(param);
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
param->split_dim_ = GetValue<int64_t>(attr);
|
||||||
|
return reinterpret_cast<OpParameter *>(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_OPERATOR_POPULATE(kNameSplit, PrimitiveType_Split, PopulateSplitOpParameter)
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "src/common/ops/operator_populate/operator_populate_register.h"
|
||||||
|
#include "nnacl/squeeze_parameter.h"
|
||||||
|
#include "ops/squeeze.h"
|
||||||
|
using mindspore::ops::kAxis;
|
||||||
|
using mindspore::ops::kNameSqueeze;
|
||||||
|
using mindspore::schema::PrimitiveType_Squeeze;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
OpParameter *PopulateSqueezeOpParameter(const BaseOperatorPtr &base_operator) {
|
||||||
|
auto param = reinterpret_cast<SqueezeParameter *>(PopulateOpParameter<SqueezeParameter>());
|
||||||
|
if (param == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new SqueezeParameter failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::ValuePtr attr = base_operator->GetPrim()->GetAttr(kAxis);
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The attr(" << kAxis << ") of operator(" << base_operator->name() << ") not exist";
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto flat_axis = GetValue<std::vector<int64_t>>(attr);
|
||||||
|
if (flat_axis.size() > MAX_SHAPE_SIZE) {
|
||||||
|
MS_LOG(ERROR) << "Invalid axis size " << flat_axis.size();
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
param->axis_size_ = flat_axis.size();
|
||||||
|
for (size_t i = 0; i < param->axis_size_; i++) {
|
||||||
|
CHECK_LESS_RETURN_RET(INT32_MAX, flat_axis[i], nullptr, param);
|
||||||
|
param->axis_[i] = flat_axis[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return reinterpret_cast<OpParameter *>(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_OPERATOR_POPULATE(kNameSqueeze, PrimitiveType_Squeeze, PopulateSqueezeOpParameter)
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "src/common/ops/operator_populate/operator_populate_register.h"
|
||||||
|
#include "nnacl/unsqueeze_parameter.h"
|
||||||
|
#include "ops/unsqueeze.h"
|
||||||
|
using mindspore::ops::kAxis;
|
||||||
|
using mindspore::ops::kNameUnsqueeze;
|
||||||
|
using mindspore::schema::PrimitiveType_Unsqueeze;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
OpParameter *PopulateUnsqueezeOpParameter(const BaseOperatorPtr &base_operator) {
|
||||||
|
auto param = reinterpret_cast<UnSqueezeParameter *>(PopulateOpParameter<UnSqueezeParameter>());
|
||||||
|
if (param == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new UnSqueezeParameter failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr attr = base_operator->GetPrim()->GetAttr(kAxis);
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The attr(" << kAxis << ") of operator(" << base_operator->name() << ") not exist";
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto flat_axis = GetValue<std::vector<int64_t>>(attr);
|
||||||
|
if (flat_axis.size() > COMM_SHAPE_SIZE) {
|
||||||
|
MS_LOG(ERROR) << "Invalid axis size " << flat_axis.size();
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
param->num_dim_ = flat_axis.size();
|
||||||
|
int i = 0;
|
||||||
|
for (auto flat_axi : flat_axis) {
|
||||||
|
CHECK_LESS_RETURN_RET(INT32_MAX, flat_axi, nullptr, param);
|
||||||
|
param->dims_[i++] = flat_axi;
|
||||||
|
}
|
||||||
|
return reinterpret_cast<OpParameter *>(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_OPERATOR_POPULATE(kNameUnsqueeze, PrimitiveType_Unsqueeze, PopulateUnsqueezeOpParameter)
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "src/common/ops/operator_populate/operator_populate_register.h"
|
||||||
|
#include "nnacl/unstack_parameter.h"
|
||||||
|
#include "ops/unstack.h"
|
||||||
|
using mindspore::ops::kAxis;
|
||||||
|
using mindspore::ops::kNameUnstack;
|
||||||
|
using mindspore::schema::PrimitiveType_Unstack;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
OpParameter *PopulateUnstackOpParameter(const BaseOperatorPtr &base_operator) {
|
||||||
|
auto param = reinterpret_cast<UnstackParameter *>(PopulateOpParameter<UnstackParameter>());
|
||||||
|
if (param == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new UnstackParameter failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::ValuePtr attr = base_operator->GetPrim()->GetAttr(kAxis);
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The attr(" << kAxis << ") of operator(" << base_operator->name() << ") not exist";
|
||||||
|
free(param);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto axis = GetValue<int64_t>(attr);
|
||||||
|
CHECK_LESS_RETURN_RET(INT32_MAX, axis, nullptr, param);
|
||||||
|
param->axis_ = axis;
|
||||||
|
return reinterpret_cast<OpParameter *>(param);
|
||||||
|
}
|
||||||
|
REG_OPERATOR_POPULATE(kNameUnstack, PrimitiveType_Unstack, PopulateUnstackOpParameter)
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
Loading…
Reference in New Issue