context support provider

This commit is contained in:
chenjianping 2021-05-14 20:27:37 +08:00
parent d4eef75155
commit e790386528
18 changed files with 86 additions and 66 deletions

View File

@ -16,7 +16,7 @@
#ifndef MINDSPORE_LITE_INCLUDE_CONTEXT_H_
#define MINDSPORE_LITE_INCLUDE_CONTEXT_H_
#include <string>
#include "include/ms_tensor.h"
#include "include/lite_utils.h"
#include "include/lite_types.h"
@ -57,6 +57,8 @@ union DeviceInfo {
struct DeviceContext {
DeviceType device_type_ = DT_CPU;
DeviceInfo device_info_;
std::string provider_{};
std::string provider_device_{};
};
/// \brief Context defined for holding environment variables during runtime.

View File

@ -134,7 +134,7 @@ set(LITE_SRC
${LITE_DIR}/src/common/prim_util.cc
${LITE_DIR}/src/common/tensor_util.cc
${LITE_DIR}/src/runtime/infer_manager.cc
${LITE_DIR}/src/kernel_interface_registry.cc
${LITE_DIR}/src/registry/kernel_interface_registry.cc
${LITE_DIR}/src/lite_model.cc
${LITE_DIR}/src/tensorlist.cc
${LITE_DIR}/src/tensor.cc

View File

@ -64,9 +64,9 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/register_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/registry/register_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/registry/kernel_interface.cc
${CMAKE_CURRENT_SOURCE_DIR}/registry/kernel_interface_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/inner_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel_util.cc

View File

@ -171,6 +171,12 @@ bool InnerContext::IsNpuEnabled() const {
#endif
}
bool InnerContext::IsProviderEnabled() const {
return this->device_list_.end() !=
std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return !device.provider_.empty(); });
}
bool InnerContext::IsUserSetCpu() const {
return this->device_list_.end() !=
std::find_if(this->device_list_.begin(), this->device_list_.end(),
@ -189,6 +195,16 @@ bool InnerContext::IsUserSetNpu() const {
[](const DeviceContext &device) { return device.device_type_ == DT_NPU; });
}
std::set<std::string> InnerContext::GetProviders() const {
std::set<std::string> providers;
for (auto &&device : device_list_) {
if (!device.provider_.empty()) {
providers.insert(device.provider_);
}
}
return providers;
}
CpuDeviceInfo InnerContext::GetCpuInfo() const {
auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_CPU; });

View File

@ -16,7 +16,8 @@
#ifndef MINDSPORE_LITE_SRC_INNER_CONTEXT_H
#define MINDSPORE_LITE_SRC_INNER_CONTEXT_H
#include <set>
#include <string>
#include "include/context.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/allocator.h"
@ -48,6 +49,10 @@ struct InnerContext : public Context {
bool IsNpuEnabled() const;
bool IsProviderEnabled() const;
std::set<std::string> GetProviders() const;
CpuDeviceInfo GetCpuInfo() const;
GpuDeviceInfo GetGpuInfo() const;

View File

@ -55,17 +55,6 @@ KernelRegistry *KernelRegistry::GetInstance() {
return &instance;
}
std::set<std::string> KernelRegistry::AllProviders() {
std::set<std::string> providers;
for (auto &&item : kernel_creators_) {
providers.insert(item.first);
}
for (auto &&item : custom_kernel_creators_) {
providers.insert(item.first);
}
return providers;
}
int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) {
if (desc.data_type >= kNumberTypeEnd) {
return -1;
@ -166,12 +155,20 @@ kernel::CreateKernel KernelRegistry::GetProviderCreator(const kernel::KernelKey
MS_ASSERT(param != nullptr);
auto custom_type = param->type()->str();
auto archs = custom_kernel_creators_[desc.provider];
auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) {
return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr;
});
if (archs_iter != archs.end()) {
return archs_iter->second[custom_type][data_type_index];
if (desc.kernel_arch.empty()) {
auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) {
return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr;
});
if (archs_iter != archs.end()) {
return archs_iter->second[custom_type][data_type_index];
}
} else {
auto find_arch_it = archs.find(desc.kernel_arch);
if (find_arch_it != archs.end()) {
return find_arch_it->second[custom_type][data_type_index];
}
}
return nullptr;
}
auto index = GetFuncIndex(desc);

