!30126 [MS][LITE] add multiModelHW api for quant model

Merge pull request !30126 from zhengjun10/quant
This commit is contained in:
i-robot 2022-02-17 09:37:16 +00:00 committed by Gitee
commit c74d4c8007
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 36 additions and 4 deletions

View File

@ -99,6 +99,16 @@ class MS_API Context {
/// \return Pointer to the custom delegate.
std::shared_ptr<Delegate> GetDelegate() const;
/// \brief Set quant model to run as float model in multi device.
///
/// \param[in] float_mode: true, run as float model; false, not run as float model.
void SetMultiModalHW(bool float_mode);
/// \brief Get the mode of the model run.
///
/// \return Bool value that indicates whether run as float model
bool GetMultiModalHW() const;
/// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
/// heterogeneous scenarios with multiple members in the vector.
///

View File

@ -80,7 +80,7 @@ struct Context {
DeviceContextVector device_list_;
#endif // NOT_USE_STL
DelegatePtr delegate = nullptr;
bool float_mode = false;
bool float_mode = false; /**< convert full quant model to float model */
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_

View File

@ -160,6 +160,22 @@ std::shared_ptr<Delegate> Context::GetDelegate() const {
return data_->delegate;
}
void Context::SetMultiModalHW(bool float_mode) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->float_mode = float_mode;
}
bool Context::GetMultiModalHW() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return false;
}
return data_->float_mode;
}
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
static std::vector<std::shared_ptr<DeviceInfoContext>> empty{};
if (data_ == nullptr) {

View File

@ -36,6 +36,7 @@ struct Context::Data {
std::vector<int32_t> affinity_core_list_;
int affinity_mode_ = 0;
std::shared_ptr<Delegate> delegate = nullptr;
bool float_mode = false;
};
struct DeviceInfoContext::Data {

View File

@ -21,11 +21,13 @@ constexpr static int kMaxNumOfDevices = 3;
void ContextUtils::SetContextAttr(int32_t thread_num, bool enable_parallel,
const std::vector<int32_t> &affinity_core_list,
const std::shared_ptr<Delegate> &delegate, lite::InnerContext *inner_context) {
const std::shared_ptr<Delegate> &delegate, lite::InnerContext *inner_context,
bool float_mode) {
inner_context->thread_num_ = thread_num;
inner_context->enable_parallel_ = enable_parallel;
inner_context->affinity_core_list_ = affinity_core_list;
inner_context->delegate = delegate;
inner_context->float_mode = float_mode;
}
Status ContextUtils::AddCpuDevice(const std::shared_ptr<Allocator> &allocator, int affinity_mode, bool enable_fp16,
@ -81,7 +83,7 @@ lite::InnerContext *ContextUtils::Convert(Context *context) {
return nullptr;
}
SetContextAttr(context->GetThreadNum(), context->GetEnableParallel(), context->GetThreadAffinityCoreList(),
context->GetDelegate(), inner_context.get());
context->GetDelegate(), inner_context.get(), context->GetMultiModalHW());
inner_context->device_list_.clear();
Status ret = kLiteError;
for (auto &device : device_list) {

View File

@ -34,7 +34,8 @@ class ContextUtils {
private:
static void SetContextAttr(int32_t thread_num, bool enable_parallel, const std::vector<int32_t> &affinity_core_list,
const std::shared_ptr<Delegate> &delegate, lite::InnerContext *inner_context);
const std::shared_ptr<Delegate> &delegate, lite::InnerContext *inner_context,
bool float_mode = false);
static Status AddCpuDevice(const std::shared_ptr<Allocator> &allocator, int affinity_mode, bool enable_fp16,
const std::string &provider, const std::string &provider_device,
lite::InnerContext *inner_context);

View File

@ -54,6 +54,7 @@ InnerContext::InnerContext(const Context *context) {
this->affinity_core_list_ = context->affinity_core_list_;
SetContextDevice(context);
this->delegate = context->delegate;
this->float_mode = context->float_mode;
}
InitDeviceFp16();
}

View File

@ -347,6 +347,7 @@ STATUS Scheduler::DelQuantDTypeCastKernel(std::vector<kernel::LiteKernel *> *ker
*tensor_iter = cur_kernel->in_tensors()[0];
}
iter = kernels->erase(iter);
MS_LOG(DEBUG) << "Delete kernel: " << cur_kernel->name();
delete cur_kernel;
}
return RET_OK;