!23028 [MSLITE] support custom kernel registry for codegen

Merge pull request !23028 from zhanyuan/nnie_micro_recitify
This commit is contained in:
i-robot 2021-09-09 02:05:24 +00:00 committed by Gitee
commit 41019c77a1
19 changed files with 178 additions and 430 deletions

View File

@ -13,14 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_CUSTOM_PARAMETER_H_
#define MINDSPORE_NNACL_CUSTOM_PARAMETER_H_
#include "nnacl/op_base.h"
#include "coder/user_registry/user_kernel_register.h"
#include "schema/ops_generated.h"
#define MAX_STR_LEN 64
#define MAX_ATTR_NUM 8
using mindspore::schema::PrimitiveType_Custom;
namespace mindspore::lite::micro {
REG_USER_KERNEL(kARM32A, kNumberTypeInt8, PrimitiveType_Custom, nnie_micro.h, NnieKernel, nniemicro)
REG_USER_KERNEL(kARM32A, kNumberTypeUInt8, PrimitiveType_Custom, nnie_micro.h, NnieKernel, nniemicro)
REG_USER_KERNEL(kARM32A, kNumberTypeFloat32, PrimitiveType_Custom, nnie_micro.h, NnieKernel, nniemicro)
} // namespace mindspore::lite::micro
typedef struct CustomParameter {
OpParameter op_parameter_;
char type[MAX_STR_LEN];
char attr_name[MAX_ATTR_NUM][MAX_STR_LEN];
char *attr_data[MAX_ATTR_NUM];
int attr_num;
} CustomParameter;
#endif // MINDSPORE_NNACL_CUSTOM_PARAMETER_H_

View File

@ -20,7 +20,7 @@
#include <arm_neon.h>
#endif
#include <memory.h>
#include <mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/resize_parameter.h>
#include "nnacl/resize_parameter.h"
#include "nnacl/op_base.h"
#include "nnacl/crop_parameter.h"

View File

@ -24,6 +24,7 @@ typedef struct TensorC {
void *data_;
size_t shape_size_;
int shape_[MAX_SHAPE_SIZE];
char *name_;
} TensorC;
#endif // MINDSPORE_NNACL_TENSOR_C_H_

View File

@ -154,9 +154,7 @@ set(LITE_SRC
)
set(REGISTRY_SRC
${MICRO_DIR}/coder/user_registry/user_kernel_register.cc
${MICRO_DIR}/coder/user_registry/nnie_kernel_reg.cc
${MICRO_DIR}/coder/user_registry/nnie_infer.cc
${MICRO_DIR}/coder/opcoders/kernel_registry.cc
)
list(APPEND FILE_SET ${CODER_SRC} ${CODER_OPCODERS_SRC} ${CODER_GENERATOR_SRC}

View File

@ -48,9 +48,6 @@ class Configurator {
int ParseProjDir(std::string model_path);
std::string proj_dir() const { return proj_dir_; }
void SetCustomFlag() { has_custom_op_ = true; }
bool CustomFlag() const { return has_custom_op_; }
private:
Configurator() = default;
~Configurator() = default;
@ -60,7 +57,6 @@ class Configurator {
bool support_parallel_{false};
bool debug_mode_{false};
std::string proj_dir_;
bool has_custom_op_{false};
};
} // namespace mindspore::lite::micro

View File

@ -17,7 +17,6 @@
#include "coder/generator/component/cmake_component.h"
#include <set>
#include <memory>
#include "coder/user_registry/user_kernel_register.h"
namespace mindspore::lite::micro {
void CodeCMakeNetLibrary(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator *config) {
@ -59,13 +58,5 @@ void CodeCMakeNetLibrary(std::ofstream &ofs, const std::unique_ptr<CoderContext>
" ${CMAKE_CURRENT_SOURCE_DIR}/*.c\n"
" )\n"
"add_library(net STATIC ${NET_SRC})\n";
auto user_libs = UserKernelFactory::GetInstance()->UserKernelLibNames();
if (config->CustomFlag() && !user_libs.empty()) {
ofs << "target_link_libraries(net";
for (auto lib : user_libs) {
ofs << " " << lib;
}
ofs << ")\n";
}
}
} // namespace mindspore::lite::micro

View File

@ -178,50 +178,4 @@ void *MTensor::MutableData() {
} // namespace mindspore
)RAW";
const char custom_params_source[] = R"RAW(
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_CUSTOM_PARAMS_H_
#define MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_CUSTOM_PARAMS_H_
#define MAX_TENSOR_NAME_LEN 32
#define MAX_SHAPE_SIZE 8
#define MAX_STR_LEN 32
#define MAX_ATTR_NUM 8
typedef struct {
void *data;
int data_size;
int shape[MAX_SHAPE_SIZE];
int shape_size;
char name[MAX_TENSOR_NAME_LEN];
int data_type;
int format;
} CustomTensor;
typedef struct {
char type[MAX_STR_LEN];
char attr_name[MAX_ATTR_NUM][MAX_STR_LEN];
char attr_data[MAX_ATTR_NUM][MAX_STR_LEN];
int attr_num;
} CustomParams;
#endif // MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_CUSTOM_PARAMS_H_
)RAW";
} // namespace mindspore::lite::micro

