support full quant restore to fp32

This commit is contained in:
zhengjun10 2022-02-16 10:43:16 +08:00
parent 7e6f3dec6d
commit 92954179a0
6 changed files with 123 additions and 12 deletions

View File

@ -80,6 +80,7 @@ struct Context {
DeviceContextVector device_list_;
#endif // NOT_USE_STL
DelegatePtr delegate = nullptr;
bool float_mode = false;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_

View File

@ -503,6 +503,38 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
}
}
void LiteSession::UpdateGraphOutputMap(const std::vector<kernel::LiteKernel *> &kernels) {
if (!context_->float_mode) {
return;
}
output_tensor_map_.clear();
output_tensor_names_.clear();
output_node_map_.clear();
// output kernels may be modified during schedule runtime pass,need update node output tensor and node map
for (auto iter_kernel : kernels) {
auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(iter_kernel);
for (auto out_kernel : subgraph->out_nodes()) {
if (out_kernel->is_model_output()) {
for (auto out_tensor : out_kernel->out_tensors())
if (out_tensor->IsGraphOutput()) {
output_node_map_[out_kernel->name()].emplace_back(out_tensor);
if (!out_tensor->tensor_name().empty()) {
output_tensor_map_[out_tensor->tensor_name()] = out_tensor;
output_tensor_names_.emplace_back(out_tensor->tensor_name());
} else {
auto tensor_iter = std::find(tensors_.begin(), tensors_.end(), out_tensor);
if (tensor_iter != tensors_.end()) {
auto tensor_index = std::to_string(tensor_iter - tensors_.begin());
output_tensor_map_[tensor_index] = out_tensor;
output_tensor_names_.emplace_back(tensor_index);
}
}
}
}
}
}
}
void LiteSession::AdjustModelOutputTensorInitRefCount(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto graph_out_size = model->output_indices_.size();
@ -675,7 +707,7 @@ int LiteSession::CompileGraph(Model *model) {
InitGraphOutputTensors(model);
// scheduler kernels
Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, &is_infershape_,
Scheduler scheduler(context_, ms_context_, model, &tensors_, &inputs_, &outputs_, is_train_session_, &is_infershape_,
&is_control_flow_, execution_plan_, delegate_, delegate_device_type_);
scheduler.SetupSchedulerCb(std::move(sched_cb_));
scheduler.SetConfig(config_info_);
@ -686,6 +718,7 @@ int LiteSession::CompileGraph(Model *model) {
return ret;
}
InitGraphInOutTensorsMap(model);
UpdateGraphOutputMap(kernels_);
non_tail_call_kernels_ = scheduler.NonTailCallNodes();

View File

@ -135,6 +135,7 @@ class LiteSession : public session::LiteSession {
private:
int IsolateOutputTensor();
bool IsIsolatedSubGraph(const kernel::LiteKernel *kernel);
void UpdateGraphOutputMap(const std::vector<kernel::LiteKernel *> &kernel);
void UpdateLinkInfoForIsolateOutput();
std::unordered_map<Tensor *, Tensor *> isolate_graph_output_map_; /* <calculate-tensor, graph-output-tensor> */
std::unordered_map<Tensor *, Tensor *> isolate_input_map_; /* <calculate-tensor, src-subgraph-input-tensor> */

View File

@ -300,6 +300,58 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *dst_kernels
return ConstructNormalSubGraphs(src_kernel, dst_kernels, &is_kernel_finish);
}
STATUS Scheduler::DelQuantDTypeCastKernel(std::vector<kernel::LiteKernel *> *kernels) {
for (auto iter = (*kernels).begin(); iter != (*kernels).end();) {
auto cur_kernel = *iter;
if (cur_kernel->subgraph_type() != kernel::kNotSubGraph) {
auto sub_inner_graph = reinterpret_cast<kernel::SubGraphKernel *>(cur_kernel);
auto &subgraph_nodes = sub_inner_graph->nodes();
if (DelQuantDTypeCastKernel(&subgraph_nodes) != RET_OK) {
MS_LOG(ERROR) << "DeleteRedundantTrans failed in subgraph.";
return RET_ERROR;
}
}
if (cur_kernel->type() != schema::PrimitiveType_QuantDTypeCast) {
iter++;
continue;
}
auto &post_kernels = cur_kernel->out_kernels();
auto &pre_kernels = cur_kernel->in_kernels();
if (pre_kernels.size() != 1 || cur_kernel->in_tensors().size() != 1) {
MS_LOG(ERROR) << "kernel input size error.";
return RET_ERROR;
}
// modify post kernel input to new kernel and new tensor
for (auto post_kernel : post_kernels) {
auto post_in_kernels = post_kernel->in_kernels();
auto post_input_iter = std::find(post_in_kernels.begin(), post_in_kernels.end(), cur_kernel);
*post_input_iter = pre_kernels[0];
post_kernel->set_in_tensor(cur_kernel->in_tensors()[0], post_input_iter - post_in_kernels.begin());
post_kernel->set_in_kernels(post_in_kernels);
}
auto pre_out_kernels = pre_kernels[0]->out_kernels();
auto pre_out_iter = std::find(pre_out_kernels.begin(), pre_out_kernels.end(), cur_kernel);
if (pre_out_iter != pre_out_kernels.end()) {
pre_out_kernels.erase(pre_out_iter);
pre_out_kernels.insert(pre_out_iter, post_kernels.begin(), post_kernels.end());
pre_kernels[0]->set_out_kernels(pre_kernels);
}
// update model output
if (cur_kernel->is_model_output()) {
pre_kernels[0]->set_is_model_output(true);
cur_kernel->in_tensors()[0]->set_category(Category::GRAPH_OUTPUT);
pre_kernels[0]->set_out_kernels({});
}
auto tensor_iter = std::find((*outputs_).begin(), (*outputs_).end(), cur_kernel->out_tensors()[0]);
if (tensor_iter != (*outputs_).end()) {
*tensor_iter = cur_kernel->in_tensors()[0];
}
iter = kernels->erase(iter);
delete cur_kernel;
}
return RET_OK;
}
int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
int check_input_ret = CheckInputParam(dst_kernels);
if (check_input_ret != RET_OK) {
@ -329,6 +381,14 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
MS_LOG(ERROR) << "Schedule graph to kernels failed.";
return ret;
}
if (context_->float_mode) {
kernel::LiteKernelUtil::FindAllInoutKernels(*dst_kernels);
ret = DelQuantDTypeCastKernel(dst_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Delete quant_dtype_cast kernel failed.";
return ret;
}
}
#ifndef DELEGATE_CLIP
// Free the output tensor data of shape fusion.
@ -420,8 +480,8 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_ker
kernels.push_back((*dst_kernels)[i]->kernel());
}
ms_inputs_ = LiteTensorsToMSTensors(inputs_);
ms_outputs_ = LiteTensorsToMSTensors(outputs_);
ms_inputs_ = LiteTensorsToMSTensors(*inputs_);
ms_outputs_ = LiteTensorsToMSTensors(*outputs_);
auto schema_version = static_cast<SchemaVersion>(schema_version_);
DelegateModel<schema::Primitive> *model =
new (std::nothrow) DelegateModel<schema::Primitive>(&kernels, ms_inputs_, ms_outputs_, primitives_, schema_version);
@ -937,6 +997,9 @@ int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std:
cpu_desc.data_type = kNumberTypeFloat16;
}
int ret;
if (context_->float_mode && op_parameter->quant_type_ == schema::QuantType_QUANT_ALL) {
op_parameter->quant_type_ = schema::QuantType_QUANT_WEIGHT;
}
#ifndef WEIGHT_DECODE_CLIP
ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type, src_model_->version_);
if (ret != RET_OK) {
@ -1075,6 +1138,14 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
} else {
data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
}
if (context_->float_mode) {
data_type = kNumberTypeFloat32;
for (auto tensor : in_tensors) {
if (tensor->data() == nullptr) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
}
kernel::LiteKernel *kernel = nullptr;
int status;
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP

View File

@ -44,8 +44,8 @@ constexpr int kDefaultDeviceType = -1;
class Scheduler {
public:
Scheduler(InnerContext *ctx, const mindspore::Context *ms_ctx, Model *src_model, std::vector<Tensor *> *src_tensors,
const std::vector<Tensor *> &input_tensors, const std::vector<Tensor *> &output_tensors,
bool is_train_session, int *is_infershape, bool *is_control_flow, std::map<std::string, TypeId> *executions,
std::vector<Tensor *> *input_tensors, std::vector<Tensor *> *output_tensors, bool is_train_session,
int *is_infershape, bool *is_control_flow, std::map<std::string, TypeId> *executions,
std::shared_ptr<Delegate> delegate = nullptr, int delegate_device_type = -1)
: context_(ctx),
ms_context_(ms_ctx),
@ -127,6 +127,7 @@ class Scheduler {
int RestoreSubGraphInput(const lite::Model::Node *partial_node);
bool IsControlFlowPattern(const lite::Model::Node &partial_node);
STATUS DelQuantDTypeCastKernel(std::vector<kernel::LiteKernel *> *kernels);
#ifdef ENABLE_FP16
int SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type);
#endif
@ -155,8 +156,8 @@ class Scheduler {
const mindspore::Context *ms_context_ = nullptr;
Model *src_model_ = nullptr;
std::vector<Tensor *> *src_tensors_;
const std::vector<Tensor *> &inputs_;
const std::vector<Tensor *> &outputs_;
std::vector<Tensor *> *inputs_;
std::vector<Tensor *> *outputs_;
std::vector<mindspore::MSTensor> ms_inputs_;
std::vector<mindspore::MSTensor> ms_outputs_;
std::vector<size_t> graph_output_node_indexes_;

View File

@ -218,10 +218,6 @@ int GetDataIndex(const std::vector<int> &dims, int preferred_dim, int bucket_ind
int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, int preferred_dim, TypeId dst_data_type) {
MS_ASSERT(input_tensor != nullptr);
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
return RET_ERROR;
}
if (input_tensor->quant_params().empty()) {
MS_LOG(ERROR) << "No quant param.";
return RET_ERROR;
@ -264,6 +260,13 @@ int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, int preferred_dim,
MS_LOG(ERROR) << "Float16 is not supported";
return RET_NOT_SUPPORT;
#endif
} else if (input_tensor->data_type() == kNumberTypeInt32 && dst_data_type == kNumberTypeFloat32) {
auto new_const_data = DequantData<int32_t, float>(input_tensor, preferred_dim);
CHECK_NULL_RETURN(new_const_data);
input_tensor->FreeData();
input_tensor->set_data(new_const_data);
input_tensor->set_own_data(true);
input_tensor->set_data_type(dst_data_type);
} else {
MS_LOG(ERROR) << "Unsupported dequant from data_type(" << (input_tensor->data_type()) << ") to data_type("
<< dst_data_type << ")";
@ -382,7 +385,8 @@ int WeightDecoder::DequantTensor(Tensor *tensor, int preferred_dim, TypeId dst_d
return RET_NO_CHANGE;
}
bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited &&
(tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16);
(tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16 ||
tensor->data_type() == kNumberTypeInt32);
if (!need_dequant) {
return RET_NO_CHANGE;
}