forked from mindspore-Ecosystem/mindspore
support get config and attr from kernel
This commit is contained in:
parent
010cc7a435
commit
8acc951a47
|
@ -19,13 +19,14 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/context.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
/// \brief The Kernel class is used to define a MindSpore Kernel.
|
||||
class Kernel {
|
||||
class MS_API Kernel {
|
||||
public:
|
||||
Kernel() = default;
|
||||
/// \brief Constructor.
|
||||
|
@ -37,9 +38,7 @@ class Kernel {
|
|||
Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
|
||||
const schema::Primitive *primitive, const mindspore::Context *ctx)
|
||||
: context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) {
|
||||
if (primitive != nullptr) {
|
||||
type_ = primitive->value_type();
|
||||
}
|
||||
Initialize();
|
||||
}
|
||||
/// \brief Destructor.
|
||||
virtual ~Kernel() = default;
|
||||
|
@ -102,6 +101,44 @@ class Kernel {
|
|||
/// \return the primitive of kernel generated by flatbuffers.
|
||||
const schema::Primitive *primitive() const { return this->primitive_; }
|
||||
|
||||
/// \brief get kernel's attribute.
|
||||
///
|
||||
/// \param[in] key define the kernel's attribute key.
|
||||
std::string GetAttr(const std::string &key) const {
|
||||
auto iter = attrs_.find(key);
|
||||
if (iter != attrs_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
/// \brief set kernel's config.
|
||||
///
|
||||
/// \param[in] config define the kernel's config.
|
||||
void SetConfig(const std::map<std::string, std::map<std::string, std::string>> *config) {
|
||||
config_ = config;
|
||||
}
|
||||
/// \brief set kernel's config.
|
||||
///
|
||||
/// \param[in] config define the kernel's config.
|
||||
std::map<std::string, std::string> GetConfig(const std::string §ion) const {
|
||||
if (config_ == nullptr) {
|
||||
return std::map<std::string, std::string>();
|
||||
}
|
||||
auto iter = config_->find(section);
|
||||
if (iter != config_->end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return std::map<std::string, std::string>();
|
||||
}
|
||||
|
||||
protected:
|
||||
/// \brief set kernel's attribute
|
||||
///
|
||||
/// \param[in] key define the kernel's attribute key.
|
||||
/// \param[in] value define the kernel's attribute value.
|
||||
void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; }
|
||||
|
||||
protected:
|
||||
std::string name_;
|
||||
const mindspore::Context *context_ = nullptr;
|
||||
|
@ -109,6 +146,11 @@ class Kernel {
|
|||
std::vector<mindspore::MSTensor> outputs_;
|
||||
schema::PrimitiveType type_ = schema::PrimitiveType_NONE;
|
||||
const schema::Primitive *primitive_ = nullptr;
|
||||
std::map<std::string, std::string> attrs_;
|
||||
const std::map<std::string, std::map<std::string, std::string>> *config_;
|
||||
|
||||
private:
|
||||
void Initialize();
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -106,6 +106,14 @@ class MS_API Model {
|
|||
/// \return Status.
|
||||
inline Status LoadConfig(const std::string &config_path);
|
||||
|
||||
/// \brief Update config.
|
||||
///
|
||||
/// \param[in] section define the config section.
|
||||
/// \param[in] config define the config will be updated.
|
||||
///
|
||||
/// \return Status.
|
||||
inline Status UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config);
|
||||
|
||||
/// \brief Obtains all input tensors of the model.
|
||||
///
|
||||
/// \return The vector that includes all input tensors.
|
||||
|
@ -215,6 +223,7 @@ class MS_API Model {
|
|||
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
|
||||
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
|
||||
Status LoadConfig(const std::vector<char> &config_path);
|
||||
Status UpdateConfig(const std::vector<char> §ion, const std::pair<std::vector<char>, std::vector<char>> &config);
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode);
|
||||
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||
|
@ -241,6 +250,12 @@ Status Model::LoadConfig(const std::string &config_path) {
|
|||
return LoadConfig(StringToChar(config_path));
|
||||
}
|
||||
|
||||
Status Model::UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config) {
|
||||
std::pair<std::vector<char>, std::vector<char>> config_pair = {StringToChar(config.first),
|
||||
StringToChar(config.second)};
|
||||
return UpdateConfig(StringToChar(section), config_pair);
|
||||
}
|
||||
|
||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
|
||||
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode));
|
||||
|
|
|
@ -71,7 +71,7 @@ class MS_API RegisterKernel {
|
|||
///
|
||||
/// \return Status as a status identification of registering.
|
||||
inline static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||
CreateKernel creator);
|
||||
const CreateKernel creator);
|
||||
|
||||
/// \brief Static method to register kernel which is corresponding to custom op.
|
||||
///
|
||||
|
@ -83,7 +83,7 @@ class MS_API RegisterKernel {
|
|||
///
|
||||
/// \return Status as a status identification of registering.
|
||||
inline static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, CreateKernel creator);
|
||||
const std::string &type, const CreateKernel creator);
|
||||
|
||||
/// \brief Static methon to get a kernel's create function.
|
||||
///
|
||||
|
@ -95,9 +95,9 @@ class MS_API RegisterKernel {
|
|||
|
||||
private:
|
||||
static Status RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
|
||||
int type, CreateKernel creator);
|
||||
int type, const CreateKernel creator);
|
||||
static Status RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
|
||||
const std::vector<char> &type, CreateKernel creator);
|
||||
const std::vector<char> &type, const CreateKernel creator);
|
||||
static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc);
|
||||
};
|
||||
|
||||
|
@ -115,7 +115,7 @@ class MS_API KernelReg {
|
|||
/// \param[in] op_type Define the ordinary op type.
|
||||
/// \param[in] creator Define a function pointer to create a kernel.
|
||||
KernelReg(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
|
||||
CreateKernel creator) {
|
||||
const CreateKernel creator) {
|
||||
RegisterKernel::RegKernel(arch, provider, data_type, op_type, creator);
|
||||
}
|
||||
|
||||
|
@ -127,18 +127,18 @@ class MS_API KernelReg {
|
|||
/// \param[in] op_type Define the concrete type of a custom op.
|
||||
/// \param[in] creator Define a function pointer to create a kernel.
|
||||
KernelReg(const std::string &arch, const std::string &provider, DataType data_type, const std::string &op_type,
|
||||
CreateKernel creator) {
|
||||
const CreateKernel creator) {
|
||||
RegisterKernel::RegCustomKernel(arch, provider, data_type, op_type, creator);
|
||||
}
|
||||
};
|
||||
|
||||
Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||
CreateKernel creator) {
|
||||
const CreateKernel creator) {
|
||||
return RegKernel(StringToChar(arch), StringToChar(provider), data_type, type, creator);
|
||||
}
|
||||
|
||||
Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
const std::string &type, const CreateKernel creator) {
|
||||
return RegCustomKernel(StringToChar(arch), StringToChar(provider), data_type, StringToChar(type), creator);
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,9 @@
|
|||
#include "schema/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class Kernel;
|
||||
}
|
||||
namespace registry {
|
||||
/// \brief KernelInterfaceCreator defined a functor to create KernelInterface.
|
||||
using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;
|
||||
|
@ -40,7 +43,7 @@ class MS_API RegisterKernelInterface {
|
|||
///
|
||||
/// \return Status as a status identification of registering.
|
||||
inline static Status CustomReg(const std::string &provider, const std::string &op_type,
|
||||
KernelInterfaceCreator creator);
|
||||
const KernelInterfaceCreator creator);
|
||||
|
||||
/// \brief Static method to register op whose primitive type is ordinary.
|
||||
///
|
||||
|
@ -49,23 +52,26 @@ class MS_API RegisterKernelInterface {
|
|||
/// \param[in] creator Define the KernelInterface create function.
|
||||
///
|
||||
/// \return Status as a status identification of registering.
|
||||
inline static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
|
||||
inline static Status Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator);
|
||||
|
||||
/// \brief Static method to get registration of a certain op.
|
||||
///
|
||||
/// \param[in] provider Define the identification of user.
|
||||
/// \param[in] primitive Define the attributes of a certain op.
|
||||
/// \param[in] kernel Define the kernel of a certain op.
|
||||
///
|
||||
/// \return Boolean value to represent registration of a certain op is existing or not.
|
||||
inline static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
|
||||
const schema::Primitive *primitive);
|
||||
const schema::Primitive *primitive,
|
||||
const kernel::Kernel *kernel = nullptr);
|
||||
|
||||
private:
|
||||
static Status CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
|
||||
KernelInterfaceCreator creator);
|
||||
static Status Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator);
|
||||
const KernelInterfaceCreator creator);
|
||||
static Status Reg(const std::vector<char> &provider, int op_type, const KernelInterfaceCreator creator);
|
||||
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::vector<char> &provider,
|
||||
const schema::Primitive *primitive);
|
||||
const schema::Primitive *primitive,
|
||||
const kernel::Kernel *kernel = nullptr);
|
||||
};
|
||||
|
||||
/// \brief KernelInterfaceReg defined registration class of KernelInterface.
|
||||
|
@ -76,7 +82,7 @@ class MS_API KernelInterfaceReg {
|
|||
/// \param[in] provider Define the identification of user.
|
||||
/// \param[in] op_type Define the ordinary op type.
|
||||
/// \param[in] creator Define the KernelInterface create function.
|
||||
KernelInterfaceReg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
KernelInterfaceReg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
|
||||
RegisterKernelInterface::Reg(provider, op_type, creator);
|
||||
}
|
||||
|
||||
|
@ -85,23 +91,26 @@ class MS_API KernelInterfaceReg {
|
|||
/// \param[in] provider Define the identification of user.
|
||||
/// \param[in] op_type Define the concrete type of a custom op.
|
||||
/// \param[in] creator Define the KernelInterface create function.
|
||||
KernelInterfaceReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator) {
|
||||
KernelInterfaceReg(const std::string &provider, const std::string &op_type, const KernelInterfaceCreator creator) {
|
||||
RegisterKernelInterface::CustomReg(provider, op_type, creator);
|
||||
}
|
||||
|
||||
virtual ~KernelInterfaceReg() = default;
|
||||
};
|
||||
|
||||
Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
|
||||
KernelInterfaceCreator creator) {
|
||||
const KernelInterfaceCreator creator) {
|
||||
return CustomReg(StringToChar(provider), StringToChar(op_type), creator);
|
||||
}
|
||||
|
||||
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
|
||||
return Reg(StringToChar(provider), op_type, creator);
|
||||
}
|
||||
|
||||
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
|
||||
const std::string &provider, const schema::Primitive *primitive) {
|
||||
return GetKernelInterface(StringToChar(provider), primitive);
|
||||
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::string &provider,
|
||||
const schema::Primitive *primitive,
|
||||
const kernel::Kernel *kernel) {
|
||||
return GetKernelInterface(StringToChar(provider), primitive, kernel);
|
||||
}
|
||||
|
||||
/// \brief Defined registering macro to register ordinary op, which called by user directly.
|
||||
|
|
|
@ -21,19 +21,55 @@
|
|||
#endif
|
||||
namespace {
|
||||
constexpr size_t kLengthOfParentheses = 2;
|
||||
}
|
||||
constexpr size_t kMinSectionLineLength = 2;
|
||||
constexpr size_t kMaxValidLineCount = 100000;
|
||||
constexpr size_t kMaxLineCount = 100100;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int GetSectionInfoFromConfigFile(const std::string &file, const std::string §ion_name,
|
||||
std::map<std::string, std::string> *section_info) {
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "file is nullptr";
|
||||
namespace {
|
||||
void ParseLine(const std::string &line, std::map<std::string, std::string> *section_config, std::string *section,
|
||||
size_t *valid_line_count, std::map<std::string, std::map<std::string, std::string>> *config) {
|
||||
// eg: [section]
|
||||
// key=value
|
||||
if (line[0] == '[' && line[line.length() - 1] == ']') {
|
||||
if (!section->empty() && !section_config->empty()) {
|
||||
config->insert(std::make_pair(*section, *section_config));
|
||||
}
|
||||
section_config->clear();
|
||||
*section = line.substr(1, line.length() - kLengthOfParentheses);
|
||||
*valid_line_count = *valid_line_count + 1;
|
||||
}
|
||||
|
||||
if (!section->empty()) {
|
||||
auto index = line.find('=');
|
||||
if (index == std::string::npos) {
|
||||
return;
|
||||
}
|
||||
auto key = line.substr(0, index);
|
||||
if (index + 1 > line.size()) {
|
||||
return;
|
||||
}
|
||||
auto value = line.substr(index + 1);
|
||||
lite::Trim(&key);
|
||||
lite::Trim(&value);
|
||||
section_config->insert(std::make_pair(key, value));
|
||||
*valid_line_count = *valid_line_count + 1;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int GetAllSectionInfoFromConfigFile(const std::string &file,
|
||||
std::map<std::string, std::map<std::string, std::string>> *config) {
|
||||
if (file.empty() || config == nullptr) {
|
||||
MS_LOG(ERROR) << "input Invalid!check file and config.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto resolved_path = std::make_unique<char[]>(PATH_MAX);
|
||||
if (resolved_path == nullptr) {
|
||||
MS_LOG(ERROR) << "new resolved_path failed";
|
||||
MS_LOG(ERROR) << "new resolved_path fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
|
@ -56,44 +92,25 @@ int GetSectionInfoFromConfigFile(const std::string &file, const std::string &sec
|
|||
return RET_ERROR;
|
||||
}
|
||||
std::string line;
|
||||
|
||||
bool find_section = false;
|
||||
std::string section;
|
||||
std::map<std::string, std::string> section_config;
|
||||
size_t line_count = 0;
|
||||
size_t valid_line_count = 0;
|
||||
while (std::getline(ifs, line)) {
|
||||
line_count++;
|
||||
if (line_count >= kMaxLineCount || valid_line_count >= kMaxValidLineCount) {
|
||||
MS_LOG(ERROR) << "config too many lines!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
lite::Trim(&line);
|
||||
if (line.empty()) {
|
||||
if (line.length() <= kMinSectionLineLength || line[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
if (line[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (line[0] == '[') {
|
||||
if (find_section == true) {
|
||||
break;
|
||||
}
|
||||
std::string section = line.substr(1, line.length() - kLengthOfParentheses);
|
||||
if (section != section_name) {
|
||||
continue;
|
||||
}
|
||||
find_section = true;
|
||||
}
|
||||
|
||||
if (find_section == true) {
|
||||
auto index = line.find('=');
|
||||
if (index == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
auto key = line.substr(0, index);
|
||||
if (index + 1 > line.size()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto value = line.substr(index + 1);
|
||||
lite::Trim(&key);
|
||||
lite::Trim(&value);
|
||||
section_info->insert(std::make_pair(key, value));
|
||||
}
|
||||
ParseLine(line, §ion_config, §ion, &valid_line_count, config);
|
||||
}
|
||||
if (!section.empty() && !section_config.empty()) {
|
||||
config->insert(std::make_pair(section, section_config));
|
||||
}
|
||||
|
||||
ifs.close();
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -35,10 +35,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr int MAX_CONFIG_FILE_LENGTH = 1024;
|
||||
#define CONFIG_FILE_EXECUTION_PLAN "execution_plan"
|
||||
|
||||
int GetSectionInfoFromConfigFile(const std::string &file, const std::string §ion_name,
|
||||
std::map<std::string, std::string> *section_info);
|
||||
int GetAllSectionInfoFromConfigFile(const std::string &file,
|
||||
std::map<std::string, std::map<std::string, std::string>> *config);
|
||||
|
||||
void ParserExecutionPlan(const std::map<std::string, std::string> *config_infos,
|
||||
std::map<std::string, TypeId> *data_type_plan);
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/kernel.h"
|
||||
namespace mindspore::kernel {
|
||||
void Kernel::Initialize() {
|
||||
if (primitive_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
type_ = primitive_->value_type();
|
||||
if (type_ == schema::PrimitiveType_Custom) {
|
||||
auto param = primitive_->value_as_Custom();
|
||||
if (param != nullptr && param->type() != nullptr) {
|
||||
SetAttr("type", param->type()->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -209,6 +209,19 @@ Status Model::LoadConfig(const std::vector<char> &config_path) {
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status Model::UpdateConfig(const std::vector<char> §ion,
|
||||
const std::pair<std::vector<char>, std::vector<char>> &config) {
|
||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||
if (impl_ == nullptr) {
|
||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
}
|
||||
if (impl_ != nullptr) {
|
||||
return impl_->UpdateConfig(CharToString(section), {CharToString(config.first), CharToString(config.second)});
|
||||
}
|
||||
MS_LOG(ERROR) << "Model implement is null!";
|
||||
return kLiteFileError;
|
||||
}
|
||||
|
||||
Status Model::SetTrainMode(bool train) {
|
||||
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
|
||||
MS_LOG(ERROR) << "Model is null.";
|
||||
|
|
|
@ -17,6 +17,10 @@
|
|||
#include "src/cxx_api/model/model_impl.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/lite_session.h"
|
||||
|
@ -32,6 +36,11 @@
|
|||
#include "src/common/config_file.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
static const char *kExecutionPlan = "execution_plan";
|
||||
static constexpr size_t kMaxSectionNum = 100;
|
||||
static constexpr size_t kMaxConfigNumPerSection = 1000;
|
||||
} // namespace
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
|
@ -195,15 +204,16 @@ Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBac
|
|||
bool ModelImpl::IsTrainModel() { return (graph_ && graph_->graph_data_ && graph_->graph_data_->IsTrainModel()); }
|
||||
|
||||
Status ModelImpl::LoadConfig(const std::string &config_path) {
|
||||
std::map<std::string, std::string> config_info;
|
||||
int ret = lite::GetSectionInfoFromConfigFile(config_path, CONFIG_FILE_EXECUTION_PLAN, &config_info);
|
||||
std::map<std::string, std::map<std::string, std::string>> all_config_info;
|
||||
int ret = lite::GetAllSectionInfoFromConfigFile(config_path, &all_config_info);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GetSectionInfoFromConfigFile failed.";
|
||||
MS_LOG(ERROR) << "GetAllSectionInfoFromConfigFile fail!ret: " << ret;
|
||||
return kLiteFileError;
|
||||
}
|
||||
|
||||
config_info_ = all_config_info;
|
||||
std::map<std::string, std::string> config_info = all_config_info[kExecutionPlan];
|
||||
if (config_info.empty()) {
|
||||
MS_LOG(WARNING) << "No valid info in config file.";
|
||||
MS_LOG(WARNING) << "No valid execution plan info in config file.";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
|
@ -211,6 +221,24 @@ Status ModelImpl::LoadConfig(const std::string &config_path) {
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelImpl::UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config) {
|
||||
auto iter = config_info_.find(section);
|
||||
if (iter == config_info_.end()) {
|
||||
if (config_info_.size() >= kMaxSectionNum) {
|
||||
MS_LOG(ERROR) << "config too many sections!";
|
||||
return kLiteError;
|
||||
}
|
||||
config_info_[section][config.first] = config.second;
|
||||
return kSuccess;
|
||||
}
|
||||
if (iter->second.size() >= kMaxConfigNumPerSection) {
|
||||
MS_LOG(ERROR) << "config too many items!";
|
||||
return kLiteError;
|
||||
}
|
||||
iter->second[config.first] = config.second;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (outputs == nullptr) {
|
||||
|
@ -567,6 +595,7 @@ session::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context)
|
|||
}
|
||||
|
||||
session->InitExecutionConfig(&execution_plan_);
|
||||
session->SetConfigInfo(&config_info_);
|
||||
|
||||
auto ret = session->Init(context);
|
||||
if (ret != mindspore::lite::RET_OK) {
|
||||
|
|
|
@ -70,6 +70,7 @@ class ModelImpl {
|
|||
session::LiteSession *CreateLiteSession(lite::InnerContext *context);
|
||||
|
||||
Status LoadConfig(const std::string &config_path);
|
||||
Status UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config);
|
||||
std::vector<MSTensor> GetInputs();
|
||||
std::vector<MSTensor> GetOutputs();
|
||||
std::vector<MSTensor> GetGradients() const;
|
||||
|
@ -112,6 +113,7 @@ class ModelImpl {
|
|||
void SetConfig(const std::shared_ptr<TrainCfg> cfg) { cfg_ = cfg; }
|
||||
Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
|
||||
std::map<std::string, TypeId> execution_plan_;
|
||||
std::map<std::string, std::map<std::string, std::string>> config_info_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -523,6 +523,7 @@ int LiteSession::CompileGraph(Model *model) {
|
|||
Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, &is_infershape_,
|
||||
&is_control_flow_, execution_plan_, delegate_, delegate_device_type_);
|
||||
scheduler.SetupSchedulerCb(std::move(sched_cb_));
|
||||
scheduler.SetConfig(config_info_);
|
||||
ret = scheduler.Schedule(&kernels_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Schedule kernels failed: " << ret;
|
||||
|
|
|
@ -87,6 +87,10 @@ class LiteSession : public session::LiteSession {
|
|||
|
||||
const Delegate *get_delegate() const { return this->delegate_.get(); }
|
||||
|
||||
void SetConfigInfo(const std::map<std::string, std::map<std::string, std::string>> *config_info) {
|
||||
config_info_ = config_info;
|
||||
}
|
||||
|
||||
protected:
|
||||
static void ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lite::Tensor *dst_tensor);
|
||||
|
||||
|
@ -182,6 +186,7 @@ class LiteSession : public session::LiteSession {
|
|||
std::shared_ptr<Delegate> delegate_ = nullptr;
|
||||
int delegate_device_type_ = -1; // -1: not specified; 0: CPU; 1: GPU; 2: NPU
|
||||
std::map<std::string, TypeId> *execution_plan_ = nullptr;
|
||||
const std::map<std::string, std::map<std::string, std::string>> *config_info_ = nullptr;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -70,14 +70,14 @@ OpParameter *PopulateAffineParameter(const void *prim) {
|
|||
affine_param->context_size_ = static_cast<int>(context.size());
|
||||
|
||||
// malloc && memset for context
|
||||
affine_param->context_ = reinterpret_cast<int *>(malloc(affine_param->context_size_ * sizeof(int)));
|
||||
affine_param->context_ = reinterpret_cast<int *>(malloc(context.size() * sizeof(int)));
|
||||
if (affine_param->context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc param context_ for affine layer failed!";
|
||||
ReleaseParam(affine_param, matmul_param);
|
||||
return nullptr;
|
||||
}
|
||||
memset(affine_param->context_, 0, affine_param->context_size_ * sizeof(int));
|
||||
for (int i = 0; i < affine_param->context_size_; ++i) {
|
||||
(void)memset(affine_param->context_, 0, context.size() * sizeof(int));
|
||||
for (size_t i = 0; i < context.size(); ++i) {
|
||||
affine_param->context_[i] = context.at(i);
|
||||
}
|
||||
affine_param->output_dim_ = value->output_dim();
|
||||
|
|
|
@ -43,8 +43,8 @@ OpParameter *PopulateTensorArrayParameter(const void *prim) {
|
|||
bool identical_element_shapes = value->identical_element_shapes();
|
||||
param->identical_element_shapes_ = identical_element_shapes;
|
||||
std::vector<int> primitive_element_shape(value->element_shape()->begin(), value->element_shape()->end());
|
||||
param->element_shape_size_ = primitive_element_shape.size();
|
||||
int size = sizeof(int) * param->element_shape_size_;
|
||||
param->element_shape_size_ = static_cast<int>(primitive_element_shape.size());
|
||||
auto size = sizeof(int) * param->element_shape_size_;
|
||||
param->element_shape_ = static_cast<int *>(malloc(size));
|
||||
if (param->element_shape_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc element_shape failed!";
|
||||
|
|
|
@ -52,7 +52,7 @@ OpParameter *PopulateSpliceParameter(const void *prim) {
|
|||
param->context_dim_ = static_cast<int>(primitive_context.size());
|
||||
|
||||
// malloc && memset for context
|
||||
param->context_ = reinterpret_cast<int *>(malloc(param->context_dim_ * sizeof(int)));
|
||||
param->context_ = reinterpret_cast<int *>(malloc(primitive_context.size() * sizeof(int)));
|
||||
if (param->context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc param context_ error";
|
||||
free(param);
|
||||
|
@ -60,8 +60,8 @@ OpParameter *PopulateSpliceParameter(const void *prim) {
|
|||
}
|
||||
// src_to_dst_row_offset
|
||||
int src_to_dst_row_offset = INT32_MIN;
|
||||
memset(param->context_, 0, param->context_dim_ * sizeof(int));
|
||||
for (int i = 0; i < param->context_dim_; ++i) {
|
||||
(void)memset(param->context_, 0, primitive_context.size() * sizeof(int));
|
||||
for (size_t i = 0; i < primitive_context.size(); ++i) {
|
||||
param->context_[i] = primitive_context[i];
|
||||
src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i)));
|
||||
}
|
||||
|
@ -83,15 +83,15 @@ OpParameter *PopulateSpliceParameter(const void *prim) {
|
|||
param->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size());
|
||||
|
||||
// malloc && memset for forward_indexes
|
||||
param->forward_indexes_ = reinterpret_cast<int *>(malloc(param->forward_indexes_dim_ * sizeof(int)));
|
||||
param->forward_indexes_ = reinterpret_cast<int *>(malloc(primitive_context.size() * sizeof(int)));
|
||||
if (param->forward_indexes_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc param forward_indexes_ error";
|
||||
free(param->context_);
|
||||
free(param);
|
||||
return nullptr;
|
||||
}
|
||||
memset(param->forward_indexes_, 0, param->forward_indexes_dim_ * sizeof(int));
|
||||
memcpy(param->forward_indexes_, primitive_forward_indexes.data(), param->forward_indexes_dim_ * sizeof(int));
|
||||
(void)memset(param->forward_indexes_, 0, primitive_context.size() * sizeof(int));
|
||||
(void)memcpy(param->forward_indexes_, primitive_forward_indexes.data(), primitive_context.size() * sizeof(int));
|
||||
param->output_dim_ = value->output_dim();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/version_manager.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/api/kernel.h"
|
||||
|
||||
using mindspore::registry::KernelInterfaceCreator;
|
||||
using mindspore::schema::PrimitiveType_MAX;
|
||||
|
@ -27,16 +28,33 @@ using mindspore::schema::PrimitiveType_MIN;
|
|||
namespace mindspore {
|
||||
namespace registry {
|
||||
namespace {
|
||||
static constexpr auto kMaxProviderNum = 10;
|
||||
static constexpr auto KMaxCustomTypeNum = 200;
|
||||
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1;
|
||||
std::string GetCustomType(const schema::Primitive *primitive) {
|
||||
auto param = primitive->value_as_Custom();
|
||||
MS_ASSERT(param != nullptr);
|
||||
if (param == nullptr || param->type() == nullptr) {
|
||||
return "";
|
||||
}
|
||||
|
||||
return param->type()->str();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
|
||||
KernelInterfaceCreator creator) {
|
||||
const KernelInterfaceCreator creator) {
|
||||
auto provider_iter = custom_creators_.find(provider);
|
||||
if (provider_iter == custom_creators_.end() && custom_creators_.size() >= kMaxProviderNum) {
|
||||
MS_LOG(ERROR) << "register too many provider!";
|
||||
return kLiteError;
|
||||
}
|
||||
if (provider_iter != custom_creators_.end()) {
|
||||
auto type_iter = provider_iter->second.find(type);
|
||||
if (type_iter == provider_iter->second.end() && provider_iter->second.size() >= KMaxCustomTypeNum) {
|
||||
MS_LOG(ERROR) << "register too many custom type!";
|
||||
return kLiteError;
|
||||
}
|
||||
}
|
||||
custom_creators_[provider][type] = creator;
|
||||
return kSuccess;
|
||||
}
|
||||
|
@ -73,15 +91,19 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCache
|
|||
}
|
||||
|
||||
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface(
|
||||
const schema::Primitive *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
const schema::Primitive *primitive, const kernel::Kernel *kernel) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto &&type = GetCustomType(primitive);
|
||||
std::string type;
|
||||
if (kernel == nullptr) {
|
||||
type = GetCustomType(primitive);
|
||||
} else {
|
||||
type = kernel->GetAttr("type");
|
||||
}
|
||||
for (auto &&item : custom_creators_) {
|
||||
auto &&provider = item.first;
|
||||
auto kernel = GetCustomCacheInterface(provider, type);
|
||||
if (kernel != nullptr) {
|
||||
return kernel;
|
||||
auto kernel_interface = GetCustomCacheInterface(provider, type);
|
||||
if (kernel_interface != nullptr) {
|
||||
return kernel_interface;
|
||||
}
|
||||
auto provider_iter = custom_creators_.find(provider);
|
||||
if (provider_iter == custom_creators_.end()) {
|
||||
|
@ -89,47 +111,54 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKerne
|
|||
}
|
||||
auto creator_iter = provider_iter->second.find(type);
|
||||
if (creator_iter != provider_iter->second.end()) {
|
||||
kernel = creator_iter->second();
|
||||
custom_kernels_[provider][type] = kernel;
|
||||
return kernel;
|
||||
kernel_interface = creator_iter->second();
|
||||
custom_kernels_[provider][type] = kernel_interface;
|
||||
return kernel_interface;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(
|
||||
const std::string &provider, const schema::Primitive *primitive) {
|
||||
if (primitive == nullptr) {
|
||||
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(const std::string &provider,
|
||||
const schema::Primitive *primitive,
|
||||
const kernel::Kernel *kernel) {
|
||||
if (primitive == nullptr && kernel == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
int op_type;
|
||||
if (kernel == nullptr) {
|
||||
op_type = static_cast<int>(primitive->value_type());
|
||||
} else {
|
||||
op_type = static_cast<int>(kernel->type());
|
||||
}
|
||||
if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) {
|
||||
return nullptr;
|
||||
}
|
||||
int op_type = primitive->value_type();
|
||||
if (op_type == schema::PrimitiveType_Custom) {
|
||||
return GetCustomKernelInterface(primitive);
|
||||
return GetCustomKernelInterface(primitive, kernel);
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto kernel = GetCacheInterface(provider, op_type);
|
||||
if (kernel != nullptr) {
|
||||
return kernel;
|
||||
auto kernel_interface = GetCacheInterface(provider, op_type);
|
||||
if (kernel_interface != nullptr) {
|
||||
return kernel_interface;
|
||||
}
|
||||
auto iter = kernel_creators_.find(provider);
|
||||
if (iter == kernel_creators_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto creator = iter->second[op_type];
|
||||
if (creator != nullptr) {
|
||||
kernel = creator();
|
||||
kernel_interfaces_[provider][op_type] = kernel;
|
||||
return kernel;
|
||||
kernel_interface = creator();
|
||||
kernel_interfaces_[provider][op_type] = kernel_interface;
|
||||
return kernel_interface;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
|
||||
if (op_type <= PrimitiveType_MIN || op_type > PrimitiveType_MAX) {
|
||||
MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << PrimitiveType_MAX;
|
||||
return kLiteParamInvalid;
|
||||
|
@ -142,6 +171,10 @@ Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, Ke
|
|||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto iter = kernel_creators_.find(provider);
|
||||
if (iter == kernel_creators_.end()) {
|
||||
if (kernel_creators_.size() >= kMaxProviderNum) {
|
||||
MS_LOG(ERROR) << "register too many provider!";
|
||||
return kLiteError;
|
||||
}
|
||||
kernel_creators_[provider] =
|
||||
reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator)));
|
||||
if (kernel_creators_[provider] == nullptr) {
|
||||
|
|
|
@ -35,9 +35,11 @@ class KernelInterfaceRegistry {
|
|||
}
|
||||
|
||||
std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
|
||||
const schema::Primitive *primitive);
|
||||
Status CustomReg(const std::string &provider, const std::string &op_type, registry::KernelInterfaceCreator creator);
|
||||
Status Reg(const std::string &provider, int op_type, registry::KernelInterfaceCreator creator);
|
||||
const schema::Primitive *primitive,
|
||||
const kernel::Kernel *kernel);
|
||||
Status CustomReg(const std::string &provider, const std::string &op_type,
|
||||
const registry::KernelInterfaceCreator creator);
|
||||
Status Reg(const std::string &provider, int op_type, const registry::KernelInterfaceCreator creator);
|
||||
virtual ~KernelInterfaceRegistry();
|
||||
|
||||
private:
|
||||
|
@ -45,7 +47,8 @@ class KernelInterfaceRegistry {
|
|||
std::shared_ptr<kernel::KernelInterface> GetCacheInterface(const std::string &provider, int op_type);
|
||||
std::shared_ptr<kernel::KernelInterface> GetCustomCacheInterface(const std::string &provider,
|
||||
const std::string &type);
|
||||
std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive);
|
||||
std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive,
|
||||
const kernel::Kernel *kernel);
|
||||
|
||||
std::mutex mutex_;
|
||||
// key: provider
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
namespace mindspore {
|
||||
namespace registry {
|
||||
Status RegisterKernel::RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider,
|
||||
DataType data_type, const std::vector<char> &type, CreateKernel creator) {
|
||||
DataType data_type, const std::vector<char> &type, const CreateKernel creator) {
|
||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||
return RegistryKernelImpl::GetInstance()->RegCustomKernel(CharToString(arch), CharToString(provider), data_type,
|
||||
CharToString(type), creator);
|
||||
|
@ -34,7 +34,7 @@ Status RegisterKernel::RegCustomKernel(const std::vector<char> &arch, const std:
|
|||
}
|
||||
|
||||
Status RegisterKernel::RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
|
||||
int op_type, CreateKernel creator) {
|
||||
int op_type, const CreateKernel creator) {
|
||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||
return RegistryKernelImpl::GetInstance()->RegKernel(CharToString(arch), CharToString(provider), data_type, op_type,
|
||||
creator);
|
||||
|
|
|
@ -25,15 +25,14 @@ using mindspore::schema::PrimitiveType_MAX;
|
|||
using mindspore::schema::PrimitiveType_MIN;
|
||||
namespace mindspore::registry {
|
||||
namespace {
|
||||
static const auto kKernelMaxNum =
|
||||
(static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1) *
|
||||
(PrimitiveType_MAX - PrimitiveType_MIN);
|
||||
static const auto kOpTypeLen = PrimitiveType_MAX - PrimitiveType_MIN + 1;
|
||||
static const auto kDataTypeLen =
|
||||
static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
|
||||
static const auto kOpTypeLen = PrimitiveType_MAX - PrimitiveType_MIN;
|
||||
} // namespace
|
||||
|
||||
int RegistryKernelImpl::GetFuncIndex(const KernelDesc &desc) {
|
||||
static const auto kKernelMaxNum = kOpTypeLen * kDataTypeLen;
|
||||
static constexpr auto kMaxProviderNum = 10;
|
||||
static constexpr auto kMaxArchPerProviderNum = 10;
|
||||
static constexpr auto kMaxCustomTypeNum = 200;
|
||||
int GetFuncIndex(const KernelDesc &desc) {
|
||||
if (desc.data_type >= DataType::kNumberTypeEnd) {
|
||||
return -1;
|
||||
}
|
||||
|
@ -47,14 +46,36 @@ int RegistryKernelImpl::GetFuncIndex(const KernelDesc &desc) {
|
|||
}
|
||||
return index;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
if (data_type >= DataType::kNumberTypeEnd) {
|
||||
const std::string &type, const CreateKernel creator) {
|
||||
int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
|
||||
if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
|
||||
MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
|
||||
return kLiteError;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto provider_iter = custom_kernel_creators_.find(provider);
|
||||
if (provider_iter == custom_kernel_creators_.end() && custom_kernel_creators_.size() >= kMaxProviderNum) {
|
||||
MS_LOG(ERROR) << "register too many provider!";
|
||||
return kLiteError;
|
||||
}
|
||||
if (provider_iter != custom_kernel_creators_.end()) {
|
||||
auto arch_iter = provider_iter->second.find(arch);
|
||||
if (arch_iter == provider_iter->second.end()) {
|
||||
if (provider_iter->second.size() >= kMaxArchPerProviderNum) {
|
||||
MS_LOG(ERROR) << "register too many arch!";
|
||||
return kLiteError;
|
||||
}
|
||||
} else {
|
||||
auto type_iter = arch_iter->second.find(type);
|
||||
if (type_iter == arch_iter->second.end() && arch_iter->second.size() >= kMaxCustomTypeNum) {
|
||||
MS_LOG(ERROR) << "register too many type!";
|
||||
return kLiteError;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (custom_kernel_creators_[provider][arch][type] == nullptr) {
|
||||
custom_kernel_creators_[provider][arch][type] =
|
||||
reinterpret_cast<CreateKernel *>(calloc(kDataTypeLen, sizeof(CreateKernel)));
|
||||
|
@ -64,20 +85,30 @@ Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::s
|
|||
}
|
||||
}
|
||||
|
||||
int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
|
||||
if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
|
||||
MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
|
||||
return kLiteError;
|
||||
}
|
||||
custom_kernel_creators_[provider][arch][type][data_type_index] = creator;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||
registry::CreateKernel creator) {
|
||||
const registry::CreateKernel creator) {
|
||||
if (type <= static_cast<int>(PrimitiveType_MIN) || type > static_cast<int>(PrimitiveType_MAX)) {
|
||||
MS_LOG(ERROR) << "Invalid op type : " << type;
|
||||
return kLiteParamInvalid;
|
||||
}
|
||||
KernelDesc desc = {data_type, type, arch, provider};
|
||||
int index = GetFuncIndex(desc);
|
||||
if (index < 0) {
|
||||
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type "
|
||||
<< type;
|
||||
return kLiteError;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto iter = kernel_creators_.find(provider);
|
||||
if (iter == kernel_creators_.end()) {
|
||||
if (kernel_creators_.size() >= kMaxProviderNum) {
|
||||
MS_LOG(ERROR) << "register too many provider!";
|
||||
return kLiteError;
|
||||
}
|
||||
kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
|
||||
if (kernel_creators_[provider][arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
|
||||
|
@ -86,6 +117,10 @@ Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string
|
|||
} else {
|
||||
auto iter_arch = iter->second.find(arch);
|
||||
if (iter_arch == iter->second.end()) {
|
||||
if (iter->second.size() >= kMaxArchPerProviderNum) {
|
||||
MS_LOG(ERROR) << "register too many arch!";
|
||||
return kLiteError;
|
||||
}
|
||||
iter->second[arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
|
||||
if (iter->second[arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
|
||||
|
@ -94,14 +129,6 @@ Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string
|
|||
}
|
||||
}
|
||||
|
||||
KernelDesc desc = {data_type, type, arch, provider};
|
||||
int index = GetFuncIndex(desc);
|
||||
if (index < 0) {
|
||||
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type "
|
||||
<< type;
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
kernel_creators_[provider][arch][index] = creator;
|
||||
return kSuccess;
|
||||
}
|
||||
|
@ -109,11 +136,11 @@ Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string
|
|||
registry::CreateKernel RegistryKernelImpl::GetCustomKernelCreator(const schema::Primitive *primitive,
|
||||
KernelDesc *desc) {
|
||||
int data_type_index = static_cast<int>(desc->data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
|
||||
if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
|
||||
if (data_type_index < 0 || desc->data_type >= DataType::kNumberTypeEnd) {
|
||||
return nullptr;
|
||||
}
|
||||
auto param = primitive->value_as_Custom();
|
||||
if (param == nullptr) {
|
||||
if (param == nullptr || param->type() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto custom_type = param->type()->str();
|
||||
|
|
|
@ -37,10 +37,10 @@ class RegistryKernelImpl {
|
|||
}
|
||||
|
||||
Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, registry::CreateKernel creator);
|
||||
const std::string &type, const registry::CreateKernel creator);
|
||||
|
||||
Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||
registry::CreateKernel creator);
|
||||
const registry::CreateKernel creator);
|
||||
|
||||
virtual registry::CreateKernel GetProviderCreator(const schema::Primitive *primitive, registry::KernelDesc *desc);
|
||||
|
||||
|
@ -60,7 +60,6 @@ class RegistryKernelImpl {
|
|||
std::mutex lock_;
|
||||
|
||||
registry::CreateKernel GetCustomKernelCreator(const schema::Primitive *primitive, registry::KernelDesc *desc);
|
||||
int GetFuncIndex(const registry::KernelDesc &desc);
|
||||
};
|
||||
} // namespace mindspore::registry
|
||||
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace registry {
|
||||
Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_type,
|
||||
const KernelInterfaceCreator creator) {
|
||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||
return KernelInterfaceRegistry::Instance()->Reg(CharToString(provider), op_type, creator);
|
||||
#else
|
||||
|
@ -32,7 +33,7 @@ Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_ty
|
|||
}
|
||||
|
||||
Status RegisterKernelInterface::CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
|
||||
KernelInterfaceCreator creator) {
|
||||
const KernelInterfaceCreator creator) {
|
||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||
return KernelInterfaceRegistry::Instance()->CustomReg(CharToString(provider), CharToString(op_type), creator);
|
||||
#else
|
||||
|
@ -41,10 +42,11 @@ Status RegisterKernelInterface::CustomReg(const std::vector<char> &provider, con
|
|||
#endif
|
||||
}
|
||||
|
||||
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
|
||||
const std::vector<char> &provider, const schema::Primitive *primitive) {
|
||||
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::vector<char> &provider,
|
||||
const schema::Primitive *primitive,
|
||||
const kernel::Kernel *kernel) {
|
||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||
return KernelInterfaceRegistry::Instance()->GetKernelInterface(CharToString(provider), primitive);
|
||||
return KernelInterfaceRegistry::Instance()->GetKernelInterface(CharToString(provider), primitive, kernel);
|
||||
#else
|
||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||
return nullptr;
|
||||
|
|
|
@ -34,23 +34,33 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const void *primitive, std::set<std::string> &&providers, int schema_version) {
|
||||
if (primitive == nullptr) {
|
||||
const void *primitive, std::set<std::string> &&providers, int schema_version,
|
||||
const kernel::Kernel *kernel) {
|
||||
if (primitive == nullptr && kernel == nullptr) {
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
std::shared_ptr<kernel::KernelInterface> kernel_interface = nullptr;
|
||||
if (IsCustomNode(primitive, schema_version)) {
|
||||
kernel_interface =
|
||||
registry::RegisterKernelInterface::GetKernelInterface("", static_cast<const schema::Primitive *>(primitive));
|
||||
bool is_custom_node = false;
|
||||
if (kernel == nullptr) {
|
||||
if (IsCustomNode(primitive, schema_version)) {
|
||||
is_custom_node = true;
|
||||
}
|
||||
} else if (kernel->type() == schema::PrimitiveType_Custom) {
|
||||
is_custom_node = true;
|
||||
}
|
||||
if (is_custom_node) {
|
||||
kernel_interface = registry::RegisterKernelInterface::GetKernelInterface(
|
||||
"", static_cast<const schema::Primitive *>(primitive), kernel);
|
||||
} else {
|
||||
for (auto &&provider : providers) {
|
||||
kernel_interface = registry::RegisterKernelInterface::GetKernelInterface(
|
||||
provider, static_cast<const schema::Primitive *>(primitive));
|
||||
provider, static_cast<const schema::Primitive *>(primitive), kernel);
|
||||
if (kernel_interface != nullptr) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (kernel_interface == nullptr) {
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
|
|
@ -26,13 +26,15 @@
|
|||
#include "src/tensor.h"
|
||||
#include "nnacl/tensor_c.h"
|
||||
#include "nnacl/infer/infer.h"
|
||||
#include "include/api/kernel.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *parameter);
|
||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const void *primitive, std::set<std::string> &&providers, int schema_version);
|
||||
const void *primitive, std::set<std::string> &&providers, int schema_version,
|
||||
const kernel::Kernel *kernel = nullptr);
|
||||
#endif
|
||||
class InferManager {
|
||||
public:
|
||||
|
|
|
@ -428,7 +428,7 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel)
|
|||
MS_ASSERT(conv_kernel);
|
||||
MS_ASSERT(scale_kernel);
|
||||
auto *scale_param =
|
||||
reinterpret_cast<ScaleParameter *>(reinterpret_cast<OpenCLKernel *>(scale_kernel)->GetParameter());
|
||||
reinterpret_cast<ScaleParameter *>(reinterpret_cast<OpenCLKernel *>(scale_kernel->kernel())->GetParameter());
|
||||
MS_ASSERT(scale_param);
|
||||
MS_ASSERT(conv_kernel->in_tensors().size() >= INPUT_TENSOR_SIZE_2);
|
||||
auto *filter = conv_kernel->in_tensors().at(1);
|
||||
|
|
|
@ -1373,6 +1373,9 @@ kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src
|
|||
|
||||
SetKernelTensorDataType(kernel);
|
||||
kernel->set_name(src_node->name_);
|
||||
if (kernel->kernel() != nullptr) {
|
||||
kernel->kernel()->SetConfig(config_info_);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -59,6 +59,9 @@ class Scheduler {
|
|||
~Scheduler() = default;
|
||||
int Schedule(std::vector<kernel::LiteKernel *> *dst_kernels);
|
||||
void SetupSchedulerCb(std::unique_ptr<SchedulerCb> cb) { sched_cb_ = std::move(cb); }
|
||||
void SetConfig(const std::map<std::string, std::map<std::string, std::string>> *config_info) {
|
||||
config_info_ = config_info;
|
||||
}
|
||||
|
||||
private:
|
||||
int SchedulePreProcess();
|
||||
|
@ -165,6 +168,7 @@ class Scheduler {
|
|||
#endif
|
||||
int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR;
|
||||
std::map<std::string, TypeId> *execution_plan_ = nullptr;
|
||||
const std::map<std::string, std::map<std::string, std::string>> *config_info_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ int SubGraphKernel::ReSize() {
|
|||
int ret;
|
||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||
ret = lite::KernelInferShape(inputs, outputs, kernel->kernel()->primitive(), kernel->Context()->GetProviders(),
|
||||
schema_version_);
|
||||
schema_version_, kernel->kernel());
|
||||
if (ret == lite::RET_NOT_SUPPORT) {
|
||||
#endif
|
||||
auto parameter = kernel->op_parameter();
|
||||
|
|
|
@ -51,10 +51,10 @@ TEST_F(MixDataTypeTest, Config1) {
|
|||
|
||||
std::string filename = "MixDataTypeTestConfig";
|
||||
std::string sectionname = "execution_plan";
|
||||
std::map<std::string, std::string> config_info;
|
||||
ret = lite::GetSectionInfoFromConfigFile(filename, sectionname, &config_info);
|
||||
std::map<std::string, std::map<std::string, std::string>> configs;
|
||||
ret = lite::GetAllSectionInfoFromConfigFile(filename, &configs);
|
||||
ASSERT_EQ(ret, 0);
|
||||
|
||||
std::map<std::string, std::string> config_info = configs[sectionname];
|
||||
ASSERT_EQ(config_info.size(), 2);
|
||||
|
||||
auto info0 = config_info.at("op1");
|
||||
|
|
Loading…
Reference in New Issue