forked from mindspore-Ecosystem/mindspore
registry kernel support cxx api
This commit is contained in:
parent
cf3213281d
commit
3f4a587849
|
@ -74,6 +74,15 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
|
|||
return std::static_pointer_cast<T>(shared_from_this());
|
||||
}
|
||||
|
||||
std::string GetProvider() const;
|
||||
void SetProvider(const std::string &provider);
|
||||
|
||||
std::string GetProviderDevice() const;
|
||||
void SetProviderDevice(const std::string &device);
|
||||
|
||||
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
|
||||
std::shared_ptr<Allocator> GetAllocator() const;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
|
|
@ -28,6 +28,8 @@ constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
|
|||
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
|
||||
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
|
||||
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
|
||||
constexpr auto kModelOptionProvider = "mindspore.option.provider";
|
||||
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
|
||||
|
||||
struct Context::Data {
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
||||
|
@ -37,6 +39,7 @@ struct Context::Data {
|
|||
|
||||
struct DeviceInfoContext::Data {
|
||||
std::map<std::string, std::any> params;
|
||||
std::shared_ptr<Allocator> allocator = nullptr;
|
||||
};
|
||||
|
||||
Context::Context() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
|
||||
|
@ -97,6 +100,54 @@ std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
|
|||
|
||||
DeviceInfoContext::DeviceInfoContext() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
|
||||
|
||||
std::string DeviceInfoContext::GetProvider() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return "";
|
||||
}
|
||||
return GetValue<std::string>(data_, kModelOptionProvider);
|
||||
}
|
||||
|
||||
void DeviceInfoContext::SetProvider(const std::string &provider) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionProvider] = provider;
|
||||
}
|
||||
|
||||
std::string DeviceInfoContext::GetProviderDevice() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return "";
|
||||
}
|
||||
return GetValue<std::string>(data_, kModelOptionProviderDevice);
|
||||
}
|
||||
|
||||
void DeviceInfoContext::SetProviderDevice(const std::string &device) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionProviderDevice] = device;
|
||||
}
|
||||
|
||||
void DeviceInfoContext::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->allocator = allocator;
|
||||
}
|
||||
|
||||
std::shared_ptr<Allocator> DeviceInfoContext::GetAllocator() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return nullptr;
|
||||
}
|
||||
return data_->allocator;
|
||||
}
|
||||
|
||||
void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
|
|
|
@ -90,13 +90,15 @@ Status ModelImpl::Build() {
|
|||
}
|
||||
lite::DeviceInfo cpu_info = {0};
|
||||
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
|
||||
model_context.device_list_.push_back({lite::DT_CPU, cpu_info});
|
||||
model_context.device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
|
||||
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
|
||||
if (device_list.size() == 2) {
|
||||
lite::DeviceInfo device_info = {0};
|
||||
if (device_list[1]->GetDeviceType() == kMaliGPU) {
|
||||
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
|
||||
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
|
||||
model_context.device_list_.push_back({lite::DT_GPU, device_info});
|
||||
model_context.device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
|
||||
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
|
||||
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
|
||||
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
|
||||
device_info.npu_device_info_ = {npu_context->GetFrequency()};
|
||||
|
|
|
@ -155,7 +155,7 @@ int CustomSubGraph::Prepare() {
|
|||
}
|
||||
auto provider = nodes_[0]->desc().provider;
|
||||
auto context = this->Context();
|
||||
AllocatorPtr allocator = nullptr;
|
||||
AllocatorPtr allocator = context->allocator;
|
||||
auto iter = std::find_if(context->device_list_.begin(), context->device_list_.end(),
|
||||
[&provider](const auto &dev) { return dev.provider_ == provider; });
|
||||
if (iter != context->device_list_.end()) {
|
||||
|
@ -173,7 +173,7 @@ int CustomSubGraph::Prepare() {
|
|||
auto node = nodes_[nodes_.size() - 1];
|
||||
for (auto tensor : node->out_tensors()) {
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
tensor->set_allocator(this->Context()->allocator);
|
||||
tensor->set_allocator(context->allocator);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue