forked from mindspore-Ecosystem/mindspore
!6717 fix quant relative
Merge pull request !6717 from yankai10/0922_merge
This commit is contained in:
commit
a3260757c3
|
@ -46,6 +46,12 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
|
||||
SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -267,7 +267,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
if (GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
|
||||
SetInputQuantParam(vecInputQuantParam);
|
||||
SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
|
|
|
@ -130,7 +130,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
|
|||
if (GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
|
||||
SetInputQuantParam(vecInputQuantParam);
|
||||
SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
|
|
|
@ -140,7 +140,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
|||
if (GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
|
||||
SetInputQuantParam(vecInputQuantParam);
|
||||
SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ int MatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
if (GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
||||
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
||||
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
|
||||
SetInputQuantParam(vecInputQuantParam);
|
||||
SetOutputQuantParam(vecOutputQuantParam);
|
||||
}
|
||||
|
|
|
@ -158,7 +158,8 @@ void PrimitiveC::CalQuantParam(const double &mean, const double &stdDev, float *
|
|||
|
||||
void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
auto narrow_range = prim.GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
||||
auto num_bits = prim.GetAttr("num_bits");
|
||||
|
@ -179,12 +180,14 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
|
|||
} else {
|
||||
auto inputMin = prim.GetAttr("input_minq");
|
||||
auto inputMax = prim.GetAttr("input_maxq");
|
||||
auto inputMinPtr = inputMin->cast<TensorPtr>();
|
||||
auto inputMaxPtr = inputMax->cast<TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(inputMinPtr->data_c());
|
||||
float *maxBuf = static_cast<float *>(inputMaxPtr->data_c());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
if (inputMin != nullptr && inputMax != nullptr) {
|
||||
auto inputMinPtr = inputMin->cast<TensorPtr>();
|
||||
auto inputMaxPtr = inputMax->cast<TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(inputMinPtr->data_c());
|
||||
float *maxBuf = static_cast<float *>(inputMaxPtr->data_c());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
}
|
||||
}
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
|
@ -212,13 +215,15 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
|
|||
vecInputQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
quantParam.min = 0.0;
|
||||
quantParam.max = 0.0;
|
||||
quantParam.zeroPoint = 0;
|
||||
quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale;
|
||||
quants.emplace_back(quantParam);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
if (vecInputQuantParam->size() == kDoubleNum) {
|
||||
quants.clear();
|
||||
quantParam.min = 0.0;
|
||||
quantParam.max = 0.0;
|
||||
quantParam.zeroPoint = 0;
|
||||
quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale;
|
||||
quants.emplace_back(quantParam);
|
||||
vecInputQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
auto outputMin = prim.GetAttr("output_minq");
|
||||
|
|
|
@ -39,8 +39,8 @@ constexpr uint32_t kDoubleNum = 2;
|
|||
constexpr uint32_t kMultiNum = 3;
|
||||
constexpr uint32_t kDimension_4d = 4;
|
||||
|
||||
const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt32,
|
||||
kNumberTypeFloat32, kNumberTypeFloat16};
|
||||
const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat32,
|
||||
kNumberTypeFloat16};
|
||||
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
using TensorPtr = std::shared_ptr<mindspore::tensor::Tensor>;
|
||||
|
@ -119,7 +119,8 @@ class PrimitiveC : public mindspore::Primitive {
|
|||
static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
|
||||
const schema::QuantType &quantType);
|
||||
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam,
|
||||
const std::vector<AnfNodePtr> &inputs);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
|
||||
|
||||
protected:
|
||||
|
|
|
@ -98,29 +98,28 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
// activation
|
||||
auto input_quant_params = primitive->GetInputQuantParams();
|
||||
auto node_type = (schema::PrimitiveType)primitive->Type();
|
||||
if (input_quant_params.empty()) {
|
||||
MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty";
|
||||
return RET_OK;
|
||||
}
|
||||
for (size_t i = 0; i < input_quant_params.size(); i++) {
|
||||
if (i >= dst_node->inputIndex.size()) {
|
||||
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
|
||||
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto activate_index = dst_node->inputIndex[i];
|
||||
auto tensor_input = meta_graph->allTensors[activate_index].get();
|
||||
if (tensor_input->quantParams.empty()) {
|
||||
for (auto input_quant_param : input_quant_params[i]) {
|
||||
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
|
||||
std::make_unique<schema::QuantParamT>(input_quant_param);
|
||||
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
|
||||
<< " zp: " << input_quant_param_ptr->zeroPoint;
|
||||
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
|
||||
if (!input_quant_params.empty()) {
|
||||
for (size_t i = 0; i < input_quant_params.size(); i++) {
|
||||
if (i >= dst_node->inputIndex.size()) {
|
||||
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
|
||||
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto activate_index = dst_node->inputIndex[i];
|
||||
auto tensor_input = meta_graph->allTensors[activate_index].get();
|
||||
if (tensor_input->quantParams.empty()) {
|
||||
for (auto input_quant_param : input_quant_params[i]) {
|
||||
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
|
||||
std::make_unique<schema::QuantParamT>(input_quant_param);
|
||||
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
|
||||
<< " zp: " << input_quant_param_ptr->zeroPoint;
|
||||
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty";
|
||||
}
|
||||
|
||||
// output
|
||||
auto output_index = dst_node->outputIndex[0];
|
||||
auto tensor_output = meta_graph->allTensors[output_index].get();
|
||||
|
@ -171,7 +170,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &
|
|||
}
|
||||
|
||||
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *return_node) {
|
||||
schema::CNodeT *return_node) {
|
||||
MS_ASSERT(nullptr != meta_graph);
|
||||
MS_ASSERT(nullptr != return_node);
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
|
@ -210,9 +209,9 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
|
|||
if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem ||
|
||||
primitive_c->Type() == schema::PrimitiveType_MakeTuple
|
||||
#ifdef SUPPORT_TRAIN
|
||||
|| primitive_c->Type() == schema::PrimitiveType_Depend
|
||||
|| primitive_c->Type() == schema::PrimitiveType_Depend
|
||||
#endif
|
||||
) {
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
RemoveIfMakeTuple(cnode);
|
||||
|
@ -403,8 +402,7 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
|||
if (value_track->isa<Int32Imm>()) {
|
||||
shape.push_back((GetValue<int>(value_track)));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is "
|
||||
<< value_track->ToString() << ".";
|
||||
MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << ".";
|
||||
}
|
||||
}
|
||||
if (shape.size()) {
|
||||
|
@ -417,10 +415,10 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
|||
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
|
||||
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
|
||||
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
|
||||
}
|
||||
#endif
|
||||
} else if (value->isa<Number>()) {
|
||||
MS_LOG(INFO) << "Value is a number.";
|
||||
|
|
|
@ -54,8 +54,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpDualInputList = {
|
|||
|
||||
static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = {
|
||||
#ifdef SUPPORT_TRAIN
|
||||
schema::PrimitiveType_PoolingGrad,
|
||||
schema::PrimitiveType_ActivationGrad
|
||||
schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad
|
||||
#endif
|
||||
};
|
||||
|
||||
|
@ -66,20 +65,21 @@ static const std::vector<schema::PrimitiveType> fp32FullOpList = {
|
|||
static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> int8OpList = {
|
||||
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
|
||||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
|
||||
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
|
||||
schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,
|
||||
schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation,
|
||||
schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection,
|
||||
schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin,
|
||||
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
|
||||
schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div,
|
||||
schema::PrimitiveType_Mul, schema::PrimitiveType_Slice,
|
||||
schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split,
|
||||
schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
|
||||
schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze,
|
||||
schema::PrimitiveType_MatMul, schema::PrimitiveType_Pad};
|
||||
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
|
||||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
|
||||
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
|
||||
schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,
|
||||
schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation,
|
||||
schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection,
|
||||
schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin,
|
||||
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
|
||||
schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div,
|
||||
schema::PrimitiveType_Mul, schema::PrimitiveType_Slice,
|
||||
schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split,
|
||||
schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
|
||||
schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK,
|
||||
schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul,
|
||||
schema::PrimitiveType_Pad};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> needInsertOpList = {
|
||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "tools/common/converter_op_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "src/common/common.h"
|
||||
|
@ -26,6 +27,9 @@ namespace lite {
|
|||
#define kMinInputNum 1
|
||||
#define kOutputNum 1
|
||||
|
||||
static const std::set<schema::PrimitiveType> NoNeedDtypeTransList = {
|
||||
PrimitiveType_QuantDTypeCast, PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw};
|
||||
|
||||
STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
|
||||
|
@ -134,7 +138,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
|||
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
|
||||
continue;
|
||||
}
|
||||
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) {
|
||||
auto iterType = GetCNodeTType(**iter);
|
||||
if (NoNeedDtypeTransList.find(iterType) != NoNeedDtypeTransList.end()) {
|
||||
continue;
|
||||
}
|
||||
bool needInsertPost = true;
|
||||
|
|
|
@ -167,7 +167,11 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
|||
auto &outTensor = graph->allTensors.at(node.outputIndex.at(i));
|
||||
MS_ASSERT(outTensor != nullptr);
|
||||
auto outQuantParam = GetTensorQuantParam(outTensor);
|
||||
if (outQuantParam == nullptr || outQuantParam->inited) {
|
||||
if (outQuantParam == nullptr) {
|
||||
outTensor->quantParams.emplace_back(std::move(inQuantParam));
|
||||
continue;
|
||||
}
|
||||
if (outQuantParam->inited) {
|
||||
continue;
|
||||
}
|
||||
outTensor->quantParams.front() = std::move(inQuantParam);
|
||||
|
@ -232,7 +236,7 @@ class CalcConcat : public QuantParamCalcer {
|
|||
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
outTensor->quantParams.front() = std::move(outQuantParam);
|
||||
outTensor->quantParams.emplace_back(std::move(outQuantParam));
|
||||
outputParamDone++;
|
||||
}
|
||||
|
||||
|
@ -417,7 +421,7 @@ class CalcToSet : public QuantParamCalcer {
|
|||
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
|
||||
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
|
||||
MS_ASSERT(outTensor != nullptr);
|
||||
outTensor->quantParams.front() = std::move(quantParam);
|
||||
outTensor->quantParams.emplace_back(std::move(quantParam));
|
||||
outputParamDone++;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -475,6 +479,7 @@ QuantParamCalcRegister::QuantParamCalcRegister() {
|
|||
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer;
|
||||
_registerMap[schema::PrimitiveType_Resize] = linearCalcer;
|
||||
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer;
|
||||
_registerMap[schema::PrimitiveType_StridedSlice] = linearCalcer;
|
||||
_registerMap[schema::PrimitiveType_Shape] = linearCalcer;
|
||||
_registerMap[schema::PrimitiveType_SoftMax] = std::make_shared<CalcToSet>(0, 1);
|
||||
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer;
|
||||
|
|
Loading…
Reference in New Issue