forked from mindspore-Ecosystem/mindspore
!22048 [MSLITE] support custom op for codegen
Merge pull request !22048 from zhanyuan/custom_op3
This commit is contained in:
commit
570b3911ec
|
@ -122,8 +122,11 @@ set(CODER_OPCODERS_SRC
|
|||
${MICRO_DIR}/coder/opcoders/nnacl/int8/sigmoid_int8_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/int8/relux_int8_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/int8/div_int8_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/int8/transpose_int8_coder.cc
|
||||
#### nnacl dequant coder
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/dequant/de_quant.cc
|
||||
#### custom
|
||||
${MICRO_DIR}/coder/opcoders/custom/custom_coder.cc
|
||||
)
|
||||
|
||||
set(LITE_SRC
|
||||
|
@ -154,6 +157,11 @@ set(LITE_SRC
|
|||
${LITE_DIR}/tools/common/flag_parser.cc
|
||||
)
|
||||
|
||||
set(REGISTRY_SRC
|
||||
${MICRO_DIR}/coder/user_registry/user_kernel_register.cc
|
||||
${MICRO_DIR}/coder/user_registry/nnie_kernel_reg.cc
|
||||
${MICRO_DIR}/coder/user_registry/nnie_infer.cc
|
||||
)
|
||||
|
||||
list(APPEND FILE_SET ${CODER_SRC} ${CODER_OPCODERS_SRC} ${CODER_GENERATOR_SRC}
|
||||
${CODER_ALLOCATOR_SRC} ${LITE_SRC} ${MINDSPORE_CORE})
|
||||
${CODER_ALLOCATOR_SRC} ${LITE_SRC} ${MINDSPORE_CORE} ${REGISTRY_SRC})
|
||||
|
|
|
@ -48,6 +48,9 @@ class Configurator {
|
|||
int ParseProjDir(std::string model_path);
|
||||
std::string proj_dir() const { return proj_dir_; }
|
||||
|
||||
void SetCustomFlag() { has_custom_op_ = true; }
|
||||
bool CustomFlag() const { return has_custom_op_; }
|
||||
|
||||
private:
|
||||
Configurator() = default;
|
||||
~Configurator() = default;
|
||||
|
@ -57,6 +60,7 @@ class Configurator {
|
|||
bool support_parallel_{false};
|
||||
bool debug_mode_{false};
|
||||
std::string proj_dir_;
|
||||
bool has_custom_op_{false};
|
||||
};
|
||||
} // namespace mindspore::lite::micro
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "coder/generator/component/cmake_component.h"
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include "coder/user_registry/user_kernel_register.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
void CodeCMakeNetLibrary(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator *config) {
|
||||
|
@ -58,5 +59,13 @@ void CodeCMakeNetLibrary(std::ofstream &ofs, const std::unique_ptr<CoderContext>
|
|||
" ${CMAKE_CURRENT_SOURCE_DIR}/*.c\n"
|
||||
" )\n"
|
||||
"add_library(net STATIC ${NET_SRC})\n";
|
||||
auto user_libs = UserKernelFactory::GetInstance()->UserKernelLibNames();
|
||||
if (config->CustomFlag() && !user_libs.empty()) {
|
||||
ofs << "target_link_libraries(net";
|
||||
for (auto lib : user_libs) {
|
||||
ofs << " " << lib;
|
||||
}
|
||||
ofs << ")\n";
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite::micro
|
||||
|
|
|
@ -201,8 +201,7 @@ int Calibrator::CompareOutputs(const Vector<tensor::MSTensor *> &outputs) const
|
|||
CalibTensor *calib = calib_outputs_[i];
|
||||
MS_ERROR_IF_NULL(calib);
|
||||
if (output->tensor_name() != calib->tensor_name()) {
|
||||
printf("error, output tensor name is not equal to calib\n");
|
||||
return RET_ERROR;
|
||||
printf("warning, output tensor name is not equal to calib\n");
|
||||
}
|
||||
if (output->ElementsNum() != calib->ElementsNum()) {
|
||||
printf("error, output elements num is not equal to calib\n");
|
||||
|
|
|
@ -65,6 +65,7 @@ else()
|
|||
string(REPLACE "-g" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
|
||||
string(REPLACE "-g" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
endif()
|
||||
string(APPEND CMAKE_EXE_LINKER_FLAGS " -Wl,--gc-sections")
|
||||
|
||||
add_subdirectory(src)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
|
|
@ -178,4 +178,50 @@ void *MTensor::MutableData() {
|
|||
} // namespace mindspore
|
||||
|
||||
)RAW";
|
||||
|
||||
const char custom_params_source[] = R"RAW(
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_CUSTOM_PARAMS_H_
|
||||
#define MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_CUSTOM_PARAMS_H_
|
||||
|
||||
#define MAX_TENSOR_NAME_LEN 32
|
||||
#define MAX_SHAPE_SIZE 8
|
||||
#define MAX_STR_LEN 32
|
||||
#define MAX_ATTR_NUM 8
|
||||
|
||||
typedef struct {
|
||||
void *data;
|
||||
int data_size;
|
||||
int shape[MAX_SHAPE_SIZE];
|
||||
int shape_size;
|
||||
char name[MAX_TENSOR_NAME_LEN];
|
||||
int data_type;
|
||||
int format;
|
||||
} CustomTensor;
|
||||
|
||||
typedef struct {
|
||||
char type[MAX_STR_LEN];
|
||||
char attr_name[MAX_ATTR_NUM][MAX_STR_LEN];
|
||||
char attr_data[MAX_ATTR_NUM][MAX_STR_LEN];
|
||||
int attr_num;
|
||||
} CustomParams;
|
||||
|
||||
#endif // MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_CUSTOM_PARAMS_H_
|
||||
|
||||
)RAW";
|
||||
|
||||
} // namespace mindspore::lite::micro
|
||||
|
|
|
@ -21,6 +21,7 @@ namespace mindspore::lite::micro {
|
|||
|
||||
extern const char tensor_header[];
|
||||
extern const char tensor_source[];
|
||||
extern const char custom_params_source[];
|
||||
|
||||
} // namespace mindspore::lite::micro
|
||||
|
||||
|
|
|
@ -96,6 +96,10 @@ int Generator::CodeStaticContent() {
|
|||
{net_src_file_path_ + "tensor.h", tensor_header},
|
||||
{net_src_file_path_ + "tensor.cc", tensor_source},
|
||||
{net_src_file_path_ + "mmodel.h", model_header}};
|
||||
|
||||
if (config_->CustomFlag()) {
|
||||
const_blocks.emplace_back(std::make_pair(net_src_file_path_ + "custom_params.h", custom_params_source));
|
||||
}
|
||||
if (config_->support_parallel()) {
|
||||
const_blocks.emplace_back(std::make_pair(net_src_file_path_ + kThreadWrapper, thread_header));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
#include "coder/opcoders/serializers/serializer.h"
|
||||
#include "coder/opcoders/custom/custom_coder.h"
|
||||
#include "coder/opcoders/op_coder_register.h"
|
||||
#include "coder/user_registry/user_kernel_register.h"
|
||||
#include "src/common/prim_util.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_Custom;
|
||||
#define MAX_TENSORS 16
|
||||
#define MAX_STR_LEN 32
|
||||
#define MAX_ATTR_NUM 8
|
||||
#define MAX_TENSOR_NAME_LEN 32
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
std::map<Tensor *, void *> CustomCoder::const_tensor_map_;
|
||||
|
||||
void CustomCoder::Populate(const void *prim) {
|
||||
auto op = static_cast<const schema::Primitive *>(prim)->value_as_Custom();
|
||||
type_ = op->type()->str();
|
||||
for (size_t i = 0; i < op->attr()->size(); ++i) {
|
||||
auto attr = op->attr()->Get(i);
|
||||
std::string data;
|
||||
for (size_t j = 0; j < attr->data()->size(); ++j) {
|
||||
data.push_back(static_cast<char>(attr->data()->Get(j)));
|
||||
}
|
||||
attrs_[attr->name()->str()] = data;
|
||||
}
|
||||
}
|
||||
|
||||
int CustomCoder::Prepare(CoderContext *const context) {
|
||||
if (GetPrimitiveType(node_->primitive_, schema_version_) != PrimitiveType_Custom) {
|
||||
MS_LOG(ERROR) << "Primitive type should be custom";
|
||||
return RET_ERROR;
|
||||
}
|
||||
CoderKey key(target_, input_tensors_[0]->data_type(), PrimitiveType_Custom);
|
||||
auto result = UserKernelFactory::GetInstance()->FindUserKernel(key);
|
||||
if (result.empty()) {
|
||||
MS_LOG(ERROR) << "No user kernel register for custom op";
|
||||
return RET_ERROR;
|
||||
}
|
||||
header_ = result[0];
|
||||
function_ = result[1];
|
||||
Populate(node_->primitive_);
|
||||
for (const auto &tensor : input_tensors_) {
|
||||
if (tensor->category() == Tensor::Category::CONST_TENSOR) {
|
||||
if (!const_tensor_map_.count(tensor)) {
|
||||
auto buff = allocator_->Malloc(kNumberTypeUInt8, tensor->Size(), kOfflinePackWeight);
|
||||
memcpy_s(buff, tensor->Size(), tensor->data(), tensor->Size());
|
||||
const_tensor_map_[tensor] = buff;
|
||||
}
|
||||
}
|
||||
}
|
||||
Configurator::GetInstance()->SetCustomFlag();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int CustomCoder::TransformTensors(Serializer *code, std::string array_name, const std::vector<Tensor *> &tensors) {
|
||||
if (tensors.size() > MAX_TENSORS) {
|
||||
MS_LOG(ERROR) << "The number of tensors is too large";
|
||||
return RET_ERROR;
|
||||
}
|
||||
(*code) << "\t\tCustomTensor " << array_name << "[" << tensors.size() << "];\n";
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
if (tensors[i]->category() == Tensor::Category::CONST_TENSOR) {
|
||||
if (!const_tensor_map_.count(tensors[i])) {
|
||||
MS_LOG(ERROR) << "can't find the const tensor's runtime address";
|
||||
return RET_ERROR;
|
||||
}
|
||||
(*code) << "\t\t" << array_name << "[" << i
|
||||
<< "].data = " << allocator_->GetRuntimeAddr(const_tensor_map_[tensors[i]]) << ";\n";
|
||||
} else {
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].data = " << allocator_->GetRuntimeAddr(tensors[i]) << ";\n";
|
||||
}
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].data_size = " << tensors[i]->Size() << ";\n";
|
||||
for (size_t j = 0; j < tensors[i]->shape().size(); ++j) {
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].shape[" << j << "] = " << tensors[i]->shape()[j] << ";\n";
|
||||
}
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].shape_size = " << tensors[i]->shape().size() << ";\n";
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].data_type = " << tensors[i]->data_type() << ";\n";
|
||||
(*code) << "\t\t" << array_name << "[" << i << "].format = " << tensors[i]->format() << ";\n";
|
||||
if (tensors[i]->tensor_name().size() > MAX_TENSOR_NAME_LEN) {
|
||||
MS_LOG(ERROR) << "tensor name is too long: " << tensors[i]->tensor_name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
(*code) << "\t\tstrcpy(" << array_name << "[" << i << "].name, "
|
||||
<< "\"" << tensors[i]->tensor_name() << "\""
|
||||
<< ");\n";
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int CustomCoder::TransformParams(Serializer *code, std::string var_name) {
|
||||
if (attrs_.size() > MAX_ATTR_NUM) {
|
||||
MS_LOG(ERROR) << "Attrs's number exceeds the maximum";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
(*code) << "\t\tCustomParams " << var_name << ";\n";
|
||||
if (type_.size() > MAX_STR_LEN) {
|
||||
MS_LOG(ERROR) << "type name is too long: " << type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
(*code) << "\t\tstrcpy(" << var_name << ".type, "
|
||||
<< "\"" << type_ << "\""
|
||||
<< ");\n";
|
||||
int i = 0;
|
||||
for (auto iter = attrs_.begin(); iter != attrs_.end(); ++iter) {
|
||||
if (iter->first.size() > MAX_STR_LEN) {
|
||||
MS_LOG(ERROR) << "attr name is too long: " << iter->first;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (iter->second.size() > MAX_STR_LEN) {
|
||||
MS_LOG(ERROR) << "attr " << iter->first << " data is too long";
|
||||
return RET_ERROR;
|
||||
}
|
||||
(*code) << "\t\tstrcpy(" << var_name << ".attr_name[" << i << "], "
|
||||
<< "\"" << iter->first << "\""
|
||||
<< ");\n";
|
||||
(*code) << "\t\tstrcpy(" << var_name << ".attr_data[" << i++ << "], "
|
||||
<< "\"" << iter->second << "\""
|
||||
<< ");\n";
|
||||
}
|
||||
(*code) << "\t\t" << var_name << ".attr_num = " << attrs_.size() << ";\n";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int CustomCoder::DoCode(CoderContext *const context) {
|
||||
Collect(context, {header_, "custom_params.h"}, {});
|
||||
Serializer code;
|
||||
MS_CHECK_RET_CODE(TransformTensors(&code, "inputs", input_tensors_), "Transform input tensors error!");
|
||||
MS_CHECK_RET_CODE(TransformTensors(&code, "outputs", output_tensors_), "Transform output tensors error!");
|
||||
MS_CHECK_RET_CODE(TransformParams(&code, "param"), "Transform output tensors error!");
|
||||
code.CodeFunction(function_, "inputs", input_tensors_.size(), "outputs", output_tensors_.size(), "¶m");
|
||||
context->AppendCode(code.str());
|
||||
return 0;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_Custom, CPUOpCoderCreator<CustomCoder>)
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeUInt8, PrimitiveType_Custom, CPUOpCoderCreator<CustomCoder>)
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Custom, CPUOpCoderCreator<CustomCoder>)
|
||||
} // namespace mindspore::lite::micro
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_CUSTOM_CODER_H
|
||||
#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_CUSTOM_CODER_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
class CustomCoder final : public OperatorCoder {
|
||||
public:
|
||||
CustomCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const Model::Node *node, size_t node_index, Target target)
|
||||
: OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
|
||||
|
||||
~CustomCoder() override = default;
|
||||
|
||||
int Prepare(CoderContext *const context) override;
|
||||
|
||||
int DoCode(CoderContext *const context) override;
|
||||
|
||||
private:
|
||||
void Populate(const void *prim);
|
||||
int TransformTensors(Serializer *code, std::string array_name, const std::vector<Tensor *> &tensors);
|
||||
int TransformParams(Serializer *code, std::string var_name);
|
||||
|
||||
std::string type_;
|
||||
std::map<std::string, std::string> attrs_;
|
||||
std::string header_;
|
||||
std::string function_;
|
||||
static std::map<Tensor *, void *> const_tensor_map_;
|
||||
};
|
||||
} // namespace mindspore::lite::micro
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_CUSTOM_CODER_H
|
|
@ -1,57 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/opcoders/nnacl/fp32/nchw2nhwc_fp32_coder.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_Nchw2Nhwc;
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
int Nchw2NhwcFP32Coder::Prepare(CoderContext *const context) { return RET_OK; }
|
||||
|
||||
int Nchw2NhwcFP32Coder::DoCode(CoderContext *context) {
|
||||
// generate code .h .c
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/pack.h",
|
||||
},
|
||||
{
|
||||
"nnacl/pack.c",
|
||||
});
|
||||
NNaclFp32Serializer code;
|
||||
if (input_tensor_->shape().size() == DIMENSION_4D) {
|
||||
if (input_tensor_->data_type() == kNumberTypeFloat32) {
|
||||
code.CodeFunction("PackNCHWToNHWCFp32", input_tensor_, output_tensor_, output_tensor_->Batch(),
|
||||
output_tensor_->Height() * output_tensor_->Width(), output_tensor_->Channel());
|
||||
} else if (input_tensor_->data_type() == kNumberTypeInt8) {
|
||||
code.CodeFunction("PackNCHWToNHWCInt8", input_tensor_, output_tensor_, output_tensor_->Batch(),
|
||||
output_tensor_->Height() * output_tensor_->Width(), output_tensor_->Channel());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported format transform";
|
||||
}
|
||||
} else {
|
||||
code.CodeFunction("memcpy", output_tensor_, input_tensor_, input_tensor_->ElementsNum() * sizeof(float));
|
||||
}
|
||||
|
||||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CPUOpCoderCreator<Nchw2NhwcFP32Coder>)
|
||||
} // namespace mindspore::lite::micro::nnacl
|
|
@ -1,38 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_NCHW2FP32_CODER_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_NCHW2FP32_CODER_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
#include "nnacl/base/tile_base.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
class Nchw2NhwcFP32Coder final : public OperatorCoder {
|
||||
public:
|
||||
Nchw2NhwcFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const Model::Node *node, size_t node_index, Target target)
|
||||
: OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
|
||||
|
||||
~Nchw2NhwcFP32Coder() override = default;
|
||||
int Prepare(CoderContext *const context) override;
|
||||
|
||||
int DoCode(CoderContext *const context) override;
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_NCHW2FP32_CODER_H_
|
|
@ -1,56 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/opcoders/nnacl/fp32/nhwc2nchw_fp32_coder.h"
|
||||
#include <string>
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_Nhwc2Nchw;
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
int Nhwc2NchwFP32Coder::Prepare(CoderContext *const context) { return RET_OK; }
|
||||
|
||||
int Nhwc2NchwFP32Coder::DoCode(CoderContext *const context) {
|
||||
// generate code .h .c
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/pack.h",
|
||||
},
|
||||
{
|
||||
"pack.c",
|
||||
});
|
||||
|
||||
NNaclFp32Serializer code;
|
||||
if (input_tensor_->shape().size() == DIMENSION_4D) {
|
||||
if (input_tensor_->data_type() == kNumberTypeFloat32) {
|
||||
code.CodeFunction("PackNHWCToNCHWFp32", input_tensor_, output_tensor_, output_tensor_->Batch(),
|
||||
output_tensor_->Height() * output_tensor_->Width(), output_tensor_->Channel());
|
||||
} else if (input_tensor_->data_type() == kNumberTypeInt8) {
|
||||
code.CodeFunction("PackNHWCToNCHWInt8", input_tensor_, output_tensor_, output_tensor_->Batch(),
|
||||
output_tensor_->Height() * output_tensor_->Width(), output_tensor_->Channel());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported format transform";
|
||||
}
|
||||
} else {
|
||||
code.CodeFunction("memcpy", output_tensor_, input_tensor_, input_tensor_->ElementsNum() * sizeof(float));
|
||||
}
|
||||
|
||||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CPUOpCoderCreator<Nhwc2NchwFP32Coder>)
|
||||
} // namespace mindspore::lite::micro::nnacl
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/opcoders/nnacl/int8/transpose_int8_coder.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "coder/log.h"
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h"
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_Transpose;
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
|
||||
int TransposeInt8Coder::Prepare(CoderContext *const context) {
|
||||
auto in_tensor = input_tensors_.front();
|
||||
auto out_tensor = output_tensors_.front();
|
||||
auto in_shape = in_tensor->shape();
|
||||
auto out_shape = out_tensor->shape();
|
||||
param_ = reinterpret_cast<TransposeParameter *>(parameter_);
|
||||
param_->data_num_ = in_tensor->ElementsNum();
|
||||
|
||||
auto perm_tensor = input_tensors_.at(1);
|
||||
int *perm_data = reinterpret_cast<int *>(perm_tensor->data_c());
|
||||
MS_ASSERT(perm_data != nullptr);
|
||||
param_->num_axes_ = perm_tensor->ElementsNum();
|
||||
for (int i = 0; i < param_->num_axes_; ++i) {
|
||||
param_->perm_[i] = perm_data[i];
|
||||
}
|
||||
param_->strides_[param_->num_axes_ - 1] = 1;
|
||||
param_->out_strides_[param_->num_axes_ - 1] = 1;
|
||||
for (int i = param_->num_axes_ - 2; i >= 0; i--) {
|
||||
param_->strides_[i] = in_shape.at(i + 1) * param_->strides_[i + 1];
|
||||
param_->out_strides_[i] = out_shape.at(i + 1) * param_->out_strides_[i + 1];
|
||||
}
|
||||
|
||||
if (in_tensor->shape().size() == DIMENSION_4D &&
|
||||
((param_->perm_[0] == 0 && param_->perm_[1] == 2 && param_->perm_[2] == 3 && param_->perm_[3] == 1) ||
|
||||
(param_->perm_[0] == 0 && param_->perm_[1] == 3 && param_->perm_[2] == 1 && param_->perm_[3] == 2))) {
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Currently transpose op only supports NCHW<->NHWC";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
int TransposeInt8Coder::DoCode(CoderContext *const context) {
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/int8/pack_int8.h",
|
||||
},
|
||||
{
|
||||
"pack_int8.c",
|
||||
});
|
||||
|
||||
NNaclInt8Serializer code;
|
||||
auto out_shape = output_tensors_[0]->shape();
|
||||
if (param_->perm_[0] == 0 && param_->perm_[1] == 2 && param_->perm_[2] == 3 && param_->perm_[3] == 1) {
|
||||
code.CodeFunction("PackNCHWToNHWCInt8", input_tensors_[0], output_tensors_[0], out_shape[0],
|
||||
out_shape[1] * out_shape[2], out_shape[3]);
|
||||
} else if (param_->perm_[0] == 0 && param_->perm_[1] == 3 && param_->perm_[2] == 1 && param_->perm_[3] == 2) {
|
||||
code.CodeFunction("PackNCHWToNHWCInt8", input_tensors_[0], output_tensors_[0], out_shape[0],
|
||||
out_shape[2] * out_shape[3], out_shape[1]);
|
||||
}
|
||||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_Transpose, CPUOpCoderCreator<TransposeInt8Coder>)
|
||||
} // namespace mindspore::lite::micro::nnacl
|
|
@ -13,25 +13,28 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_NHWC2NCHW_FP32_CODER_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_NHWC2NCHW_FP32_CODER_H_
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_TRANSPOSE_INT8_CODER_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_TRANSPOSE_INT8_CODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
#include "nnacl/base/tile_base.h"
|
||||
#include "nnacl/transpose.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
class Nhwc2NchwFP32Coder final : public OperatorCoder {
|
||||
class TransposeInt8Coder final : public OperatorCoder {
|
||||
public:
|
||||
Nhwc2NchwFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
TransposeInt8Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const Model::Node *node, size_t node_index, Target target)
|
||||
: OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
|
||||
~Nhwc2NchwFP32Coder() override = default;
|
||||
|
||||
~TransposeInt8Coder() override = default;
|
||||
|
||||
int Prepare(CoderContext *const context) override;
|
||||
|
||||
int DoCode(CoderContext *const context) override;
|
||||
|
||||
private:
|
||||
TransposeParameter *param_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_NHWC2NCHW_FP32_CODER_H_
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_TRANSPOSE_INT8_CODER_H_
|
|
@ -51,7 +51,9 @@ std::unique_ptr<OperatorCoder> OpCoderBuilder::build(int schema_version) {
|
|||
op_coder->set_output_tensor_indices(output_indices_);
|
||||
int thread_num = support_parallel_ ? kMaxThreadNumSupported : 1;
|
||||
op_coder->set_thread_num(thread_num);
|
||||
if (primitive_type != schema::PrimitiveType_Custom) {
|
||||
parameter_->thread_num_ = thread_num;
|
||||
}
|
||||
op_coder->set_parameter(parameter_);
|
||||
op_coder->set_type(primitive_type);
|
||||
return op_coder;
|
||||
|
|
|
@ -279,8 +279,14 @@ int CoderSession::CreateOpCoders() {
|
|||
MS_LOG(ERROR) << "node: " << node->name_ << "has no outputs tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
OpParameter *parameter = GenParameterAndInfer(node, inputs, &outputs);
|
||||
|
||||
OpParameter *parameter = nullptr;
|
||||
if (lite::KernelInferShape(inputs, outputs, node->primitive_, std::set<std::string>{}, schema_version_) ==
|
||||
lite::RET_NOT_SUPPORT) { // custom op infer
|
||||
parameter = GenParameterAndInfer(node, inputs, &outputs); // general ops infer
|
||||
MS_CHECK_PTR(parameter);
|
||||
}
|
||||
|
||||
TypeId tensor_data_type = inputs.at(0)->data_type();
|
||||
std::unique_ptr<OperatorCoder> op_coder = builder.inputs(inputs)
|
||||
.outputs(outputs)
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/user_registry/nnie_infer.h"
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/api/format.h"
|
||||
#include "include/registry/register_kernel_interface.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using mindspore::kernel::KernelInterface;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Custom;
|
||||
#define MAX_SIZE 1024
|
||||
|
||||
namespace mindspore {
|
||||
namespace nnie {
|
||||
std::shared_ptr<KernelInterface> CustomInferCreater() {
|
||||
auto infer = new (std::nothrow) CustomInterface();
|
||||
if (infer == nullptr) {
|
||||
MS_LOG(ERROR) << "new custom infer is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
return std::shared_ptr<KernelInterface>(infer);
|
||||
}
|
||||
|
||||
int GetCustomShape(const mindspore::schema::Custom *op, const std::string &attr,
|
||||
std::vector<std::vector<int64_t>> *shapes) {
|
||||
char buf[MAX_SIZE];
|
||||
bool has_outputs_shape = false;
|
||||
|
||||
for (size_t i = 0; i < op->attr()->size(); i++) {
|
||||
if (op->attr()->Get(i)->name()->str() == attr) {
|
||||
auto output_info = op->attr()->Get(i)->data();
|
||||
int attr_size = static_cast<int>(output_info->size());
|
||||
if (attr_size >= MAX_SIZE) {
|
||||
MS_LOG(ERROR) << "attr size too big";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int j = 0; j < attr_size; j++) {
|
||||
buf[j] = static_cast<char>(output_info->Get(j));
|
||||
}
|
||||
buf[attr_size] = 0;
|
||||
has_outputs_shape = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_outputs_shape) {
|
||||
MS_LOG(ERROR) << "Custom op don't have " << attr.c_str() << " attr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
char delims[] = ",";
|
||||
char *res = nullptr;
|
||||
char *save_ptr = nullptr;
|
||||
res = strtok_r(buf, delims, &save_ptr);
|
||||
while (res != nullptr) {
|
||||
// 待补完
|
||||
// outputs[id]->format_ = input->format_;
|
||||
// outputs[id]->data_type_ = kNumberTypeFloat32;
|
||||
int64_t ndims = strtol(res, &res, 10);
|
||||
int j = 0;
|
||||
std::vector<int64_t> shape;
|
||||
shape.resize(ndims);
|
||||
for (; j < ndims; j++) {
|
||||
res = strtok_r(NULL, delims, &save_ptr);
|
||||
shape[j] = static_cast<int64_t>(strtol(res, &res, 10));
|
||||
}
|
||||
shapes->push_back(shape);
|
||||
|
||||
res = strtok_r(NULL, delims, &save_ptr);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
Status CustomInterface::Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const mindspore::schema::Primitive *primitive) {
|
||||
if (inputs->empty()) {
|
||||
MS_LOG(ERROR) << "Inputs size 0";
|
||||
return kLiteError;
|
||||
}
|
||||
if (outputs->empty()) {
|
||||
MS_LOG(ERROR) << "Outputs size 0";
|
||||
return kLiteError;
|
||||
}
|
||||
if (primitive->value_type() != mindspore::schema::PrimitiveType_Custom) {
|
||||
MS_LOG(ERROR) << "Primitive type is not PrimitiveType_Custom";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
auto op = primitive->value_as_Custom();
|
||||
if (op->attr()->size() < 1) {
|
||||
MS_LOG(ERROR) << "There are at least 1 attribute of Custom";
|
||||
return kLiteError;
|
||||
}
|
||||
std::vector<std::vector<int64_t>> inputs_shape;
|
||||
if (GetCustomShape(op, "inputs_shape", &inputs_shape) != RET_OK) {
|
||||
MS_LOG(ERROR) << "parser inputs_shape attribute err.";
|
||||
return kLiteError;
|
||||
}
|
||||
std::vector<std::vector<int64_t>> outputs_shape;
|
||||
if (GetCustomShape(op, "outputs_shape", &outputs_shape) != RET_OK) {
|
||||
MS_LOG(ERROR) << "parser outputs_shape attribute err.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (inputs_shape.size() != (inputs->size() - 1)) {
|
||||
MS_LOG(ERROR) << "inputs num diff inputs_shape num.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (inputs_shape[0].size() != (*inputs)[0].Shape().size()) {
|
||||
MS_LOG(ERROR) << "shape size err.";
|
||||
return kLiteError;
|
||||
}
|
||||
bool resize_flag = false;
|
||||
int resize_num = 1;
|
||||
for (size_t i = 0; i < inputs_shape[0].size(); i++) {
|
||||
if (inputs_shape[0][i] != (*inputs)[0].Shape()[i]) {
|
||||
if (i == 0) {
|
||||
resize_flag = true;
|
||||
resize_num = (*inputs)[0].Shape()[i];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Custom of NNIE only support batch_num resize.";
|
||||
return kLiteError;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (resize_flag) {
|
||||
for (auto &output_shape : outputs_shape) {
|
||||
output_shape[0] = resize_num;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < outputs->size(); i++) {
|
||||
(*outputs)[i].SetShape(outputs_shape[i]);
|
||||
(*outputs)[i].SetDataType(DataType::kNumberTypeFloat32);
|
||||
(*outputs)[i].SetFormat(Format::NCHW);
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace nnie
|
||||
} // namespace mindspore
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
REGISTER_CUSTOM_KERNEL_INTERFACE(NNIE, NNIE, nnie::CustomInferCreater);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_CUSTOM_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_NNACL_CUSTOM_PARAMETER_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/kernel_interface.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace nnie {
|
||||
class CustomInterface : public mindspore::kernel::KernelInterface {
|
||||
public:
|
||||
CustomInterface() {}
|
||||
|
||||
~CustomInterface() = default;
|
||||
|
||||
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const mindspore::schema::Primitive *primitive) override;
|
||||
};
|
||||
} // namespace nnie
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_NNACL_CUSTOM_PARAMETER_H_
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/user_registry/user_kernel_register.h"
|
||||
#include "schema/ops_generated.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_Custom;
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
REG_USER_KERNEL(kARM32A, kNumberTypeInt8, PrimitiveType_Custom, nnie_micro.h, NnieKernel, nniemicro)
|
||||
REG_USER_KERNEL(kARM32A, kNumberTypeUInt8, PrimitiveType_Custom, nnie_micro.h, NnieKernel, nniemicro)
|
||||
REG_USER_KERNEL(kARM32A, kNumberTypeFloat32, PrimitiveType_Custom, nnie_micro.h, NnieKernel, nniemicro)
|
||||
} // namespace mindspore::lite::micro
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/user_registry/user_kernel_register.h"
|
||||
#include <set>
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
UserKernelFactory *UserKernelFactory::GetInstance() {
|
||||
static UserKernelFactory reg;
|
||||
return ®
|
||||
}
|
||||
|
||||
int UserKernelFactory::RegistUserKernel(Target target, TypeId data_type, schema::PrimitiveType operator_type,
|
||||
const std::string &header, const std::string &func, const std::string &lib) {
|
||||
CoderKey key(target, data_type, operator_type);
|
||||
if (user_kernel_sets_.find(key) != user_kernel_sets_.end()) {
|
||||
MS_LOG(ERROR) << "Register the user kernel multiple times: " << key.ToString();
|
||||
return RET_ERROR;
|
||||
}
|
||||
user_kernel_sets_[key] = {header, func, lib};
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::vector<std::string> UserKernelFactory::FindUserKernel(const CoderKey &key) {
|
||||
if (user_kernel_sets_.find(key) != user_kernel_sets_.end()) {
|
||||
return user_kernel_sets_[key];
|
||||
} else {
|
||||
return std::vector<std::string>{};
|
||||
}
|
||||
}
|
||||
|
||||
std::set<std::string> UserKernelFactory::UserKernelLibNames() {
|
||||
std::set<std::string> names;
|
||||
for (auto it = user_kernel_sets_.begin(); it != user_kernel_sets_.end(); ++it) {
|
||||
names.insert(it->second[2]);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
UserKernelRegister::UserKernelRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type,
|
||||
const std::string &header, const std::string &func, const std::string &lib) {
|
||||
UserKernelFactory::GetInstance()->RegistUserKernel(target, data_type, operator_type, header, func, lib);
|
||||
}
|
||||
} // namespace mindspore::lite::micro
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_USER_KERNEL_REGISTER_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_USER_KERNEL_REGISTER_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "coder/config.h"
|
||||
#include "coder/opcoders/op_coder_register.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "schema/ops_generated.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
class UserKernelFactory {
|
||||
public:
|
||||
UserKernelFactory() = default;
|
||||
|
||||
static UserKernelFactory *GetInstance();
|
||||
|
||||
int RegistUserKernel(Target target, TypeId data_type, schema::PrimitiveType operator_type, const std::string &header,
|
||||
const std::string &func, const std::string &lib);
|
||||
|
||||
std::vector<std::string> FindUserKernel(const CoderKey &key);
|
||||
|
||||
std::set<std::string> UserKernelLibNames();
|
||||
|
||||
~UserKernelFactory() { user_kernel_sets_.clear(); }
|
||||
|
||||
private:
|
||||
std::map<CoderKey, std::vector<std::string>> user_kernel_sets_;
|
||||
};
|
||||
|
||||
class UserKernelRegister {
|
||||
public:
|
||||
UserKernelRegister() = delete;
|
||||
|
||||
UserKernelRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, const std::string &header,
|
||||
const std::string &func, const std::string &lib);
|
||||
|
||||
~UserKernelRegister() = default;
|
||||
};
|
||||
|
||||
#define REG_USER_KERNEL(target, data_type, operator_type, header, func, lib) \
|
||||
static UserKernelRegister g_user_##target##data_type##operator_type##function(target, data_type, operator_type, \
|
||||
#header, #func, #lib);
|
||||
} // namespace mindspore::lite::micro
|
||||
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_USER_KERNEL_REGISTER_H_
|
Loading…
Reference in New Issue