View File

@ -21,7 +21,6 @@ namespace mindspore::lite::micro {
extern const char tensor_header[];
extern const char tensor_source[];
extern const char custom_params_source[];
} // namespace mindspore::lite::micro

View File

@ -33,6 +33,7 @@
#include "coder/generator/component/const_blocks/license.h"
#include "coder/log.h"
#include "coder/opcoders/parallel.h"
#include "coder/opcoders/kernel_registry.h"
namespace mindspore::lite::micro {
int WriteContentToFile(const std::string &file, const std::string &content) {
@ -97,9 +98,6 @@ int Generator::CodeStaticContent() {
{net_src_file_path_ + "tensor.cc", tensor_source},
{net_src_file_path_ + "mmodel.h", model_header}};
if (config_->CustomFlag()) {
const_blocks.emplace_back(std::make_pair(net_src_file_path_ + "custom_params.h", custom_params_source));
}
if (config_->support_parallel()) {
const_blocks.emplace_back(std::make_pair(net_src_file_path_ + kThreadWrapper, thread_header));
}
@ -167,6 +165,24 @@ int Generator::CodeWeightFile() {
return RET_OK;
}
int Generator::CodeRegKernelHFile() {
if (!KernelRegistry::GetInstance()->HasKernelRegistered()) return RET_OK;
if (!KernelRegistry::GetInstance()->CheckRegistered(schema::PrimitiveType_Custom)) {
MS_LOG(ERROR) << "Only support custom kernel to register now!";
return RET_ERROR;
}
std::string reg_kernel_header = net_src_file_path_ + "registered_kernel.h";
std::ofstream cofs(reg_kernel_header);
MS_CHECK_TRUE(!cofs.bad(), "filed to open file");
MS_LOG(INFO) << "write " << reg_kernel_header;
cofs << g_hwLicense;
cofs << "#include \"nnacl/tensor_c.h\"\n";
cofs << "#include \"nnacl/custom_parameter.h\"\n\n";
cofs << KernelRegistry::GetInstance()->GenKernelInterface(kCustomKernelName, kCustomKernelParam) << "\n";
return RET_OK;
}
int Generator::GenerateCode() {
MS_CHECK_RET_CODE(CodeNetHFile(), "code net h file failed.");
MS_CHECK_RET_CODE(CodeNetCFile(), "code net c file failed.");
@ -174,6 +190,7 @@ int Generator::GenerateCode() {
MS_CHECK_RET_CODE(CodeSourceCMakeFile(), "code net cmake file failed.");
MS_CHECK_RET_CODE(CodeStaticContent(), "code static content failed.");
MS_CHECK_RET_CODE(CodeSessionImplement(), "code session file failed.");
MS_CHECK_RET_CODE(CodeRegKernelHFile(), "code registered kernel header file failed.");
return RET_OK;
}
} // namespace mindspore::lite::micro

View File

@ -46,6 +46,7 @@ class Generator {
virtual int CodeNetHFile() = 0;
virtual int CodeNetCFile() = 0;
virtual int CodeWeightFile();
virtual int CodeRegKernelHFile();
void CodeNetRunFunc(std::ofstream &ofs);

View File

@ -21,14 +21,12 @@
#include "coder/opcoders/serializers/serializer.h"
#include "coder/opcoders/custom/custom_coder.h"
#include "coder/opcoders/op_coder_register.h"
#include "coder/user_registry/user_kernel_register.h"
#include "coder/opcoders/kernel_registry.h"
#include "src/common/prim_util.h"
#include "nnacl/custom_parameter.h"
using mindspore::schema::PrimitiveType_Custom;
#define MAX_TENSORS 16
#define MAX_STR_LEN 32
#define MAX_ATTR_NUM 8
#define MAX_TENSOR_NAME_LEN 32
namespace mindspore::lite::micro {
std::map<Tensor *, void *> CustomCoder::const_tensor_map_;
@ -51,14 +49,6 @@ int CustomCoder::Prepare(CoderContext *const context) {
MS_LOG(ERROR) << "Primitive type should be custom";
return RET_ERROR;
}
CoderKey key(target_, input_tensors_[0]->data_type(), PrimitiveType_Custom);
auto result = UserKernelFactory::GetInstance()->FindUserKernel(key);
if (result.empty()) {
MS_LOG(ERROR) << "No user kernel register for custom op";
return RET_ERROR;
}
header_ = result[0];
function_ = result[1];
Populate(node_->primitive_);
for (const auto &tensor : input_tensors_) {
if (tensor->category() == Tensor::Category::CONST_TENSOR) {
@ -69,7 +59,7 @@ int CustomCoder::Prepare(CoderContext *const context) {
}
}
}
Configurator::GetInstance()->SetCustomFlag();
return RET_OK;
}
@ -78,7 +68,7 @@ int CustomCoder::TransformTensors(Serializer *code, std::string array_name, cons
MS_LOG(ERROR) << "The number of tensors is too large";
return RET_ERROR;
}
(*code) << "\t\tCustomTensor " << array_name << "[" << tensors.size() << "];\n";
(*code) << "\t\tTensorC " << array_name << "[" << tensors.size() << "];\n";
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensors[i]->category() == Tensor::Category::CONST_TENSOR) {
if (!const_tensor_map_.count(tensors[i])) {
@ -86,22 +76,23 @@ int CustomCoder::TransformTensors(Serializer *code, std::string array_name, cons
return RET_ERROR;
}
(*code) << "\t\t" << array_name << "[" << i
<< "].data = " << allocator_->GetRuntimeAddr(const_tensor_map_[tensors[i]]) << ";\n";
<< "].data_ = " << allocator_->GetRuntimeAddr(const_tensor_map_[tensors[i]]) << ";\n";
} else {
(*code) << "\t\t" << array_name << "[" << i << "].data = " << allocator_->GetRuntimeAddr(tensors[i]) << ";\n";
(*code) << "\t\t" << array_name << "[" << i << "].data_ = " << allocator_->GetRuntimeAddr(tensors[i]) << ";\n";
}
(*code) << "\t\t" << array_name << "[" << i << "].data_size = " << tensors[i]->Size() << ";\n";
for (size_t j = 0; j < tensors[i]->shape().size(); ++j) {
(*code) << "\t\t" << array_name << "[" << i << "].shape[" << j << "] = " << tensors[i]->shape()[j] << ";\n";
(*code) << "\t\t" << array_name << "[" << i << "].shape_[" << j << "] = " << tensors[i]->shape()[j] << ";\n";
}
(*code) << "\t\t" << array_name << "[" << i << "].shape_size = " << tensors[i]->shape().size() << ";\n";
(*code) << "\t\t" << array_name << "[" << i << "].data_type = " << tensors[i]->data_type() << ";\n";
(*code) << "\t\t" << array_name << "[" << i << "].format = " << tensors[i]->format() << ";\n";
if (tensors[i]->tensor_name().size() > MAX_TENSOR_NAME_LEN) {
(*code) << "\t\t" << array_name << "[" << i << "].shape_size_ = " << tensors[i]->shape().size() << ";\n";
(*code) << "\t\t" << array_name << "[" << i << "].data_type_ = " << tensors[i]->data_type() << ";\n";
(*code) << "\t\t" << array_name << "[" << i << "].format_ = " << tensors[i]->format() << ";\n";
if (tensors[i]->tensor_name().size() > MAX_STR_LEN) {
MS_LOG(ERROR) << "tensor name is too long: " << tensors[i]->tensor_name();
return RET_ERROR;
}
(*code) << "\t\tstrcpy(" << array_name << "[" << i << "].name, "
(*code) << "\t\t" << array_name << "[" << i << "].name_ = "
<< "malloc(" << tensors[i]->tensor_name().length() + 1 << ");\n";
(*code) << "\t\tstrcpy(" << array_name << "[" << i << "].name_, "
<< "\"" << tensors[i]->tensor_name() << "\""
<< ");\n";
}
@ -115,7 +106,7 @@ int CustomCoder::TransformParams(Serializer *code, std::string var_name) {
return RET_ERROR;
}
(*code) << "\t\tCustomParams " << var_name << ";\n";
(*code) << "\t\tCustomParameter " << var_name << ";\n";
if (type_.size() > MAX_STR_LEN) {
MS_LOG(ERROR) << "type name is too long: " << type_;
return RET_ERROR;
@ -129,13 +120,11 @@ int CustomCoder::TransformParams(Serializer *code, std::string var_name) {
MS_LOG(ERROR) << "attr name is too long: " << iter->first;
return RET_ERROR;
}
if (iter->second.size() > MAX_STR_LEN) {
MS_LOG(ERROR) << "attr " << iter->first << " data is too long";
return RET_ERROR;
}
(*code) << "\t\tstrcpy(" << var_name << ".attr_name[" << i << "], "
<< "\"" << iter->first << "\""
<< ");\n";
(*code) << "\t\t" << var_name << ".attr_data[" << i << "] = "
<< "malloc(" << iter->second.size() + 1 << ");\n";
(*code) << "\t\tstrcpy(" << var_name << ".attr_data[" << i++ << "], "
<< "\"" << iter->second << "\""
<< ");\n";
@ -144,13 +133,29 @@ int CustomCoder::TransformParams(Serializer *code, std::string var_name) {
return RET_OK;
}
void CustomCoder::FreeParams(Serializer *code, std::string var_name) {
int i = 0;
for (auto iter = attrs_.begin(); iter != attrs_.end(); ++iter) {
(*code) << "\t\tfree(" << var_name << ".attr_data[" << i++ << "]);\n";
}
}
void CustomCoder::FreeTensors(Serializer *code, std::string array_name, size_t tensors_num) {
for (size_t i = 0; i < tensors_num; i++) {
(*code) << "\t\tfree(" << array_name << "[" << i << "].name_);\n";
}
}
int CustomCoder::DoCode(CoderContext *const context) {
Collect(context, {header_, "custom_params.h"}, {});
Collect(context, {"nnacl/custom_parameter.h", "nnacl/tensor_c.h", "registered_kernel.h"}, {});
Serializer code;
MS_CHECK_RET_CODE(TransformTensors(&code, "inputs", input_tensors_), "Transform input tensors error!");
MS_CHECK_RET_CODE(TransformTensors(&code, "outputs", output_tensors_), "Transform output tensors error!");
MS_CHECK_RET_CODE(TransformParams(&code, "param"), "Transform output tensors error!");
code.CodeFunction(function_, "inputs", input_tensors_.size(), "outputs", output_tensors_.size(), "&param");
code.CodeFunction(kCustomKernelName, "inputs", input_tensors_.size(), "outputs", output_tensors_.size(), "&param");
FreeParams(&code, "param");
FreeTensors(&code, "inputs", input_tensors_.size());
FreeTensors(&code, "outputs", output_tensors_.size());
context->AppendCode(code.str());
return 0;
}

View File

@ -38,11 +38,11 @@ class CustomCoder final : public OperatorCoder {
void Populate(const void *prim);
int TransformTensors(Serializer *code, std::string array_name, const std::vector<Tensor *> &tensors);
int TransformParams(Serializer *code, std::string var_name);
void FreeParams(Serializer *code, std::string var_name);
void FreeTensors(Serializer *code, std::string array_name, size_t tensors_num);
std::string type_;
std::map<std::string, std::string> attrs_;
std::string header_;
std::string function_;
static std::map<Tensor *, void *> const_tensor_map_;
};
} // namespace mindspore::lite::micro

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "coder/opcoders/kernel_registry.h"
#include <set>
#include "schema/ops_generated.h"
namespace mindspore::lite::micro {
KernelRegistry *KernelRegistry::GetInstance() {
static KernelRegistry reg;
return &reg;
}
void KernelRegistry::RegisterKernel(schema::PrimitiveType op) { registry_.insert(op); }
bool KernelRegistry::CheckRegistered(schema::PrimitiveType op) { return registry_.find(op) != registry_.end(); }
bool KernelRegistry::HasKernelRegistered() { return !registry_.empty(); }
std::string KernelRegistry::GenKernelInterface(const char *func, const char *param) {
return "int " + std::string(func) + "(TensorC *inputs, int input_num, TensorC *outputs, int output_num, " +
std::string(param) + " *param);";
}
} // namespace mindspore::lite::micro

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_KERNEL_REGISTER_H_
#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_KERNEL_REGISTER_H_
#include <map>
#include <vector>
#include <string>
#include <set>
#include "coder/config.h"
#include "coder/opcoders/op_coder_register.h"
#include "ir/dtype/type_id.h"
#include "schema/ops_generated.h"
namespace mindspore::lite::micro {
constexpr char kCustomKernelName[] = "CustomKernel";
constexpr char kCustomKernelParam[] = "CustomParameter";
class KernelRegistry {
public:
KernelRegistry() = default;
static KernelRegistry *GetInstance();
void RegisterKernel(schema::PrimitiveType op);
bool CheckRegistered(schema::PrimitiveType op);
bool HasKernelRegistered();
std::string GenKernelInterface(const char *func, const char *param);
~KernelRegistry() { registry_.clear(); }
private:
std::set<schema::PrimitiveType> registry_;
};
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_KERNEL_REGISTER_H_

View File

@ -25,6 +25,7 @@
#include "coder/generator/inference/inference_generator.h"
#include "coder/generator/train/train_generator.h"
#include "coder/opcoders/op_coder_builder.h"
#include "coder/opcoders/kernel_registry.h"
#include "coder/utils/coder_utils.h"
#include "coder/log.h"
#include "src/ops/populate/populate_register.h"
@ -35,6 +36,7 @@
#include "include/errorcode.h"
#include "include/model.h"
#include "src/common/file_utils.h"
#include "src/common/prim_util.h"
#include "coder/opcoders/nnacl/dequant/de_quant.h"
namespace mindspore::lite::micro {
@ -263,13 +265,13 @@ int CoderSession::CreateOpCoders() {
}
inputs.push_back(all_tensors.at(in_index));
}
for (auto ou_index : output_indices) {
ou_index = static_cast<size_t>(ou_index);
if (ou_index > all_tensors.size()) {
MS_LOG(ERROR) << "ou_index is invalid";
for (auto out_index : output_indices) {
out_index = static_cast<size_t>(out_index);
if (out_index > all_tensors.size()) {
MS_LOG(ERROR) << "out_index is invalid";
return RET_ERROR;
}
outputs.push_back(all_tensors.at(ou_index));
outputs.push_back(all_tensors.at(out_index));
}
if (inputs.empty()) {
MS_LOG(ERROR) << "node: " << node->name_ << "has no inputs tensor";
@ -281,9 +283,10 @@ int CoderSession::CreateOpCoders() {
}
OpParameter *parameter = nullptr;
if (lite::KernelInferShape(inputs, outputs, node->primitive_, std::set<std::string>{}, schema_version_) ==
lite::RET_NOT_SUPPORT) { // custom op infer
parameter = GenParameterAndInfer(node, inputs, &outputs); // general ops infer
if (IsCustomNode(node->primitive_, schema_version_)) {
KernelRegistry::GetInstance()->RegisterKernel(schema::PrimitiveType_Custom);
} else {
parameter = GenParameterAndInfer(node, inputs, &outputs); // built-in ops infer
MS_CHECK_PTR(parameter);
}

View File

@ -1,158 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "coder/user_registry/nnie_infer.h"
#include <string>
#include <iostream>
#include "include/errorcode.h"
#include "include/api/format.h"
#include "include/registry/register_kernel_interface.h"
#include "src/common/log_adapter.h"
using mindspore::kernel::KernelInterface;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Custom;
#define MAX_SIZE 1024
namespace mindspore {
namespace nnie {
std::shared_ptr<KernelInterface> CustomInferCreater() {
auto infer = new (std::nothrow) CustomInterface();
if (infer == nullptr) {
MS_LOG(ERROR) << "new custom infer is nullptr";
return nullptr;
}
return std::shared_ptr<KernelInterface>(infer);
}
int GetCustomShape(const mindspore::schema::Custom *op, const std::string &attr,
std::vector<std::vector<int64_t>> *shapes) {
char buf[MAX_SIZE];
bool has_outputs_shape = false;
for (size_t i = 0; i < op->attr()->size(); i++) {
if (op->attr()->Get(i)->name()->str() == attr) {
auto output_info = op->attr()->Get(i)->data();
int attr_size = static_cast<int>(output_info->size());
if (attr_size >= MAX_SIZE) {
MS_LOG(ERROR) << "attr size too big";
return RET_ERROR;
}
for (int j = 0; j < attr_size; j++) {
buf[j] = static_cast<char>(output_info->Get(j));
}
buf[attr_size] = 0;
has_outputs_shape = true;
break;
}
}
if (!has_outputs_shape) {
MS_LOG(ERROR) << "Custom op don't have " << attr.c_str() << " attr.";
return RET_ERROR;
}
char delims[] = ",";
char *res = nullptr;
char *save_ptr = nullptr;
res = strtok_r(buf, delims, &save_ptr);
while (res != nullptr) {
int64_t ndims = strtol(res, &res, 10);
int j = 0;
std::vector<int64_t> shape;
shape.resize(ndims);
for (; j < ndims; j++) {
res = strtok_r(NULL, delims, &save_ptr);
shape[j] = static_cast<int64_t>(strtol(res, &res, 10));
}
shapes->push_back(shape);
res = strtok_r(NULL, delims, &save_ptr);
}
return RET_OK;
}
Status CustomInterface::Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const mindspore::schema::Primitive *primitive) {
if (inputs->empty()) {
MS_LOG(ERROR) << "Inputs size 0";
return kLiteError;
}
if (outputs->empty()) {
MS_LOG(ERROR) << "Outputs size 0";
return kLiteError;
}
if (primitive->value_type() != mindspore::schema::PrimitiveType_Custom) {
MS_LOG(ERROR) << "Primitive type is not PrimitiveType_Custom";
return kLiteError;
}
auto op = primitive->value_as_Custom();
if (op->attr()->size() < 1) {
MS_LOG(ERROR) << "There are at least 1 attribute of Custom";
return kLiteError;
}
std::vector<std::vector<int64_t>> inputs_shape;
if (GetCustomShape(op, "inputs_shape", &inputs_shape) != RET_OK) {
MS_LOG(ERROR) << "parser inputs_shape attribute err.";
return kLiteError;
}
std::vector<std::vector<int64_t>> outputs_shape;
if (GetCustomShape(op, "outputs_shape", &outputs_shape) != RET_OK) {
MS_LOG(ERROR) << "parser outputs_shape attribute err.";
return kLiteError;
}
if (inputs_shape.size() != (inputs->size() - 1)) {
MS_LOG(ERROR) << "inputs num diff inputs_shape num.";
return kLiteError;
}
if (inputs_shape[0].size() != (*inputs)[0].Shape().size()) {
MS_LOG(ERROR) << "shape size err.";
return kLiteError;
}
bool resize_flag = false;
int resize_num = 1;
for (size_t i = 0; i < inputs_shape[0].size(); i++) {
if (inputs_shape[0][i] != (*inputs)[0].Shape()[i]) {
if (i == 0) {
resize_flag = true;
resize_num = (*inputs)[0].Shape()[i];
} else {
MS_LOG(ERROR) << "Custom of NNIE only support batch_num resize.";
return kLiteError;
}
}
}
if (resize_flag) {
for (auto &output_shape : outputs_shape) {
output_shape[0] = resize_num;
}
}
for (size_t i = 0; i < outputs->size(); i++) {
(*outputs)[i].SetShape(outputs_shape[i]);
(*outputs)[i].SetDataType(DataType::kNumberTypeFloat32);
(*outputs)[i].SetFormat(Format::NCHW);
}
return kSuccess;
}
} // namespace nnie
} // namespace mindspore
namespace mindspore {
namespace kernel {
REGISTER_CUSTOM_KERNEL_INTERFACE(NNIE, NNIE, nnie::CustomInferCreater);
} // namespace kernel
} // namespace mindspore

View File

@ -1,35 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_CUSTOM_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_CUSTOM_PARAMETER_H_
#include <vector>
#include <memory>
#include "include/kernel_interface.h"
namespace mindspore {
namespace nnie {
class CustomInterface : public mindspore::kernel::KernelInterface {
public:
CustomInterface() {}
~CustomInterface() = default;
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const mindspore::schema::Primitive *primitive) override;
};
} // namespace nnie
} // namespace mindspore
#endif // MINDSPORE_LITE_NNACL_CUSTOM_PARAMETER_H_

View File

@ -1,57 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "coder/user_registry/user_kernel_register.h"
#include <set>
namespace mindspore::lite::micro {
UserKernelFactory *UserKernelFactory::GetInstance() {
static UserKernelFactory reg;
return &reg;
}
int UserKernelFactory::RegistUserKernel(Target target, TypeId data_type, schema::PrimitiveType operator_type,
const std::string &header, const std::string &func, const std::string &lib) {
CoderKey key(target, data_type, operator_type);
if (user_kernel_sets_.find(key) != user_kernel_sets_.end()) {
MS_LOG(ERROR) << "Register the user kernel multiple times: " << key.ToString();
return RET_ERROR;
}
user_kernel_sets_[key] = {header, func, lib};
return RET_OK;
}
std::vector<std::string> UserKernelFactory::FindUserKernel(const CoderKey &key) {
if (user_kernel_sets_.find(key) != user_kernel_sets_.end()) {
return user_kernel_sets_[key];
} else {
return std::vector<std::string>{};
}
}
std::set<std::string> UserKernelFactory::UserKernelLibNames() {
std::set<std::string> names;
for (auto it = user_kernel_sets_.begin(); it != user_kernel_sets_.end(); ++it) {
names.insert(it->second[2]);
}
return names;
}
UserKernelRegister::UserKernelRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type,
const std::string &header, const std::string &func, const std::string &lib) {
UserKernelFactory::GetInstance()->RegistUserKernel(target, data_type, operator_type, header, func, lib);
}
} // namespace mindspore::lite::micro

View File

@ -1,64 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_USER_KERNEL_REGISTER_H_
#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_USER_KERNEL_REGISTER_H_
#include <map>
#include <vector>
#include <string>
#include <set>
#include "coder/config.h"
#include "coder/opcoders/op_coder_register.h"
#include "ir/dtype/type_id.h"
#include "schema/ops_generated.h"
namespace mindspore::lite::micro {
class UserKernelFactory {
public:
UserKernelFactory() = default;
static UserKernelFactory *GetInstance();
int RegistUserKernel(Target target, TypeId data_type, schema::PrimitiveType operator_type, const std::string &header,
const std::string &func, const std::string &lib);
std::vector<std::string> FindUserKernel(const CoderKey &key);
std::set<std::string> UserKernelLibNames();
~UserKernelFactory() { user_kernel_sets_.clear(); }
private:
std::map<CoderKey, std::vector<std::string>> user_kernel_sets_;
};
class UserKernelRegister {
public:
UserKernelRegister() = delete;
UserKernelRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, const std::string &header,
const std::string &func, const std::string &lib);
~UserKernelRegister() = default;
};
#define REG_USER_KERNEL(target, data_type, operator_type, header, func, lib) \
static UserKernelRegister g_user_##target##data_type##operator_type##function(target, data_type, operator_type, \
#header, #func, #lib);
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_USER_KERNEL_REGISTER_H_