forked from mindspore-Ecosystem/mindspore
registry kernel support cxx api
This commit is contained in:
parent
cf3213281d
commit
3f4a587849
|
@ -74,6 +74,15 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
|
||||||
return std::static_pointer_cast<T>(shared_from_this());
|
return std::static_pointer_cast<T>(shared_from_this());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string GetProvider() const;
|
||||||
|
void SetProvider(const std::string &provider);
|
||||||
|
|
||||||
|
std::string GetProviderDevice() const;
|
||||||
|
void SetProviderDevice(const std::string &device);
|
||||||
|
|
||||||
|
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
|
||||||
|
std::shared_ptr<Allocator> GetAllocator() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<Data> data_;
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -28,6 +28,8 @@ constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
|
||||||
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
|
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
|
||||||
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
|
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
|
||||||
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
|
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
|
||||||
|
constexpr auto kModelOptionProvider = "mindspore.option.provider";
|
||||||
|
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
|
||||||
|
|
||||||
struct Context::Data {
|
struct Context::Data {
|
||||||
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
||||||
|
@ -37,6 +39,7 @@ struct Context::Data {
|
||||||
|
|
||||||
struct DeviceInfoContext::Data {
|
struct DeviceInfoContext::Data {
|
||||||
std::map<std::string, std::any> params;
|
std::map<std::string, std::any> params;
|
||||||
|
std::shared_ptr<Allocator> allocator = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
Context::Context() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
|
Context::Context() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
|
||||||
|
@ -97,6 +100,54 @@ std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
|
||||||
|
|
||||||
DeviceInfoContext::DeviceInfoContext() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
|
DeviceInfoContext::DeviceInfoContext() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
|
||||||
|
|
||||||
|
std::string DeviceInfoContext::GetProvider() const {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return GetValue<std::string>(data_, kModelOptionProvider);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceInfoContext::SetProvider(const std::string &provider) {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data_->params[kModelOptionProvider] = provider;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string DeviceInfoContext::GetProviderDevice() const {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return GetValue<std::string>(data_, kModelOptionProviderDevice);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceInfoContext::SetProviderDevice(const std::string &device) {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data_->params[kModelOptionProviderDevice] = device;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceInfoContext::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data_->allocator = allocator;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Allocator> DeviceInfoContext::GetAllocator() const {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return data_->allocator;
|
||||||
|
}
|
||||||
|
|
||||||
void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
|
|
@ -90,13 +90,15 @@ Status ModelImpl::Build() {
|
||||||
}
|
}
|
||||||
lite::DeviceInfo cpu_info = {0};
|
lite::DeviceInfo cpu_info = {0};
|
||||||
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
|
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
|
||||||
model_context.device_list_.push_back({lite::DT_CPU, cpu_info});
|
model_context.device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
|
||||||
|
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
|
||||||
if (device_list.size() == 2) {
|
if (device_list.size() == 2) {
|
||||||
lite::DeviceInfo device_info = {0};
|
lite::DeviceInfo device_info = {0};
|
||||||
if (device_list[1]->GetDeviceType() == kMaliGPU) {
|
if (device_list[1]->GetDeviceType() == kMaliGPU) {
|
||||||
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
|
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
|
||||||
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
|
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
|
||||||
model_context.device_list_.push_back({lite::DT_GPU, device_info});
|
model_context.device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
|
||||||
|
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
|
||||||
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
|
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
|
||||||
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
|
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
|
||||||
device_info.npu_device_info_ = {npu_context->GetFrequency()};
|
device_info.npu_device_info_ = {npu_context->GetFrequency()};
|
||||||
|
|
|
@ -155,7 +155,7 @@ int CustomSubGraph::Prepare() {
|
||||||
}
|
}
|
||||||
auto provider = nodes_[0]->desc().provider;
|
auto provider = nodes_[0]->desc().provider;
|
||||||
auto context = this->Context();
|
auto context = this->Context();
|
||||||
AllocatorPtr allocator = nullptr;
|
AllocatorPtr allocator = context->allocator;
|
||||||
auto iter = std::find_if(context->device_list_.begin(), context->device_list_.end(),
|
auto iter = std::find_if(context->device_list_.begin(), context->device_list_.end(),
|
||||||
[&provider](const auto &dev) { return dev.provider_ == provider; });
|
[&provider](const auto &dev) { return dev.provider_ == provider; });
|
||||||
if (iter != context->device_list_.end()) {
|
if (iter != context->device_list_.end()) {
|
||||||
|
@ -173,7 +173,7 @@ int CustomSubGraph::Prepare() {
|
||||||
auto node = nodes_[nodes_.size() - 1];
|
auto node = nodes_[nodes_.size() - 1];
|
||||||
for (auto tensor : node->out_tensors()) {
|
for (auto tensor : node->out_tensors()) {
|
||||||
MS_ASSERT(tensor != nullptr);
|
MS_ASSERT(tensor != nullptr);
|
||||||
tensor->set_allocator(this->Context()->allocator);
|
tensor->set_allocator(context->allocator);
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue