forked from mindspore-Ecosystem/mindspore
context support provider
This commit is contained in:
parent
d4eef75155
commit
e790386528
|
@ -16,7 +16,7 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_CONTEXT_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_CONTEXT_H_
|
||||
|
||||
#include <string>
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/lite_utils.h"
|
||||
#include "include/lite_types.h"
|
||||
|
@ -57,6 +57,8 @@ union DeviceInfo {
|
|||
struct DeviceContext {
|
||||
DeviceType device_type_ = DT_CPU;
|
||||
DeviceInfo device_info_;
|
||||
std::string provider_{};
|
||||
std::string provider_device_{};
|
||||
};
|
||||
|
||||
/// \brief Context defined for holding environment variables during runtime.
|
||||
|
|
|
@ -134,7 +134,7 @@ set(LITE_SRC
|
|||
${LITE_DIR}/src/common/prim_util.cc
|
||||
${LITE_DIR}/src/common/tensor_util.cc
|
||||
${LITE_DIR}/src/runtime/infer_manager.cc
|
||||
${LITE_DIR}/src/kernel_interface_registry.cc
|
||||
${LITE_DIR}/src/registry/kernel_interface_registry.cc
|
||||
${LITE_DIR}/src/lite_model.cc
|
||||
${LITE_DIR}/src/tensorlist.cc
|
||||
${LITE_DIR}/src/tensor.cc
|
||||
|
|
|
@ -64,9 +64,9 @@ set(LITE_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/register_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface_registry.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/registry/register_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/registry/kernel_interface.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/registry/kernel_interface_registry.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inner_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel_util.cc
|
||||
|
|
|
@ -171,6 +171,12 @@ bool InnerContext::IsNpuEnabled() const {
|
|||
#endif
|
||||
}
|
||||
|
||||
bool InnerContext::IsProviderEnabled() const {
|
||||
return this->device_list_.end() !=
|
||||
std::find_if(this->device_list_.begin(), this->device_list_.end(),
|
||||
[](const DeviceContext &device) { return !device.provider_.empty(); });
|
||||
}
|
||||
|
||||
bool InnerContext::IsUserSetCpu() const {
|
||||
return this->device_list_.end() !=
|
||||
std::find_if(this->device_list_.begin(), this->device_list_.end(),
|
||||
|
@ -189,6 +195,16 @@ bool InnerContext::IsUserSetNpu() const {
|
|||
[](const DeviceContext &device) { return device.device_type_ == DT_NPU; });
|
||||
}
|
||||
|
||||
std::set<std::string> InnerContext::GetProviders() const {
|
||||
std::set<std::string> providers;
|
||||
for (auto &&device : device_list_) {
|
||||
if (!device.provider_.empty()) {
|
||||
providers.insert(device.provider_);
|
||||
}
|
||||
}
|
||||
return providers;
|
||||
}
|
||||
|
||||
CpuDeviceInfo InnerContext::GetCpuInfo() const {
|
||||
auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(),
|
||||
[](const DeviceContext &device) { return device.device_type_ == DT_CPU; });
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_SRC_INNER_CONTEXT_H
|
||||
#define MINDSPORE_LITE_SRC_INNER_CONTEXT_H
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "include/context.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/allocator.h"
|
||||
|
@ -48,6 +49,10 @@ struct InnerContext : public Context {
|
|||
|
||||
bool IsNpuEnabled() const;
|
||||
|
||||
bool IsProviderEnabled() const;
|
||||
|
||||
std::set<std::string> GetProviders() const;
|
||||
|
||||
CpuDeviceInfo GetCpuInfo() const;
|
||||
|
||||
GpuDeviceInfo GetGpuInfo() const;
|
||||
|
|
|
@ -55,17 +55,6 @@ KernelRegistry *KernelRegistry::GetInstance() {
|
|||
return &instance;
|
||||
}
|
||||
|
||||
std::set<std::string> KernelRegistry::AllProviders() {
|
||||
std::set<std::string> providers;
|
||||
for (auto &&item : kernel_creators_) {
|
||||
providers.insert(item.first);
|
||||
}
|
||||
for (auto &&item : custom_kernel_creators_) {
|
||||
providers.insert(item.first);
|
||||
}
|
||||
return providers;
|
||||
}
|
||||
|
||||
int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) {
|
||||
if (desc.data_type >= kNumberTypeEnd) {
|
||||
return -1;
|
||||
|
@ -166,12 +155,20 @@ kernel::CreateKernel KernelRegistry::GetProviderCreator(const kernel::KernelKey
|
|||
MS_ASSERT(param != nullptr);
|
||||
auto custom_type = param->type()->str();
|
||||
auto archs = custom_kernel_creators_[desc.provider];
|
||||
auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) {
|
||||
return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr;
|
||||
});
|
||||
if (archs_iter != archs.end()) {
|
||||
return archs_iter->second[custom_type][data_type_index];
|
||||
if (desc.kernel_arch.empty()) {
|
||||
auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) {
|
||||
return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr;
|
||||
});
|
||||
if (archs_iter != archs.end()) {
|
||||
return archs_iter->second[custom_type][data_type_index];
|
||||
}
|
||||
} else {
|
||||
auto find_arch_it = archs.find(desc.kernel_arch);
|
||||
if (find_arch_it != archs.end()) {
|
||||
return find_arch_it->second[custom_type][data_type_index];
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
auto index = GetFuncIndex(desc);
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/register_kernel.h"
|
||||
#include "src/registry/register_kernel.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
||||
using mindspore::kernel::kKernelArch_MAX;
|
||||
|
@ -43,7 +43,6 @@ class KernelRegistry {
|
|||
virtual kernel::CreateKernel GetProviderCreator(const kernel::KernelKey &desc, const schema::Primitive *prim);
|
||||
int GetCreatorFuncIndex(kernel::KernelKey desc);
|
||||
int GetFuncIndex(const kernel::KernelKey &desc);
|
||||
std::set<std::string> AllProviders();
|
||||
int RegCustomKernel(const std::string &arch, const std::string &vendor, TypeId data_type, const std::string &type,
|
||||
kernel::CreateKernel creator);
|
||||
void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator);
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/kernel_interface.h"
|
||||
#include "src/kernel_interface_registry.h"
|
||||
#include "src/registry/kernel_interface.h"
|
||||
#include "src/registry/kernel_interface_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/kernel_interface_registry.h"
|
||||
#include "src/kernel_interface.h"
|
||||
#include "src/registry/kernel_interface_registry.h"
|
||||
#include "src/registry/kernel_interface.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/version_manager.h"
|
||||
|
@ -34,7 +34,7 @@ std::string GetCustomType(const schema::Primitive *primitive) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node) {
|
||||
bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node, std::set<std::string> &&providers) {
|
||||
if (VersionManager::GetInstance()->GetSchemaVersion() == SCHEMA_V0) {
|
||||
return false;
|
||||
}
|
||||
|
@ -46,7 +46,10 @@ bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node) {
|
|||
auto op_type = primitive->value_type();
|
||||
if (op_type == schema::PrimitiveType_Custom) {
|
||||
auto &&custom_type = GetCustomType(primitive);
|
||||
return std::any_of(custom_creators_.begin(), custom_creators_.end(), [&custom_type](auto &&item) {
|
||||
return std::any_of(custom_creators_.begin(), custom_creators_.end(), [&custom_type, &providers](auto &&item) {
|
||||
if (providers.find(item.first) == providers.end()) {
|
||||
return false;
|
||||
}
|
||||
if (item.second[custom_type] != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
@ -156,17 +159,6 @@ int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, Kerne
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
std::set<std::string> KernelInterfaceRegistry::AllProviders() {
|
||||
std::set<std::string> providers;
|
||||
for (auto &&item : kernel_creators_) {
|
||||
providers.insert(item.first);
|
||||
}
|
||||
for (auto &&item : custom_creators_) {
|
||||
providers.insert(item.first);
|
||||
}
|
||||
return providers;
|
||||
}
|
||||
|
||||
KernelInterfaceRegistry::~KernelInterfaceRegistry() {
|
||||
for (auto &&item : kernel_creators_) {
|
||||
free(item.second);
|
|
@ -21,7 +21,7 @@
|
|||
#include <map>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include "src/kernel_interface.h"
|
||||
#include "src/registry/kernel_interface.h"
|
||||
#include "include/model.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -32,11 +32,10 @@ class KernelInterfaceRegistry {
|
|||
static KernelInterfaceRegistry instance;
|
||||
return &instance;
|
||||
}
|
||||
bool CheckReg(const lite::Model::Node *node);
|
||||
bool CheckReg(const lite::Model::Node *node, std::set<std::string> &&providers);
|
||||
kernel::KernelInterface *GetKernelInterface(const std::string &provider, const schema::Primitive *primitive);
|
||||
int CustomReg(const std::string &provider, const std::string &op_type, kernel::KernelInterfaceCreator creator);
|
||||
int Reg(const std::string &provider, int op_type, kernel::KernelInterfaceCreator creator);
|
||||
std::set<std::string> AllProviders();
|
||||
virtual ~KernelInterfaceRegistry();
|
||||
|
||||
private:
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/register_kernel.h"
|
||||
#include "src/registry/register_kernel.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
namespace mindspore {
|
|
@ -15,24 +15,26 @@
|
|||
*/
|
||||
#include "src/runtime/infer_manager.h"
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "src/common/prim_util.h"
|
||||
#include "src/common/tensor_util.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "src/tensorlist.h"
|
||||
#include "src/kernel_interface_registry.h"
|
||||
#include "src/registry/kernel_interface_registry.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const void *primitive) {
|
||||
const void *primitive, std::set<std::string> &&providers) {
|
||||
std::vector<tensor::MSTensor *> in_tensors;
|
||||
std::copy(inputs.begin(), inputs.end(), std::back_inserter(in_tensors));
|
||||
std::vector<tensor::MSTensor *> out_tensors;
|
||||
std::copy(outputs.begin(), outputs.end(), std::back_inserter(out_tensors));
|
||||
for (auto &&provider : KernelInterfaceRegistry::Instance()->AllProviders()) {
|
||||
for (auto &&provider : providers) {
|
||||
auto kernel_interface = KernelInterfaceRegistry::Instance()->GetKernelInterface(
|
||||
provider, static_cast<const schema::Primitive *>(primitive));
|
||||
if (kernel_interface == nullptr) {
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "src/common/prim_util.h"
|
||||
#include "src/common/common.h"
|
||||
#include "nnacl/tensor_c.h"
|
||||
|
@ -28,7 +30,7 @@ namespace mindspore::lite {
|
|||
int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *parameter);
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const void *primitive);
|
||||
const void *primitive, std::set<std::string> &&providers);
|
||||
class InferManager {
|
||||
public:
|
||||
static InferManager *GetInstance() {
|
||||
|
|
|
@ -48,7 +48,7 @@
|
|||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
|
||||
#endif
|
||||
#include "src/kernel_interface_registry.h"
|
||||
#include "src/registry/kernel_interface_registry.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
using kernel::KERNEL_ARCH::kCPU;
|
||||
|
@ -131,8 +131,8 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
|
|||
std::vector<Tensor *> inputs;
|
||||
std::vector<Tensor *> outputs;
|
||||
FindNodeInoutTensors(*node, &inputs, &outputs);
|
||||
if (KernelInterfaceRegistry::Instance()->CheckReg(node)) {
|
||||
return KernelInferShape(inputs, outputs, node->primitive_);
|
||||
if (KernelInterfaceRegistry::Instance()->CheckReg(node, context_->GetProviders())) {
|
||||
return KernelInferShape(inputs, outputs, node->primitive_, context_->GetProviders());
|
||||
}
|
||||
|
||||
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
|
||||
|
@ -442,20 +442,26 @@ int Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std:
|
|||
int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel) {
|
||||
MS_ASSERT(kernel != nullptr);
|
||||
int ret = RET_ERROR;
|
||||
auto &&providers = KernelRegistry::GetInstance()->AllProviders();
|
||||
int ret = RET_NOT_SUPPORT;
|
||||
if (!context_->IsProviderEnabled()) {
|
||||
return ret;
|
||||
}
|
||||
if (VersionManager::GetInstance()->GetSchemaVersion() == SCHEMA_V0) {
|
||||
return ret;
|
||||
}
|
||||
for (auto &&provider : providers) {
|
||||
kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), "", provider};
|
||||
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel,
|
||||
node->primitive_);
|
||||
if (ret == RET_OK && *kernel != nullptr) {
|
||||
return ret;
|
||||
for (auto &&device : context_->device_list_) {
|
||||
if (!device.provider_.empty()) {
|
||||
kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), device.provider_device_,
|
||||
device.provider_};
|
||||
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel,
|
||||
node->primitive_);
|
||||
if (ret == RET_OK && *kernel != nullptr) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,
|
||||
|
|
|
@ -141,10 +141,10 @@ set(TEST_LITE_SRC
|
|||
${LITE_DIR}/src/tensorlist.cc
|
||||
${LITE_DIR}/src/executor.cc
|
||||
${LITE_DIR}/src/inner_context.cc
|
||||
${LITE_DIR}/src/kernel_interface.cc
|
||||
${LITE_DIR}/src/kernel_interface_registry.cc
|
||||
${LITE_DIR}/src/registry/kernel_interface.cc
|
||||
${LITE_DIR}/src/registry/kernel_interface_registry.cc
|
||||
${LITE_DIR}/src/kernel_registry.cc
|
||||
${LITE_DIR}/src/register_kernel.cc
|
||||
${LITE_DIR}/src/registry/register_kernel.cc
|
||||
${LITE_DIR}/src/inner_kernel.cc
|
||||
${LITE_DIR}/src/lite_kernel.cc
|
||||
${LITE_DIR}/src/lite_kernel_util.cc
|
||||
|
|
|
@ -124,9 +124,9 @@ set(LITE_SRC
|
|||
${SRC_DIR}/tensor.cc
|
||||
${SRC_DIR}/ms_tensor.cc
|
||||
${SRC_DIR}/tensorlist.cc
|
||||
${SRC_DIR}/kernel_interface_registry.cc
|
||||
${SRC_DIR}/registry/kernel_interface_registry.cc
|
||||
${SRC_DIR}/kernel_registry.cc
|
||||
${SRC_DIR}/register_kernel.cc
|
||||
${SRC_DIR}/registry/register_kernel.cc
|
||||
${SRC_DIR}/inner_kernel.cc
|
||||
${SRC_DIR}/lite_kernel.cc
|
||||
${SRC_DIR}/lite_kernel_util.cc
|
||||
|
|
Loading…
Reference in New Issue