!23028 [MSLITE] support custom kernel registry for codegen
Merge pull request !23028 from zhanyuan/nnie_micro_recitify
This commit is contained in:
commit
41019c77a1
mindspore
ccsrc/backend/kernel_compiler/cpu/nnacl
lite/micro
|
@ -13,14 +13,18 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_CUSTOM_PARAMETER_H_
|
||||
#define MINDSPORE_NNACL_CUSTOM_PARAMETER_H_
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
#include "coder/user_registry/user_kernel_register.h"
|
||||
#include "schema/ops_generated.h"
|
||||
#define MAX_STR_LEN 64
|
||||
#define MAX_ATTR_NUM 8
|
||||
|
||||
using mindspore::schema::PrimitiveType_Custom;
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
REG_USER_KERNEL(kARM32A, kNumberTypeInt8, PrimitiveType_Custom, nnie_micro.h, NnieKernel, nniemicro)
|
||||
REG_USER_KERNEL(kARM32A, kNumberTypeUInt8, PrimitiveType_Custom, nnie_micro.h, NnieKernel, nniemicro)
|
||||
REG_USER_KERNEL(kARM32A, kNumberTypeFloat32, PrimitiveType_Custom, nnie_micro.h, NnieKernel, nniemicro)
|
||||
} // namespace mindspore::lite::micro
|
||||
typedef struct CustomParameter {
|
||||
OpParameter op_parameter_;
|
||||
char type[MAX_STR_LEN];
|
||||
char attr_name[MAX_ATTR_NUM][MAX_STR_LEN];
|
||||
char *attr_data[MAX_ATTR_NUM];
|
||||
int attr_num;
|
||||
} CustomParameter;
|
||||
#endif // MINDSPORE_NNACL_CUSTOM_PARAMETER_H_
|
|
@ -20,7 +20,7 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include <memory.h>
|
||||
#include <mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/resize_parameter.h>
|
||||
#include "nnacl/resize_parameter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/crop_parameter.h"
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ typedef struct TensorC {
|
|||
void *data_;
|
||||
size_t shape_size_;
|
||||
int shape_[MAX_SHAPE_SIZE];
|
||||
char *name_;
|
||||
} TensorC;
|
||||
|
||||
#endif // MINDSPORE_NNACL_TENSOR_C_H_
|
||||
|
|
|
@ -154,9 +154,7 @@ set(LITE_SRC
|
|||
)
|
||||
|
||||
set(REGISTRY_SRC
|
||||
${MICRO_DIR}/coder/user_registry/user_kernel_register.cc
|
||||
${MICRO_DIR}/coder/user_registry/nnie_kernel_reg.cc
|
||||
${MICRO_DIR}/coder/user_registry/nnie_infer.cc
|
||||
${MICRO_DIR}/coder/opcoders/kernel_registry.cc
|
||||
)
|
||||
|
||||
list(APPEND FILE_SET ${CODER_SRC} ${CODER_OPCODERS_SRC} ${CODER_GENERATOR_SRC}
|
||||
|
|
|
@ -48,9 +48,6 @@ class Configurator {
|
|||
int ParseProjDir(std::string model_path);
|
||||
std::string proj_dir() const { return proj_dir_; }
|
||||
|
||||
void SetCustomFlag() { has_custom_op_ = true; }
|
||||
bool CustomFlag() const { return has_custom_op_; }
|
||||
|
||||
private:
|
||||
Configurator() = default;
|
||||
~Configurator() = default;
|
||||
|
@ -60,7 +57,6 @@ class Configurator {
|
|||
bool support_parallel_{false};
|
||||
bool debug_mode_{false};
|
||||
std::string proj_dir_;
|
||||
bool has_custom_op_{false};
|
||||
};
|
||||
} // namespace mindspore::lite::micro
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include "coder/generator/component/cmake_component.h"
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include "coder/user_registry/user_kernel_register.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
void CodeCMakeNetLibrary(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator *config) {
|
||||
|
@ -59,13 +58,5 @@ void CodeCMakeNetLibrary(std::ofstream &ofs, const std::unique_ptr<CoderContext>
|
|||
" ${CMAKE_CURRENT_SOURCE_DIR}/*.c\n"
|
||||
" )\n"
|
||||
"add_library(net STATIC ${NET_SRC})\n";
|
||||
auto user_libs = UserKernelFactory::GetInstance()->UserKernelLibNames();
|
||||
if (config->CustomFlag() && !user_libs.empty()) {
|
||||
ofs << "target_link_libraries(net";
|
||||
for (auto lib : user_libs) {
|
||||
ofs << " " << lib;
|
||||
}
|
||||
ofs << ")\n";
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite::micro
|
||||
|
|
|
@ -178,50 +178,4 @@ void *MTensor::MutableData() {
|
|||
} // namespace mindspore
|
||||
|
||||
)RAW";
|
||||
|
||||
const char custom_params_source[] = R"RAW(
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_CUSTOM_PARAMS_H_
|
||||
#define MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_CUSTOM_PARAMS_H_
|
||||
|
||||
#define MAX_TENSOR_NAME_LEN 32
|
||||
#define MAX_SHAPE_SIZE 8
|
||||
#define MAX_STR_LEN 32
|
||||
#define MAX_ATTR_NUM 8
|
||||
|
||||
typedef struct {
|
||||
void *data;
|
||||
int data_size;
|
||||
int shape[MAX_SHAPE_SIZE];
|
||||
int shape_size;
|
||||
char name[MAX_TENSOR_NAME_LEN];
|
||||
int data_type;
|
||||
int format;
|
||||
} CustomTensor;
|
||||
|
||||
typedef struct {
|
||||
char type[MAX_STR_LEN];
|
||||
char attr_name[MAX_ATTR_NUM][MAX_STR_LEN];
|
||||
char attr_data[MAX_ATTR_NUM][MAX_STR_LEN];
|
||||
int attr_num;
|
||||
} CustomParams;
|
||||
|
||||
#endif // MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_CUSTOM_PARAMS_H_
|
||||
|
||||
)RAW";
|
||||
|
||||
} // namespace mindspore::lite::micro
|
||||
|
|
|
@ -21,7 +21,6 @@ namespace mindspore::lite::micro {
|
|||
|
||||
extern const char tensor_header[];
|
||||
extern const char tensor_source[];
|
||||
extern const char custom_params_source[];
|
||||
|
||||
} // namespace mindspore::lite::micro
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "coder/generator/component/const_blocks/license.h"
|
||||
#include "coder/log.h"
|
||||
#include "coder/opcoders/parallel.h"
|
||||
#include "coder/opcoders/kernel_registry.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
int WriteContentToFile(const std::string &file, const std::string &content) {
|
||||
|
@ -97,9 +98,6 @@ int Generator::CodeStaticContent() {
|
|||
{net_src_file_path_ + "tensor.cc", tensor_source},
|
||||
{net_src_file_path_ + "mmodel.h", model_header}};
|
||||
|
||||
if (config_->CustomFlag()) {
|
||||
const_blocks.emplace_back(std::make_pair(net_src_file_path_ + "custom_params.h", custom_params_source));
|
||||
}
|
||||
if (config_->support_parallel()) {
|
||||
const_blocks.emplace_back(std::make_pair(net_src_file_path_ + kThreadWrapper, thread_header));
|
||||
}
|
||||
|
@ -167,6 +165,24 @@ int Generator::CodeWeightFile() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Generator::CodeRegKernelHFile() {
|
||||
if (!KernelRegistry::GetInstance()->HasKernelRegistered()) return RET_OK;
|
||||
if (!KernelRegistry::GetInstance()->CheckRegistered(schema::PrimitiveType_Custom)) {
|
||||
MS_LOG(ERROR) << "Only support custom kernel to register now!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
std::string reg_kernel_header = net_src_file_path_ + "registered_kernel.h";
|
||||
std::ofstream cofs(reg_kernel_header);
|
||||
MS_CHECK_TRUE(!cofs.bad(), "filed to open file");
|
||||
MS_LOG(INFO) << "write " << reg_kernel_header;
|
||||
cofs << g_hwLicense;
|
||||
cofs << "#include \"nnacl/tensor_c.h\"\n";
|
||||
cofs << "#include \"nnacl/custom_parameter.h\"\n\n";
|
||||
cofs << KernelRegistry::GetInstance()->GenKernelInterface(kCustomKernelName, kCustomKernelParam) << "\n";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Generator::GenerateCode() {
|
||||
MS_CHECK_RET_CODE(CodeNetHFile(), "code net h file failed.");
|
||||
MS_CHECK_RET_CODE(CodeNetCFile(), "code net c file failed.");
|
||||
|
@ -174,6 +190,7 @@ int Generator::GenerateCode() {
|
|||
MS_CHECK_RET_CODE(CodeSourceCMakeFile(), "code net cmake file failed.");
|
||||
MS_CHECK_RET_CODE(CodeStaticContent(), "code static content failed.");
|
||||
MS_CHECK_RET_CODE(CodeSessionImplement(), "code session file failed.");
|
||||
MS_CHECK_RET_CODE(CodeRegKernelHFile(), "code registered kernel header file failed.");
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite::micro
|
||||
|
|
|
@ -46,6 +46,7 @@ class Generator {
|
|||
virtual int CodeNetHFile() = 0;
|
||||
virtual int CodeNetCFile() = 0;
|
||||
virtual int CodeWeightFile();
|
||||
virtual int CodeRegKernelHFile();
|
||||
|
||||
void CodeNetRunFunc(std::ofstream &ofs);
|
||||
|
||||
|
|
|
@ -21,14 +21,12 @@
|
|||
#include "coder/opcoders/serializers/serializer.h"
|
||||
#include "coder/opcoders/custom/custom_coder.h"
|
||||
#include "coder/opcoders/op_coder_register.h"
|
||||
#include "coder/user_registry/user_kernel_register.h"
|
||||
#include "coder/opcoders/kernel_registry.h"
|
||||
#include "src/common/prim_util.h"
|
||||
#include "nnacl/custom_parameter.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_Custom;
|
||||
#define MAX_TENSORS 16
|
||||
#define MAX_STR_LEN 32
|
||||
#define MAX_ATTR_NUM 8
|
||||
#define MAX_TENSOR_NAME_LEN 32
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
std::map<Tensor *, void *> CustomCoder::const_tensor_map_;
|
||||
|
@ -51,14 +49,6 @@ int CustomCoder::Prepare(CoderContext *const context) {
|
|||
MS_LOG(ERROR) << "Primitive type should be custom";
|
||||
return RET_ERROR;
|
||||
}
|
||||
CoderKey key(target_, input_tensors_[0]->data_type(), PrimitiveType_Custom);
|
||||
auto result = UserKernelFactory::GetInstance()->FindUserKernel(key);
|
||||
if (result.empty()) {
|
||||
MS_LOG(ERROR) << "No user kernel register for custom op";
|
||||
return RET_ERROR;
|
||||
}
|
||||
header_ = result[0];
|
||||
function_ = result[1];
|
||||
Populate(node_->primitive_);
|
||||
for (const auto &tensor : input_tensors_) {
|
||||
if (tensor->category() == Tensor::Category::CONST_TENSOR) {
|
||||
|
@ -69,7 +59,7 @@ int CustomCoder::Prepare(CoderContext *const context) {
|
|||
}
|
||||
}
|
||||
}
|
||||
Configurator::GetInstance()->SetCustomFlag();
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -78,7 +68,7 @@ int CustomCoder::TransformTensors(Serializer *code, std::string array_name, cons
|
|||
MS_LOG(ERROR) << "The number of tensors is too large";
|
||||
return RET_ERROR;
|
||||
}
|
||||
(*code) << "\t\tCustomTensor " << array_name << "[" << tensors.size() << "];\n";
|
||||
(*code) << "\t\tTensorC " << array_name << "[" << tensors.size() << "];\n";
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
if (tensors[i]->category() == Tensor::Category::CONST_TENSOR) {
|
||||
if (!const_tensor_map_.count(tensors[i])) {
|
||||
|
@ -86,22 +76,23 @@ int CustomCoder::TransformTensors(Serializer *code, std::string array_name, cons
|
|||
return RET_ERROR;
|
||||
}
|
||||
(*code) << "\t\t" << array_name << "[" << i
|
||||
<< "].data = " << allocator_->GetRuntimeAddr(const_tensor_map_[tensors[i]]) << ";\n";
|
||||
<< "].data_ = " << allocator_->GetRuntimeAddr(const_tensor_map_[tensors[i]]) << ";\n";
|
||||
} else {
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].data = " << allocator_->GetRuntimeAddr(tensors[i]) << ";\n";
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].data_ = " << allocator_->GetRuntimeAddr(tensors[i]) << ";\n";
|
||||
}
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].data_size = " << tensors[i]->Size() << ";\n";
|
||||
for (size_t j = 0; j < tensors[i]->shape().size(); ++j) {
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].shape[" << j << "] = " << tensors[i]->shape()[j] << ";\n";
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].shape_[" << j << "] = " << tensors[i]->shape()[j] << ";\n";
|
||||
}
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].shape_size = " << tensors[i]->shape().size() << ";\n";
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].data_type = " << tensors[i]->data_type() << ";\n";
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].format = " << tensors[i]->format() << ";\n";
|
||||
if (tensors[i]->tensor_name().size() > MAX_TENSOR_NAME_LEN) {
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].shape_size_ = " << tensors[i]->shape().size() << ";\n";
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].data_type_ = " << tensors[i]->data_type() << ";\n";
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].format_ = " << tensors[i]->format() << ";\n";
|
||||
if (tensors[i]->tensor_name().size() > MAX_STR_LEN) {
|
||||
MS_LOG(ERROR) << "tensor name is too long: " << tensors[i]->tensor_name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
(*code) << "\t\tstrcpy(" << array_name << "[" << i << "].name, "
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].name_ = "
|
||||
<< "malloc(" << tensors[i]->tensor_name().length() + 1 << ");\n";
|
||||
(*code) << "\t\tstrcpy(" << array_name << "[" << i << "].name_, "
|
||||
<< "\"" << tensors[i]->tensor_name() << "\""
|
||||
<< ");\n";
|
||||
}
|
||||
|
@ -115,7 +106,7 @@ int CustomCoder::TransformParams(Serializer *code, std::string var_name) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
(*code) << "\t\tCustomParams " << var_name << ";\n";
|
||||
(*code) << "\t\tCustomParameter " << var_name << ";\n";
|
||||
if (type_.size() > MAX_STR_LEN) {
|
||||
MS_LOG(ERROR) << "type name is too long: " << type_;
|
||||
return RET_ERROR;
|
||||
|
@ -129,13 +120,11 @@ int CustomCoder::TransformParams(Serializer *code, std::string var_name) {
|
|||
MS_LOG(ERROR) << "attr name is too long: " << iter->first;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (iter->second.size() > MAX_STR_LEN) {
|
||||
MS_LOG(ERROR) << "attr " << iter->first << " data is too long";
|
||||
return RET_ERROR;
|
||||
}
|
||||
(*code) << "\t\tstrcpy(" << var_name << ".attr_name[" << i << "], "
|
||||
<< "\"" << iter->first << "\""
|
||||
<< ");\n";
|
||||
(*code) << "\t\t" << var_name << ".attr_data[" << i << "] = "
|
||||
<< "malloc(" << iter->second.size() + 1 << ");\n";
|
||||
(*code) << "\t\tstrcpy(" << var_name << ".attr_data[" << i++ << "], "
|
||||
<< "\"" << iter->second << "\""
|
||||
<< ");\n";
|
||||
|
@ -144,13 +133,29 @@ int CustomCoder::TransformParams(Serializer *code, std::string var_name) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void CustomCoder::FreeParams(Serializer *code, std::string var_name) {
|
||||
int i = 0;
|
||||
for (auto iter = attrs_.begin(); iter != attrs_.end(); ++iter) {
|
||||
(*code) << "\t\tfree(" << var_name << ".attr_data[" << i++ << "]);\n";
|
||||
}
|
||||
}
|
||||
|
||||
void CustomCoder::FreeTensors(Serializer *code, std::string array_name, size_t tensors_num) {
|
||||
for (size_t i = 0; i < tensors_num; i++) {
|
||||
(*code) << "\t\tfree(" << array_name << "[" << i << "].name_);\n";
|
||||
}
|
||||
}
|
||||
|
||||
int CustomCoder::DoCode(CoderContext *const context) {
|
||||
Collect(context, {header_, "custom_params.h"}, {});
|
||||
Collect(context, {"nnacl/custom_parameter.h", "nnacl/tensor_c.h", "registered_kernel.h"}, {});
|
||||
Serializer code;
|
||||
MS_CHECK_RET_CODE(TransformTensors(&code, "inputs", input_tensors_), "Transform input tensors error!");
|
||||
MS_CHECK_RET_CODE(TransformTensors(&code, "outputs", output_tensors_), "Transform output tensors error!");
|
||||
MS_CHECK_RET_CODE(TransformParams(&code, "param"), "Transform output tensors error!");
|
||||
code.CodeFunction(function_, "inputs", input_tensors_.size(), "outputs", output_tensors_.size(), "¶m");
|
||||
code.CodeFunction(kCustomKernelName, "inputs", input_tensors_.size(), "outputs", output_tensors_.size(), "¶m");
|
||||
FreeParams(&code, "param");
|
||||
FreeTensors(&code, "inputs", input_tensors_.size());
|
||||
FreeTensors(&code, "outputs", output_tensors_.size());
|
||||
context->AppendCode(code.str());
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -38,11 +38,11 @@ class CustomCoder final : public OperatorCoder {
|
|||
void Populate(const void *prim);
|
||||
int TransformTensors(Serializer *code, std::string array_name, const std::vector<Tensor *> &tensors);
|
||||
int TransformParams(Serializer *code, std::string var_name);
|
||||
void FreeParams(Serializer *code, std::string var_name);
|
||||
void FreeTensors(Serializer *code, std::string array_name, size_t tensors_num);
|
||||
|
||||
std::string type_;
|
||||
std::map<std::string, std::string> attrs_;
|
||||
std::string header_;
|
||||
std::string function_;
|
||||
static std::map<Tensor *, void *> const_tensor_map_;
|
||||
};
|
||||
} // namespace mindspore::lite::micro
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/opcoders/kernel_registry.h"
|
||||
#include <set>
|
||||
#include "schema/ops_generated.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
KernelRegistry *KernelRegistry::GetInstance() {
|
||||
static KernelRegistry reg;
|
||||
return ®
|
||||
}
|
||||
|
||||
void KernelRegistry::RegisterKernel(schema::PrimitiveType op) { registry_.insert(op); }
|
||||
|
||||
bool KernelRegistry::CheckRegistered(schema::PrimitiveType op) { return registry_.find(op) != registry_.end(); }
|
||||
|
||||
bool KernelRegistry::HasKernelRegistered() { return !registry_.empty(); }
|
||||
|
||||
std::string KernelRegistry::GenKernelInterface(const char *func, const char *param) {
|
||||
return "int " + std::string(func) + "(TensorC *inputs, int input_num, TensorC *outputs, int output_num, " +
|
||||
std::string(param) + " *param);";
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite::micro
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_KERNEL_REGISTER_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_KERNEL_REGISTER_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "coder/config.h"
|
||||
#include "coder/opcoders/op_coder_register.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "schema/ops_generated.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
|
||||
constexpr char kCustomKernelName[] = "CustomKernel";
|
||||
constexpr char kCustomKernelParam[] = "CustomParameter";
|
||||
|
||||
class KernelRegistry {
|
||||
public:
|
||||
KernelRegistry() = default;
|
||||
|
||||
static KernelRegistry *GetInstance();
|
||||
|
||||
void RegisterKernel(schema::PrimitiveType op);
|
||||
|
||||
bool CheckRegistered(schema::PrimitiveType op);
|
||||
|
||||
bool HasKernelRegistered();
|
||||
|
||||
std::string GenKernelInterface(const char *func, const char *param);
|
||||
|
||||
~KernelRegistry() { registry_.clear(); }
|
||||
|
||||
private:
|
||||
std::set<schema::PrimitiveType> registry_;
|
||||
};
|
||||
} // namespace mindspore::lite::micro
|
||||
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_KERNEL_REGISTER_H_
|
|
@ -25,6 +25,7 @@
|
|||
#include "coder/generator/inference/inference_generator.h"
|
||||
#include "coder/generator/train/train_generator.h"
|
||||
#include "coder/opcoders/op_coder_builder.h"
|
||||
#include "coder/opcoders/kernel_registry.h"
|
||||
#include "coder/utils/coder_utils.h"
|
||||
#include "coder/log.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
|
@ -35,6 +36,7 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "include/model.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/common/prim_util.h"
|
||||
#include "coder/opcoders/nnacl/dequant/de_quant.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
|
@ -263,13 +265,13 @@ int CoderSession::CreateOpCoders() {
|
|||
}
|
||||
inputs.push_back(all_tensors.at(in_index));
|
||||
}
|
||||
for (auto ou_index : output_indices) {
|
||||
ou_index = static_cast<size_t>(ou_index);
|
||||
if (ou_index > all_tensors.size()) {
|
||||
MS_LOG(ERROR) << "ou_index is invalid";
|
||||
for (auto out_index : output_indices) {
|
||||
out_index = static_cast<size_t>(out_index);
|
||||
if (out_index > all_tensors.size()) {
|
||||
MS_LOG(ERROR) << "out_index is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
outputs.push_back(all_tensors.at(ou_index));
|
||||
outputs.push_back(all_tensors.at(out_index));
|
||||
}
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(ERROR) << "node: " << node->name_ << "has no inputs tensor";
|
||||
|
@ -281,9 +283,10 @@ int CoderSession::CreateOpCoders() {
|
|||
}
|
||||
|
||||
OpParameter *parameter = nullptr;
|
||||
if (lite::KernelInferShape(inputs, outputs, node->primitive_, std::set<std::string>{}, schema_version_) ==
|
||||
lite::RET_NOT_SUPPORT) { // custom op infer
|
||||
parameter = GenParameterAndInfer(node, inputs, &outputs); // general ops infer
|
||||
if (IsCustomNode(node->primitive_, schema_version_)) {
|
||||
KernelRegistry::GetInstance()->RegisterKernel(schema::PrimitiveType_Custom);
|
||||
} else {
|
||||
parameter = GenParameterAndInfer(node, inputs, &outputs); // built-in ops infer
|
||||
MS_CHECK_PTR(parameter);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,158 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/user_registry/nnie_infer.h"
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/api/format.h"
|
||||
#include "include/registry/register_kernel_interface.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
using mindspore::kernel::KernelInterface;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Custom;
|
||||
#define MAX_SIZE 1024
|
||||
|
||||
namespace mindspore {
|
||||
namespace nnie {
|
||||
std::shared_ptr<KernelInterface> CustomInferCreater() {
|
||||
auto infer = new (std::nothrow) CustomInterface();
|
||||
if (infer == nullptr) {
|
||||
MS_LOG(ERROR) << "new custom infer is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
return std::shared_ptr<KernelInterface>(infer);
|
||||
}
|
||||
|
||||
int GetCustomShape(const mindspore::schema::Custom *op, const std::string &attr,
|
||||
std::vector<std::vector<int64_t>> *shapes) {
|
||||
char buf[MAX_SIZE];
|
||||
bool has_outputs_shape = false;
|
||||
|
||||
for (size_t i = 0; i < op->attr()->size(); i++) {
|
||||
if (op->attr()->Get(i)->name()->str() == attr) {
|
||||
auto output_info = op->attr()->Get(i)->data();
|
||||
int attr_size = static_cast<int>(output_info->size());
|
||||
if (attr_size >= MAX_SIZE) {
|
||||
MS_LOG(ERROR) << "attr size too big";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int j = 0; j < attr_size; j++) {
|
||||
buf[j] = static_cast<char>(output_info->Get(j));
|
||||
}
|
||||
buf[attr_size] = 0;
|
||||
has_outputs_shape = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_outputs_shape) {
|
||||
MS_LOG(ERROR) << "Custom op don't have " << attr.c_str() << " attr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
char delims[] = ",";
|
||||
char *res = nullptr;
|
||||
char *save_ptr = nullptr;
|
||||
res = strtok_r(buf, delims, &save_ptr);
|
||||
while (res != nullptr) {
|
||||
int64_t ndims = strtol(res, &res, 10);
|
||||
int j = 0;
|
||||
std::vector<int64_t> shape;
|
||||
shape.resize(ndims);
|
||||
for (; j < ndims; j++) {
|
||||
res = strtok_r(NULL, delims, &save_ptr);
|
||||
shape[j] = static_cast<int64_t>(strtol(res, &res, 10));
|
||||
}
|
||||
shapes->push_back(shape);
|
||||
|
||||
res = strtok_r(NULL, delims, &save_ptr);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
Status CustomInterface::Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const mindspore::schema::Primitive *primitive) {
|
||||
if (inputs->empty()) {
|
||||
MS_LOG(ERROR) << "Inputs size 0";
|
||||
return kLiteError;
|
||||
}
|
||||
if (outputs->empty()) {
|
||||
MS_LOG(ERROR) << "Outputs size 0";
|
||||
return kLiteError;
|
||||
}
|
||||
if (primitive->value_type() != mindspore::schema::PrimitiveType_Custom) {
|
||||
MS_LOG(ERROR) << "Primitive type is not PrimitiveType_Custom";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
auto op = primitive->value_as_Custom();
|
||||
if (op->attr()->size() < 1) {
|
||||
MS_LOG(ERROR) << "There are at least 1 attribute of Custom";
|
||||
return kLiteError;
|
||||
}
|
||||
std::vector<std::vector<int64_t>> inputs_shape;
|
||||
if (GetCustomShape(op, "inputs_shape", &inputs_shape) != RET_OK) {
|
||||
MS_LOG(ERROR) << "parser inputs_shape attribute err.";
|
||||
return kLiteError;
|
||||
}
|
||||
std::vector<std::vector<int64_t>> outputs_shape;
|
||||
if (GetCustomShape(op, "outputs_shape", &outputs_shape) != RET_OK) {
|
||||
MS_LOG(ERROR) << "parser outputs_shape attribute err.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (inputs_shape.size() != (inputs->size() - 1)) {
|
||||
MS_LOG(ERROR) << "inputs num diff inputs_shape num.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (inputs_shape[0].size() != (*inputs)[0].Shape().size()) {
|
||||
MS_LOG(ERROR) << "shape size err.";
|
||||
return kLiteError;
|
||||
}
|
||||
bool resize_flag = false;
|
||||
int resize_num = 1;
|
||||
for (size_t i = 0; i < inputs_shape[0].size(); i++) {
|
||||
if (inputs_shape[0][i] != (*inputs)[0].Shape()[i]) {
|
||||
if (i == 0) {
|
||||
resize_flag = true;
|
||||
resize_num = (*inputs)[0].Shape()[i];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Custom of NNIE only support batch_num resize.";
|
||||
return kLiteError;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (resize_flag) {
|
||||
for (auto &output_shape : outputs_shape) {
|
||||
output_shape[0] = resize_num;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < outputs->size(); i++) {
|
||||
(*outputs)[i].SetShape(outputs_shape[i]);
|
||||
(*outputs)[i].SetDataType(DataType::kNumberTypeFloat32);
|
||||
(*outputs)[i].SetFormat(Format::NCHW);
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace nnie
|
||||
} // namespace mindspore
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
REGISTER_CUSTOM_KERNEL_INTERFACE(NNIE, NNIE, nnie::CustomInferCreater);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,35 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_CUSTOM_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_NNACL_CUSTOM_PARAMETER_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/kernel_interface.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace nnie {
|
||||
class CustomInterface : public mindspore::kernel::KernelInterface {
|
||||
public:
|
||||
CustomInterface() {}
|
||||
|
||||
~CustomInterface() = default;
|
||||
|
||||
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const mindspore::schema::Primitive *primitive) override;
|
||||
};
|
||||
} // namespace nnie
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_NNACL_CUSTOM_PARAMETER_H_
|
|
@ -1,57 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/user_registry/user_kernel_register.h"
|
||||
#include <set>
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
UserKernelFactory *UserKernelFactory::GetInstance() {
|
||||
static UserKernelFactory reg;
|
||||
return ®
|
||||
}
|
||||
|
||||
int UserKernelFactory::RegistUserKernel(Target target, TypeId data_type, schema::PrimitiveType operator_type,
|
||||
const std::string &header, const std::string &func, const std::string &lib) {
|
||||
CoderKey key(target, data_type, operator_type);
|
||||
if (user_kernel_sets_.find(key) != user_kernel_sets_.end()) {
|
||||
MS_LOG(ERROR) << "Register the user kernel multiple times: " << key.ToString();
|
||||
return RET_ERROR;
|
||||
}
|
||||
user_kernel_sets_[key] = {header, func, lib};
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::vector<std::string> UserKernelFactory::FindUserKernel(const CoderKey &key) {
|
||||
if (user_kernel_sets_.find(key) != user_kernel_sets_.end()) {
|
||||
return user_kernel_sets_[key];
|
||||
} else {
|
||||
return std::vector<std::string>{};
|
||||
}
|
||||
}
|
||||
|
||||
std::set<std::string> UserKernelFactory::UserKernelLibNames() {
|
||||
std::set<std::string> names;
|
||||
for (auto it = user_kernel_sets_.begin(); it != user_kernel_sets_.end(); ++it) {
|
||||
names.insert(it->second[2]);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
UserKernelRegister::UserKernelRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type,
|
||||
const std::string &header, const std::string &func, const std::string &lib) {
|
||||
UserKernelFactory::GetInstance()->RegistUserKernel(target, data_type, operator_type, header, func, lib);
|
||||
}
|
||||
} // namespace mindspore::lite::micro
|
|
@ -1,64 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_USER_KERNEL_REGISTER_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_USER_KERNEL_REGISTER_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "coder/config.h"
|
||||
#include "coder/opcoders/op_coder_register.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "schema/ops_generated.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
class UserKernelFactory {
|
||||
public:
|
||||
UserKernelFactory() = default;
|
||||
|
||||
static UserKernelFactory *GetInstance();
|
||||
|
||||
int RegistUserKernel(Target target, TypeId data_type, schema::PrimitiveType operator_type, const std::string &header,
|
||||
const std::string &func, const std::string &lib);
|
||||
|
||||
std::vector<std::string> FindUserKernel(const CoderKey &key);
|
||||
|
||||
std::set<std::string> UserKernelLibNames();
|
||||
|
||||
~UserKernelFactory() { user_kernel_sets_.clear(); }
|
||||
|
||||
private:
|
||||
std::map<CoderKey, std::vector<std::string>> user_kernel_sets_;
|
||||
};
|
||||
|
||||
class UserKernelRegister {
|
||||
public:
|
||||
UserKernelRegister() = delete;
|
||||
|
||||
UserKernelRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, const std::string &header,
|
||||
const std::string &func, const std::string &lib);
|
||||
|
||||
~UserKernelRegister() = default;
|
||||
};
|
||||
|
||||
#define REG_USER_KERNEL(target, data_type, operator_type, header, func, lib) \
|
||||
static UserKernelRegister g_user_##target##data_type##operator_type##function(target, data_type, operator_type, \
|
||||
#header, #func, #lib);
|
||||
} // namespace mindspore::lite::micro
|
||||
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_USER_KERNEL_REGISTER_H_
|
Loading…
Reference in New Issue