forked from mindspore-Ecosystem/mindspore
support full quant restore to fp32
This commit is contained in:
parent
7e6f3dec6d
commit
92954179a0
|
@ -80,6 +80,7 @@ struct Context {
|
||||||
DeviceContextVector device_list_;
|
DeviceContextVector device_list_;
|
||||||
#endif // NOT_USE_STL
|
#endif // NOT_USE_STL
|
||||||
DelegatePtr delegate = nullptr;
|
DelegatePtr delegate = nullptr;
|
||||||
|
bool float_mode = false;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
#endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_
|
#endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_
|
||||||
|
|
|
@ -503,6 +503,38 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LiteSession::UpdateGraphOutputMap(const std::vector<kernel::LiteKernel *> &kernels) {
|
||||||
|
if (!context_->float_mode) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
output_tensor_map_.clear();
|
||||||
|
output_tensor_names_.clear();
|
||||||
|
output_node_map_.clear();
|
||||||
|
// output kernels may be modified during schedule runtime pass,need update node output tensor and node map
|
||||||
|
for (auto iter_kernel : kernels) {
|
||||||
|
auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(iter_kernel);
|
||||||
|
for (auto out_kernel : subgraph->out_nodes()) {
|
||||||
|
if (out_kernel->is_model_output()) {
|
||||||
|
for (auto out_tensor : out_kernel->out_tensors())
|
||||||
|
if (out_tensor->IsGraphOutput()) {
|
||||||
|
output_node_map_[out_kernel->name()].emplace_back(out_tensor);
|
||||||
|
if (!out_tensor->tensor_name().empty()) {
|
||||||
|
output_tensor_map_[out_tensor->tensor_name()] = out_tensor;
|
||||||
|
output_tensor_names_.emplace_back(out_tensor->tensor_name());
|
||||||
|
} else {
|
||||||
|
auto tensor_iter = std::find(tensors_.begin(), tensors_.end(), out_tensor);
|
||||||
|
if (tensor_iter != tensors_.end()) {
|
||||||
|
auto tensor_index = std::to_string(tensor_iter - tensors_.begin());
|
||||||
|
output_tensor_map_[tensor_index] = out_tensor;
|
||||||
|
output_tensor_names_.emplace_back(tensor_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void LiteSession::AdjustModelOutputTensorInitRefCount(const lite::Model *model) {
|
void LiteSession::AdjustModelOutputTensorInitRefCount(const lite::Model *model) {
|
||||||
MS_ASSERT(model != nullptr);
|
MS_ASSERT(model != nullptr);
|
||||||
auto graph_out_size = model->output_indices_.size();
|
auto graph_out_size = model->output_indices_.size();
|
||||||
|
@ -675,7 +707,7 @@ int LiteSession::CompileGraph(Model *model) {
|
||||||
InitGraphOutputTensors(model);
|
InitGraphOutputTensors(model);
|
||||||
|
|
||||||
// scheduler kernels
|
// scheduler kernels
|
||||||
Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, &is_infershape_,
|
Scheduler scheduler(context_, ms_context_, model, &tensors_, &inputs_, &outputs_, is_train_session_, &is_infershape_,
|
||||||
&is_control_flow_, execution_plan_, delegate_, delegate_device_type_);
|
&is_control_flow_, execution_plan_, delegate_, delegate_device_type_);
|
||||||
scheduler.SetupSchedulerCb(std::move(sched_cb_));
|
scheduler.SetupSchedulerCb(std::move(sched_cb_));
|
||||||
scheduler.SetConfig(config_info_);
|
scheduler.SetConfig(config_info_);
|
||||||
|
@ -686,6 +718,7 @@ int LiteSession::CompileGraph(Model *model) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
InitGraphInOutTensorsMap(model);
|
InitGraphInOutTensorsMap(model);
|
||||||
|
UpdateGraphOutputMap(kernels_);
|
||||||
|
|
||||||
non_tail_call_kernels_ = scheduler.NonTailCallNodes();
|
non_tail_call_kernels_ = scheduler.NonTailCallNodes();
|
||||||
|
|
||||||
|
|
|
@ -135,6 +135,7 @@ class LiteSession : public session::LiteSession {
|
||||||
private:
|
private:
|
||||||
int IsolateOutputTensor();
|
int IsolateOutputTensor();
|
||||||
bool IsIsolatedSubGraph(const kernel::LiteKernel *kernel);
|
bool IsIsolatedSubGraph(const kernel::LiteKernel *kernel);
|
||||||
|
void UpdateGraphOutputMap(const std::vector<kernel::LiteKernel *> &kernel);
|
||||||
void UpdateLinkInfoForIsolateOutput();
|
void UpdateLinkInfoForIsolateOutput();
|
||||||
std::unordered_map<Tensor *, Tensor *> isolate_graph_output_map_; /* <calculate-tensor, graph-output-tensor> */
|
std::unordered_map<Tensor *, Tensor *> isolate_graph_output_map_; /* <calculate-tensor, graph-output-tensor> */
|
||||||
std::unordered_map<Tensor *, Tensor *> isolate_input_map_; /* <calculate-tensor, src-subgraph-input-tensor> */
|
std::unordered_map<Tensor *, Tensor *> isolate_input_map_; /* <calculate-tensor, src-subgraph-input-tensor> */
|
||||||
|
|
|
@ -300,6 +300,58 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *dst_kernels
|
||||||
return ConstructNormalSubGraphs(src_kernel, dst_kernels, &is_kernel_finish);
|
return ConstructNormalSubGraphs(src_kernel, dst_kernels, &is_kernel_finish);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
STATUS Scheduler::DelQuantDTypeCastKernel(std::vector<kernel::LiteKernel *> *kernels) {
|
||||||
|
for (auto iter = (*kernels).begin(); iter != (*kernels).end();) {
|
||||||
|
auto cur_kernel = *iter;
|
||||||
|
if (cur_kernel->subgraph_type() != kernel::kNotSubGraph) {
|
||||||
|
auto sub_inner_graph = reinterpret_cast<kernel::SubGraphKernel *>(cur_kernel);
|
||||||
|
auto &subgraph_nodes = sub_inner_graph->nodes();
|
||||||
|
if (DelQuantDTypeCastKernel(&subgraph_nodes) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "DeleteRedundantTrans failed in subgraph.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cur_kernel->type() != schema::PrimitiveType_QuantDTypeCast) {
|
||||||
|
iter++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto &post_kernels = cur_kernel->out_kernels();
|
||||||
|
auto &pre_kernels = cur_kernel->in_kernels();
|
||||||
|
if (pre_kernels.size() != 1 || cur_kernel->in_tensors().size() != 1) {
|
||||||
|
MS_LOG(ERROR) << "kernel input size error.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
// modify post kernel input to new kernel and new tensor
|
||||||
|
for (auto post_kernel : post_kernels) {
|
||||||
|
auto post_in_kernels = post_kernel->in_kernels();
|
||||||
|
auto post_input_iter = std::find(post_in_kernels.begin(), post_in_kernels.end(), cur_kernel);
|
||||||
|
*post_input_iter = pre_kernels[0];
|
||||||
|
post_kernel->set_in_tensor(cur_kernel->in_tensors()[0], post_input_iter - post_in_kernels.begin());
|
||||||
|
post_kernel->set_in_kernels(post_in_kernels);
|
||||||
|
}
|
||||||
|
auto pre_out_kernels = pre_kernels[0]->out_kernels();
|
||||||
|
auto pre_out_iter = std::find(pre_out_kernels.begin(), pre_out_kernels.end(), cur_kernel);
|
||||||
|
if (pre_out_iter != pre_out_kernels.end()) {
|
||||||
|
pre_out_kernels.erase(pre_out_iter);
|
||||||
|
pre_out_kernels.insert(pre_out_iter, post_kernels.begin(), post_kernels.end());
|
||||||
|
pre_kernels[0]->set_out_kernels(pre_kernels);
|
||||||
|
}
|
||||||
|
// update model output
|
||||||
|
if (cur_kernel->is_model_output()) {
|
||||||
|
pre_kernels[0]->set_is_model_output(true);
|
||||||
|
cur_kernel->in_tensors()[0]->set_category(Category::GRAPH_OUTPUT);
|
||||||
|
pre_kernels[0]->set_out_kernels({});
|
||||||
|
}
|
||||||
|
auto tensor_iter = std::find((*outputs_).begin(), (*outputs_).end(), cur_kernel->out_tensors()[0]);
|
||||||
|
if (tensor_iter != (*outputs_).end()) {
|
||||||
|
*tensor_iter = cur_kernel->in_tensors()[0];
|
||||||
|
}
|
||||||
|
iter = kernels->erase(iter);
|
||||||
|
delete cur_kernel;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
||||||
int check_input_ret = CheckInputParam(dst_kernels);
|
int check_input_ret = CheckInputParam(dst_kernels);
|
||||||
if (check_input_ret != RET_OK) {
|
if (check_input_ret != RET_OK) {
|
||||||
|
@ -329,6 +381,14 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
||||||
MS_LOG(ERROR) << "Schedule graph to kernels failed.";
|
MS_LOG(ERROR) << "Schedule graph to kernels failed.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
if (context_->float_mode) {
|
||||||
|
kernel::LiteKernelUtil::FindAllInoutKernels(*dst_kernels);
|
||||||
|
ret = DelQuantDTypeCastKernel(dst_kernels);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Delete quant_dtype_cast kernel failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#ifndef DELEGATE_CLIP
|
#ifndef DELEGATE_CLIP
|
||||||
// Free the output tensor data of shape fusion.
|
// Free the output tensor data of shape fusion.
|
||||||
|
@ -420,8 +480,8 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_ker
|
||||||
kernels.push_back((*dst_kernels)[i]->kernel());
|
kernels.push_back((*dst_kernels)[i]->kernel());
|
||||||
}
|
}
|
||||||
|
|
||||||
ms_inputs_ = LiteTensorsToMSTensors(inputs_);
|
ms_inputs_ = LiteTensorsToMSTensors(*inputs_);
|
||||||
ms_outputs_ = LiteTensorsToMSTensors(outputs_);
|
ms_outputs_ = LiteTensorsToMSTensors(*outputs_);
|
||||||
auto schema_version = static_cast<SchemaVersion>(schema_version_);
|
auto schema_version = static_cast<SchemaVersion>(schema_version_);
|
||||||
DelegateModel<schema::Primitive> *model =
|
DelegateModel<schema::Primitive> *model =
|
||||||
new (std::nothrow) DelegateModel<schema::Primitive>(&kernels, ms_inputs_, ms_outputs_, primitives_, schema_version);
|
new (std::nothrow) DelegateModel<schema::Primitive>(&kernels, ms_inputs_, ms_outputs_, primitives_, schema_version);
|
||||||
|
@ -937,6 +997,9 @@ int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std:
|
||||||
cpu_desc.data_type = kNumberTypeFloat16;
|
cpu_desc.data_type = kNumberTypeFloat16;
|
||||||
}
|
}
|
||||||
int ret;
|
int ret;
|
||||||
|
if (context_->float_mode && op_parameter->quant_type_ == schema::QuantType_QUANT_ALL) {
|
||||||
|
op_parameter->quant_type_ = schema::QuantType_QUANT_WEIGHT;
|
||||||
|
}
|
||||||
#ifndef WEIGHT_DECODE_CLIP
|
#ifndef WEIGHT_DECODE_CLIP
|
||||||
ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type, src_model_->version_);
|
ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type, src_model_->version_);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
|
@ -1075,6 +1138,14 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
||||||
} else {
|
} else {
|
||||||
data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
|
data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
|
||||||
}
|
}
|
||||||
|
if (context_->float_mode) {
|
||||||
|
data_type = kNumberTypeFloat32;
|
||||||
|
for (auto tensor : in_tensors) {
|
||||||
|
if (tensor->data() == nullptr) {
|
||||||
|
tensor->set_data_type(kNumberTypeFloat32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
kernel::LiteKernel *kernel = nullptr;
|
kernel::LiteKernel *kernel = nullptr;
|
||||||
int status;
|
int status;
|
||||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||||
|
|
|
@ -44,8 +44,8 @@ constexpr int kDefaultDeviceType = -1;
|
||||||
class Scheduler {
|
class Scheduler {
|
||||||
public:
|
public:
|
||||||
Scheduler(InnerContext *ctx, const mindspore::Context *ms_ctx, Model *src_model, std::vector<Tensor *> *src_tensors,
|
Scheduler(InnerContext *ctx, const mindspore::Context *ms_ctx, Model *src_model, std::vector<Tensor *> *src_tensors,
|
||||||
const std::vector<Tensor *> &input_tensors, const std::vector<Tensor *> &output_tensors,
|
std::vector<Tensor *> *input_tensors, std::vector<Tensor *> *output_tensors, bool is_train_session,
|
||||||
bool is_train_session, int *is_infershape, bool *is_control_flow, std::map<std::string, TypeId> *executions,
|
int *is_infershape, bool *is_control_flow, std::map<std::string, TypeId> *executions,
|
||||||
std::shared_ptr<Delegate> delegate = nullptr, int delegate_device_type = -1)
|
std::shared_ptr<Delegate> delegate = nullptr, int delegate_device_type = -1)
|
||||||
: context_(ctx),
|
: context_(ctx),
|
||||||
ms_context_(ms_ctx),
|
ms_context_(ms_ctx),
|
||||||
|
@ -127,6 +127,7 @@ class Scheduler {
|
||||||
int RestoreSubGraphInput(const lite::Model::Node *partial_node);
|
int RestoreSubGraphInput(const lite::Model::Node *partial_node);
|
||||||
|
|
||||||
bool IsControlFlowPattern(const lite::Model::Node &partial_node);
|
bool IsControlFlowPattern(const lite::Model::Node &partial_node);
|
||||||
|
STATUS DelQuantDTypeCastKernel(std::vector<kernel::LiteKernel *> *kernels);
|
||||||
#ifdef ENABLE_FP16
|
#ifdef ENABLE_FP16
|
||||||
int SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type);
|
int SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type);
|
||||||
#endif
|
#endif
|
||||||
|
@ -155,8 +156,8 @@ class Scheduler {
|
||||||
const mindspore::Context *ms_context_ = nullptr;
|
const mindspore::Context *ms_context_ = nullptr;
|
||||||
Model *src_model_ = nullptr;
|
Model *src_model_ = nullptr;
|
||||||
std::vector<Tensor *> *src_tensors_;
|
std::vector<Tensor *> *src_tensors_;
|
||||||
const std::vector<Tensor *> &inputs_;
|
std::vector<Tensor *> *inputs_;
|
||||||
const std::vector<Tensor *> &outputs_;
|
std::vector<Tensor *> *outputs_;
|
||||||
std::vector<mindspore::MSTensor> ms_inputs_;
|
std::vector<mindspore::MSTensor> ms_inputs_;
|
||||||
std::vector<mindspore::MSTensor> ms_outputs_;
|
std::vector<mindspore::MSTensor> ms_outputs_;
|
||||||
std::vector<size_t> graph_output_node_indexes_;
|
std::vector<size_t> graph_output_node_indexes_;
|
||||||
|
|
|
@ -218,10 +218,6 @@ int GetDataIndex(const std::vector<int> &dims, int preferred_dim, int bucket_ind
|
||||||
|
|
||||||
int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, int preferred_dim, TypeId dst_data_type) {
|
int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, int preferred_dim, TypeId dst_data_type) {
|
||||||
MS_ASSERT(input_tensor != nullptr);
|
MS_ASSERT(input_tensor != nullptr);
|
||||||
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
|
|
||||||
MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
if (input_tensor->quant_params().empty()) {
|
if (input_tensor->quant_params().empty()) {
|
||||||
MS_LOG(ERROR) << "No quant param.";
|
MS_LOG(ERROR) << "No quant param.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -264,6 +260,13 @@ int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, int preferred_dim,
|
||||||
MS_LOG(ERROR) << "Float16 is not supported";
|
MS_LOG(ERROR) << "Float16 is not supported";
|
||||||
return RET_NOT_SUPPORT;
|
return RET_NOT_SUPPORT;
|
||||||
#endif
|
#endif
|
||||||
|
} else if (input_tensor->data_type() == kNumberTypeInt32 && dst_data_type == kNumberTypeFloat32) {
|
||||||
|
auto new_const_data = DequantData<int32_t, float>(input_tensor, preferred_dim);
|
||||||
|
CHECK_NULL_RETURN(new_const_data);
|
||||||
|
input_tensor->FreeData();
|
||||||
|
input_tensor->set_data(new_const_data);
|
||||||
|
input_tensor->set_own_data(true);
|
||||||
|
input_tensor->set_data_type(dst_data_type);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Unsupported dequant from data_type(" << (input_tensor->data_type()) << ") to data_type("
|
MS_LOG(ERROR) << "Unsupported dequant from data_type(" << (input_tensor->data_type()) << ") to data_type("
|
||||||
<< dst_data_type << ")";
|
<< dst_data_type << ")";
|
||||||
|
@ -382,7 +385,8 @@ int WeightDecoder::DequantTensor(Tensor *tensor, int preferred_dim, TypeId dst_d
|
||||||
return RET_NO_CHANGE;
|
return RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited &&
|
bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited &&
|
||||||
(tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16);
|
(tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16 ||
|
||||||
|
tensor->data_type() == kNumberTypeInt32);
|
||||||
if (!need_dequant) {
|
if (!need_dequant) {
|
||||||
return RET_NO_CHANGE;
|
return RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue