!6717 fix quant relative

Merge pull request !6717 from yankai10/0922_merge
This commit is contained in:
mindspore-ci-bot 2020-09-23 11:39:03 +08:00 committed by Gitee
commit a3260757c3
11 changed files with 89 additions and 69 deletions

View File

@ -46,6 +46,12 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
return RET_ERROR;
}
}
if (GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
SetOutputQuantParam(vecOutputQuantParam);
}
return RET_OK;
}

View File

@ -267,7 +267,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
if (GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
SetInputQuantParam(vecInputQuantParam);
SetOutputQuantParam(vecOutputQuantParam);
}

View File

@ -130,7 +130,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
if (GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
SetInputQuantParam(vecInputQuantParam);
SetOutputQuantParam(vecOutputQuantParam);
}

View File

@ -140,7 +140,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
if (GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
SetInputQuantParam(vecInputQuantParam);
SetOutputQuantParam(vecOutputQuantParam);
}

View File

@ -60,7 +60,7 @@ int MatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
if (GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
SetInputQuantParam(vecInputQuantParam);
SetOutputQuantParam(vecOutputQuantParam);
}

View File

@ -158,7 +158,8 @@ void PrimitiveC::CalQuantParam(const double &mean, const double &stdDev, float *
void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam,
const std::vector<AnfNodePtr> &inputs) {
auto narrow_range = prim.GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim.GetAttr("num_bits");
@ -179,12 +180,14 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
} else {
auto inputMin = prim.GetAttr("input_minq");
auto inputMax = prim.GetAttr("input_maxq");
auto inputMinPtr = inputMin->cast<TensorPtr>();
auto inputMaxPtr = inputMax->cast<TensorPtr>();
float *minBuf = static_cast<float *>(inputMinPtr->data_c());
float *maxBuf = static_cast<float *>(inputMaxPtr->data_c());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
if (inputMin != nullptr && inputMax != nullptr) {
auto inputMinPtr = inputMin->cast<TensorPtr>();
auto inputMaxPtr = inputMax->cast<TensorPtr>();
float *minBuf = static_cast<float *>(inputMinPtr->data_c());
float *maxBuf = static_cast<float *>(inputMaxPtr->data_c());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
}
}
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
@ -212,13 +215,15 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale;
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
if (vecInputQuantParam->size() == kDoubleNum) {
quants.clear();
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale;
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
auto outputMin = prim.GetAttr("output_minq");

View File

@ -39,8 +39,8 @@ constexpr uint32_t kDoubleNum = 2;
constexpr uint32_t kMultiNum = 3;
constexpr uint32_t kDimension_4d = 4;
const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt32,
kNumberTypeFloat32, kNumberTypeFloat16};
const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat32,
kNumberTypeFloat16};
#ifdef PRIMITIVE_WRITEABLE
using TensorPtr = std::shared_ptr<mindspore::tensor::Tensor>;
@ -119,7 +119,8 @@ class PrimitiveC : public mindspore::Primitive {
static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType);
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam,
const std::vector<AnfNodePtr> &inputs);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
protected:

View File

@ -98,29 +98,28 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
// activation
auto input_quant_params = primitive->GetInputQuantParams();
auto node_type = (schema::PrimitiveType)primitive->Type();
if (input_quant_params.empty()) {
MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty";
return RET_OK;
}
for (size_t i = 0; i < input_quant_params.size(); i++) {
if (i >= dst_node->inputIndex.size()) {
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
return RET_PARAM_INVALID;
}
auto activate_index = dst_node->inputIndex[i];
auto tensor_input = meta_graph->allTensors[activate_index].get();
if (tensor_input->quantParams.empty()) {
for (auto input_quant_param : input_quant_params[i]) {
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
if (!input_quant_params.empty()) {
for (size_t i = 0; i < input_quant_params.size(); i++) {
if (i >= dst_node->inputIndex.size()) {
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
return RET_PARAM_INVALID;
}
auto activate_index = dst_node->inputIndex[i];
auto tensor_input = meta_graph->allTensors[activate_index].get();
if (tensor_input->quantParams.empty()) {
for (auto input_quant_param : input_quant_params[i]) {
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
}
}
}
} else {
MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty";
}
// output
auto output_index = dst_node->outputIndex[0];
auto tensor_output = meta_graph->allTensors[output_index].get();
@ -171,7 +170,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &
}
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *return_node) {
schema::CNodeT *return_node) {
MS_ASSERT(nullptr != meta_graph);
MS_ASSERT(nullptr != return_node);
for (size_t i = 1; i < cnode->inputs().size(); i++) {
@ -210,9 +209,9 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem ||
primitive_c->Type() == schema::PrimitiveType_MakeTuple
#ifdef SUPPORT_TRAIN
|| primitive_c->Type() == schema::PrimitiveType_Depend
|| primitive_c->Type() == schema::PrimitiveType_Depend
#endif
) {
) {
continue;
}
RemoveIfMakeTuple(cnode);
@ -403,8 +402,7 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
if (value_track->isa<Int32Imm>()) {
shape.push_back((GetValue<int>(value_track)));
} else {
MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is "
<< value_track->ToString() << ".";
MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << ".";
}
}
if (shape.size()) {
@ -417,10 +415,10 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
}
} else {
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
}
} else {
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
}
#endif
} else if (value->isa<Number>()) {
MS_LOG(INFO) << "Value is a number.";

View File

@ -54,8 +54,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpDualInputList = {
static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = {
#ifdef SUPPORT_TRAIN
schema::PrimitiveType_PoolingGrad,
schema::PrimitiveType_ActivationGrad
schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad
#endif
};
@ -66,20 +65,21 @@ static const std::vector<schema::PrimitiveType> fp32FullOpList = {
static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {};
static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,
schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation,
schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection,
schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin,
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div,
schema::PrimitiveType_Mul, schema::PrimitiveType_Slice,
schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split,
schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze,
schema::PrimitiveType_MatMul, schema::PrimitiveType_Pad};
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,
schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation,
schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection,
schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin,
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div,
schema::PrimitiveType_Mul, schema::PrimitiveType_Slice,
schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split,
schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK,
schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul,
schema::PrimitiveType_Pad};
static const std::vector<schema::PrimitiveType> needInsertOpList = {
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,

View File

@ -16,6 +16,7 @@
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include <string>
#include <set>
#include "tools/common/converter_op_utils.h"
#include "tools/common/node_util.h"
#include "src/common/common.h"
@ -26,6 +27,9 @@ namespace lite {
#define kMinInputNum 1
#define kOutputNum 1
static const std::set<schema::PrimitiveType> NoNeedDtypeTransList = {
PrimitiveType_QuantDTypeCast, PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw};
STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
@ -134,7 +138,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
continue;
}
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) {
auto iterType = GetCNodeTType(**iter);
if (NoNeedDtypeTransList.find(iterType) != NoNeedDtypeTransList.end()) {
continue;
}
bool needInsertPost = true;

View File

@ -167,7 +167,11 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
auto &outTensor = graph->allTensors.at(node.outputIndex.at(i));
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);
if (outQuantParam == nullptr || outQuantParam->inited) {
if (outQuantParam == nullptr) {
outTensor->quantParams.emplace_back(std::move(inQuantParam));
continue;
}
if (outQuantParam->inited) {
continue;
}
outTensor->quantParams.front() = std::move(inQuantParam);
@ -232,7 +236,7 @@ class CalcConcat : public QuantParamCalcer {
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
return RET_ERROR;
}
outTensor->quantParams.front() = std::move(outQuantParam);
outTensor->quantParams.emplace_back(std::move(outQuantParam));
outputParamDone++;
}
@ -417,7 +421,7 @@ class CalcToSet : public QuantParamCalcer {
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
outTensor->quantParams.front() = std::move(quantParam);
outTensor->quantParams.emplace_back(std::move(quantParam));
outputParamDone++;
}
return RET_OK;
@ -475,6 +479,7 @@ QuantParamCalcRegister::QuantParamCalcRegister() {
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer;
_registerMap[schema::PrimitiveType_Resize] = linearCalcer;
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer;
_registerMap[schema::PrimitiveType_StridedSlice] = linearCalcer;
_registerMap[schema::PrimitiveType_Shape] = linearCalcer;
_registerMap[schema::PrimitiveType_SoftMax] = std::make_shared<CalcToSet>(0, 1);
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer;