registry kernel support cxx api

This commit is contained in:
chenjianping 2021-06-28 18:55:51 +08:00
parent cf3213281d
commit 3f4a587849
4 changed files with 66 additions and 4 deletions

View File

@ -74,6 +74,15 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
return std::static_pointer_cast<T>(shared_from_this());
}
std::string GetProvider() const;
void SetProvider(const std::string &provider);
std::string GetProviderDevice() const;
void SetProviderDevice(const std::string &device);
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
std::shared_ptr<Allocator> GetAllocator() const;
protected:
std::shared_ptr<Data> data_;
};

View File

@ -28,6 +28,8 @@ constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
constexpr auto kModelOptionProvider = "mindspore.option.provider";
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
struct Context::Data {
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
@ -37,6 +39,7 @@ struct Context::Data {
struct DeviceInfoContext::Data {
std::map<std::string, std::any> params;
std::shared_ptr<Allocator> allocator = nullptr;
};
Context::Context() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
@ -97,6 +100,54 @@ std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
DeviceInfoContext::DeviceInfoContext() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
std::string DeviceInfoContext::GetProvider() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return "";
}
return GetValue<std::string>(data_, kModelOptionProvider);
}
void DeviceInfoContext::SetProvider(const std::string &provider) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionProvider] = provider;
}
std::string DeviceInfoContext::GetProviderDevice() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return "";
}
return GetValue<std::string>(data_, kModelOptionProviderDevice);
}
void DeviceInfoContext::SetProviderDevice(const std::string &device) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionProviderDevice] = device;
}
void DeviceInfoContext::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->allocator = allocator;
}
std::shared_ptr<Allocator> DeviceInfoContext::GetAllocator() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return nullptr;
}
return data_->allocator;
}
void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";

View File

@ -90,13 +90,15 @@ Status ModelImpl::Build() {
}
lite::DeviceInfo cpu_info = {0};
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
model_context.device_list_.push_back({lite::DT_CPU, cpu_info});
model_context.device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
if (device_list.size() == 2) {
lite::DeviceInfo device_info = {0};
if (device_list[1]->GetDeviceType() == kMaliGPU) {
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
model_context.device_list_.push_back({lite::DT_GPU, device_info});
model_context.device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
device_info.npu_device_info_ = {npu_context->GetFrequency()};

View File

@ -155,7 +155,7 @@ int CustomSubGraph::Prepare() {
}
auto provider = nodes_[0]->desc().provider;
auto context = this->Context();
AllocatorPtr allocator = nullptr;
AllocatorPtr allocator = context->allocator;
auto iter = std::find_if(context->device_list_.begin(), context->device_list_.end(),
[&provider](const auto &dev) { return dev.provider_ == provider; });
if (iter != context->device_list_.end()) {
@ -173,7 +173,7 @@ int CustomSubGraph::Prepare() {
auto node = nodes_[nodes_.size() - 1];
for (auto tensor : node->out_tensors()) {
MS_ASSERT(tensor != nullptr);
tensor->set_allocator(this->Context()->allocator);
tensor->set_allocator(context->allocator);
}
return RET_OK;
}