View File

@ -23,7 +23,7 @@
#include <vector>
#include <set>
#include "src/lite_kernel.h"
#include "src/register_kernel.h"
#include "src/registry/register_kernel.h"
#include "schema/model_generated.h"
using mindspore::kernel::kKernelArch_MAX;
@ -43,7 +43,6 @@ class KernelRegistry {
virtual kernel::CreateKernel GetProviderCreator(const kernel::KernelKey &desc, const schema::Primitive *prim);
int GetCreatorFuncIndex(kernel::KernelKey desc);
int GetFuncIndex(const kernel::KernelKey &desc);
std::set<std::string> AllProviders();
int RegCustomKernel(const std::string &arch, const std::string &vendor, TypeId data_type, const std::string &type,
kernel::CreateKernel creator);
void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator);

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/kernel_interface.h"
#include "src/kernel_interface_registry.h"
#include "src/registry/kernel_interface.h"
#include "src/registry/kernel_interface_registry.h"
namespace mindspore {
namespace kernel {

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/kernel_interface_registry.h"
#include "src/kernel_interface.h"
#include "src/registry/kernel_interface_registry.h"
#include "src/registry/kernel_interface.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/common/version_manager.h"
@ -34,7 +34,7 @@ std::string GetCustomType(const schema::Primitive *primitive) {
}
} // namespace
bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node) {
bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node, std::set<std::string> &&providers) {
if (VersionManager::GetInstance()->GetSchemaVersion() == SCHEMA_V0) {
return false;
}
@ -46,7 +46,10 @@ bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node) {
auto op_type = primitive->value_type();
if (op_type == schema::PrimitiveType_Custom) {
auto &&custom_type = GetCustomType(primitive);
return std::any_of(custom_creators_.begin(), custom_creators_.end(), [&custom_type](auto &&item) {
return std::any_of(custom_creators_.begin(), custom_creators_.end(), [&custom_type, &providers](auto &&item) {
if (providers.find(item.first) == providers.end()) {
return false;
}
if (item.second[custom_type] != nullptr) {
return true;
}
@ -156,17 +159,6 @@ int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, Kerne
return RET_OK;
}
std::set<std::string> KernelInterfaceRegistry::AllProviders() {
std::set<std::string> providers;
for (auto &&item : kernel_creators_) {
providers.insert(item.first);
}
for (auto &&item : custom_creators_) {
providers.insert(item.first);
}
return providers;
}
KernelInterfaceRegistry::~KernelInterfaceRegistry() {
for (auto &&item : kernel_creators_) {
free(item.second);

View File

@ -21,7 +21,7 @@
#include <map>
#include <mutex>
#include <set>
#include "src/kernel_interface.h"
#include "src/registry/kernel_interface.h"
#include "include/model.h"
namespace mindspore {
@ -32,11 +32,10 @@ class KernelInterfaceRegistry {
static KernelInterfaceRegistry instance;
return &instance;
}
bool CheckReg(const lite::Model::Node *node);
bool CheckReg(const lite::Model::Node *node, std::set<std::string> &&providers);
kernel::KernelInterface *GetKernelInterface(const std::string &provider, const schema::Primitive *primitive);
int CustomReg(const std::string &provider, const std::string &op_type, kernel::KernelInterfaceCreator creator);
int Reg(const std::string &provider, int op_type, kernel::KernelInterfaceCreator creator);
std::set<std::string> AllProviders();
virtual ~KernelInterfaceRegistry();
private:

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/register_kernel.h"
#include "src/registry/register_kernel.h"
#include "src/kernel_registry.h"
namespace mindspore {

View File

@ -15,24 +15,26 @@
*/
#include "src/runtime/infer_manager.h"
#include <algorithm>
#include <set>
#include <string>
#include "src/common/prim_util.h"
#include "src/common/tensor_util.h"
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "nnacl/errorcode.h"
#include "src/tensorlist.h"
#include "src/kernel_interface_registry.h"
#include "src/registry/kernel_interface_registry.h"
#include "src/kernel_registry.h"
namespace mindspore {
namespace lite {
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const void *primitive) {
const void *primitive, std::set<std::string> &&providers) {
std::vector<tensor::MSTensor *> in_tensors;
std::copy(inputs.begin(), inputs.end(), std::back_inserter(in_tensors));
std::vector<tensor::MSTensor *> out_tensors;
std::copy(outputs.begin(), outputs.end(), std::back_inserter(out_tensors));
for (auto &&provider : KernelInterfaceRegistry::Instance()->AllProviders()) {
for (auto &&provider : providers) {
auto kernel_interface = KernelInterfaceRegistry::Instance()->GetKernelInterface(
provider, static_cast<const schema::Primitive *>(primitive));
if (kernel_interface == nullptr) {

View File

@ -19,6 +19,8 @@
#include <map>
#include <vector>
#include <set>
#include <string>
#include "src/common/prim_util.h"
#include "src/common/common.h"
#include "nnacl/tensor_c.h"
@ -28,7 +30,7 @@ namespace mindspore::lite {
int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs,
OpParameter *parameter);
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const void *primitive);
const void *primitive, std::set<std::string> &&providers);
class InferManager {
public:
static InferManager *GetInstance() {

View File

@ -48,7 +48,7 @@
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
#endif
#include "src/kernel_interface_registry.h"
#include "src/registry/kernel_interface_registry.h"
namespace mindspore::lite {
using kernel::KERNEL_ARCH::kCPU;
@ -131,8 +131,8 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
std::vector<Tensor *> inputs;
std::vector<Tensor *> outputs;
FindNodeInoutTensors(*node, &inputs, &outputs);
if (KernelInterfaceRegistry::Instance()->CheckReg(node)) {
return KernelInferShape(inputs, outputs, node->primitive_);
if (KernelInterfaceRegistry::Instance()->CheckReg(node, context_->GetProviders())) {
return KernelInferShape(inputs, outputs, node->primitive_, context_->GetProviders());
}
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
@ -442,20 +442,26 @@ int Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std:
int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel) {
MS_ASSERT(kernel != nullptr);
int ret = RET_ERROR;
auto &&providers = KernelRegistry::GetInstance()->AllProviders();
int ret = RET_NOT_SUPPORT;
if (!context_->IsProviderEnabled()) {
return ret;
}
if (VersionManager::GetInstance()->GetSchemaVersion() == SCHEMA_V0) {
return ret;
}
for (auto &&provider : providers) {
kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), "", provider};
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel,
node->primitive_);
if (ret == RET_OK && *kernel != nullptr) {
return ret;
for (auto &&device : context_->device_list_) {
if (!device.provider_.empty()) {
kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), device.provider_device_,
device.provider_};
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel,
node->primitive_);
if (ret == RET_OK && *kernel != nullptr) {
return ret;
}
}
}
return ret;
return RET_NOT_SUPPORT;
}
kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,

View File

@ -141,10 +141,10 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/tensorlist.cc
${LITE_DIR}/src/executor.cc
${LITE_DIR}/src/inner_context.cc
${LITE_DIR}/src/kernel_interface.cc
${LITE_DIR}/src/kernel_interface_registry.cc
${LITE_DIR}/src/registry/kernel_interface.cc
${LITE_DIR}/src/registry/kernel_interface_registry.cc
${LITE_DIR}/src/kernel_registry.cc
${LITE_DIR}/src/register_kernel.cc
${LITE_DIR}/src/registry/register_kernel.cc
${LITE_DIR}/src/inner_kernel.cc
${LITE_DIR}/src/lite_kernel.cc
${LITE_DIR}/src/lite_kernel_util.cc

View File

@ -124,9 +124,9 @@ set(LITE_SRC
${SRC_DIR}/tensor.cc
${SRC_DIR}/ms_tensor.cc
${SRC_DIR}/tensorlist.cc
${SRC_DIR}/kernel_interface_registry.cc
${SRC_DIR}/registry/kernel_interface_registry.cc
${SRC_DIR}/kernel_registry.cc
${SRC_DIR}/register_kernel.cc
${SRC_DIR}/registry/register_kernel.cc
${SRC_DIR}/inner_kernel.cc
${SRC_DIR}/lite_kernel.cc
${SRC_DIR}/lite_kernel_util.cc