forked from mindspore-Ecosystem/mindspore
!8634 fix quant mindir inference
From: @yankai10 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
5a602e5288
|
@ -220,36 +220,28 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<An
|
|||
auto narrow_range = prim.GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = narrow_range != nullptr ? GetValue<bool>(narrow_range) : false;
|
||||
auto num_bits = prim.GetAttr("num_bits");
|
||||
int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int32_t>(num_bits) : 8;
|
||||
int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int64_t>(num_bits) : 8;
|
||||
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
schema::QuantParamT quantParam;
|
||||
auto mean = prim.GetAttr("mean");
|
||||
auto std_dev = prim.GetAttr("std_dev");
|
||||
if (mean != nullptr && std_dev != nullptr) {
|
||||
auto meanValue = GetValue<double>(mean);
|
||||
auto stddevValue = GetValue<double>(std_dev);
|
||||
float mMin = 0.0;
|
||||
float mMax = 0.0;
|
||||
CalFloatScopeByMeanAndStddev(meanValue, stddevValue, &mMin, &mMax);
|
||||
quantParam.min = mMin;
|
||||
quantParam.max = mMax;
|
||||
} else {
|
||||
auto inputMin = prim.GetAttr("input_minq");
|
||||
auto inputMax = prim.GetAttr("input_maxq");
|
||||
if (inputMin != nullptr && inputMax != nullptr) {
|
||||
auto inputMinPtr = inputMin->cast<TensorPtr>();
|
||||
auto inputMaxPtr = inputMax->cast<TensorPtr>();
|
||||
auto *minBuf = static_cast<float *>(inputMinPtr->data_c());
|
||||
auto *maxBuf = static_cast<float *>(inputMaxPtr->data_c());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
auto inputMin = prim.GetAttr("input_minq");
|
||||
auto inputMax = prim.GetAttr("input_maxq");
|
||||
if (inputMin != nullptr && inputMax != nullptr) {
|
||||
auto inputMinPtr = inputMin->cast<TensorPtr>();
|
||||
auto inputMaxPtr = inputMax->cast<TensorPtr>();
|
||||
auto *minBuf = static_cast<float *>(inputMinPtr->data_c());
|
||||
auto *maxBuf = static_cast<float *>(inputMaxPtr->data_c());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
auto ret = quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Can't calculate quant parameters";
|
||||
return;
|
||||
}
|
||||
quants.emplace_back(quantParam);
|
||||
input_quant_param_.emplace_back(quants);
|
||||
}
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
input_quant_param_.emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
auto filterMin = prim.GetAttr("filter_minq");
|
||||
|
@ -267,7 +259,11 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<An
|
|||
minBuf++;
|
||||
maxBuf++;
|
||||
}
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, true, numbitsRangeQuantParam);
|
||||
auto ret = quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, true, numbitsRangeQuantParam);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Can't calculate quant parameters";
|
||||
return;
|
||||
}
|
||||
quants.emplace_back(quantParam);
|
||||
input_quant_param_.emplace_back(quants);
|
||||
}
|
||||
|
@ -300,8 +296,12 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<An
|
|||
float *maxBuf = static_cast<float *>(outputMaxPtr->data_c());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
auto ret = quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Can't calculate quant parameters";
|
||||
return;
|
||||
}
|
||||
quants.emplace_back(quantParam);
|
||||
output_quant_param_.emplace_back(quants);
|
||||
} else {
|
||||
|
|
|
@ -35,7 +35,7 @@ class ParamValueLite : public Value {
|
|||
tensor_size_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
MS_DECLARE_PARENT(ParamValueLite, Value)
|
||||
size_t tensor_size() const { return tensor_size_; }
|
||||
void set_tensor_size(size_t size) { tensor_size_ = size; }
|
||||
void *tensor_addr() const { return tensor_addr_; }
|
||||
|
|
|
@ -460,6 +460,16 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
|||
} else if (value->isa<Number>()) {
|
||||
MS_LOG(INFO) << "Value is a number.";
|
||||
return RET_OK;
|
||||
} else if (value->isa<mindspore::ParamValueLite>()) {
|
||||
auto valueLite = std::dynamic_pointer_cast<ParamValueLite>(value);
|
||||
paramTensor->data.resize(valueLite->tensor_size());
|
||||
paramTensor->format = schema::Format(valueLite->format());
|
||||
paramTensor->dataType = valueLite->tensor_type();
|
||||
paramTensor->dims = valueLite->tensor_shape();
|
||||
memcpy(paramTensor->data.data(), valueLite->tensor_addr(), valueLite->tensor_size());
|
||||
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
|
||||
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
|
||||
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not support value type , need add support.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -452,24 +452,36 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
|
|||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
std::vector<int64_t> shape_vector;
|
||||
std::vector<int> shape_vector;
|
||||
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
tensor::TensorPtr tensor_info =
|
||||
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape_vector);
|
||||
[](const int32_t &value) { return static_cast<int>(value); });
|
||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||
param_value->set_tensor_shape(shape_vector);
|
||||
param_value->set_tensor_type(kDefaultValueSwitchMap[attr_tensor_type]);
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size());
|
||||
if (EOK != ret) {
|
||||
MS_LOG(ERROR) << "memcpy_s error";
|
||||
auto tensor_data = new (std::nothrow) char[tensor_buf.size()];
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor_data is nullptr";
|
||||
return false;
|
||||
}
|
||||
auto new_value_node = NewValueNode(MakeValue(tensor_info));
|
||||
auto ret = memcpy_s(tensor_data, tensor_buf.size(), tensor_buf.data(), tensor_buf.size());
|
||||
if (ret != EOK) {
|
||||
delete[] tensor_data;
|
||||
MS_LOG(ERROR) << "Memcpy error: " << ret;
|
||||
return false;
|
||||
}
|
||||
param_value->set_tensor_addr(tensor_data);
|
||||
param_value->set_tensor_size(tensor_buf.size());
|
||||
auto new_value_node = NewValueNode(MakeValue(param_value));
|
||||
if (new_value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Make valuenode fail";
|
||||
return false;
|
||||
}
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
||||
std::vector<int64_t> shape_vector_int64;
|
||||
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector_int64),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector_int64);
|
||||
new_value_node->set_abstract(abstract_tensor);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
|
|
|
@ -398,7 +398,14 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
|
|||
ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (!utils::isa<ParameterPtr>(node)) {
|
||||
MS_LOG(ERROR) << "get lite param value node must paramter";
|
||||
if (utils::isa<ValueNodePtr>(node)) {
|
||||
auto valueNode = node->cast<ValueNodePtr>();
|
||||
auto value = std::dynamic_pointer_cast<ParamValueLite>(valueNode->value());
|
||||
if (value != nullptr) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "get lite param value node neither parameternode or valuenode";
|
||||
return nullptr;
|
||||
}
|
||||
auto param = node->cast<ParameterPtr>();
|
||||
|
|
Loading…
Reference in New Issue