!45884 [Golden-Stick] support quant lenet model lite inference

Merge pull request !45884 from yangruoqi713/gs
This commit is contained in:
i-robot 2022-12-16 08:46:26 +00:00 committed by Gitee
commit 411a3c7e35
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 184 additions and 299 deletions

View File

@ -40,16 +40,18 @@ __global__ void NudgeMinMaxPerChannel(float *input_min, float *input_max, const
float nudge_zp = 0.f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) {
float max_data = input_max[i];
float min_data = input_min[i];
if (symmetric) {
input_max[i] = abs(input_min[i]) < input_max[i] ? input_max[i] : -input_min[i];
input_min[i] = abs(input_min[i]) < input_max[i] ? -input_max[i] : input_min[i];
max_data = abs(input_min[i]) < input_max[i] ? input_max[i] : -input_min[i];
min_data = abs(input_min[i]) < input_max[i] ? -input_max[i] : input_min[i];
}
if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) {
if ((quant_max - quant_min) == 0 || (max_data - min_data) == 0) {
scale[i] = 0.f;
zp_from_min = 0.f;
} else {
scale[i] = (input_max[i] - input_min[i]) / (quant_max - quant_min);
zp_from_min = quant_min - input_min[i] / scale[i];
scale[i] = (max_data - min_data) / (quant_max - quant_min);
zp_from_min = quant_min - min_data / scale[i];
}
if (zp_from_min <= quant_min) {

View File

@ -61,17 +61,19 @@ __global__ void NudgeMinMaxPerLayer(float *input_min, float *input_max, const fl
nudge_max[0] = 0.f;
nudge_min[0] = 0.f;
float max_data = input_max[0];
float min_data = input_min[0];
if (symmetric) {
input_max[0] = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0];
input_min[0] = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0];
max_data = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0];
min_data = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0];
}
if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) {
if ((quant_max - quant_min) == 0 || (max_data - max_data) == 0) {
scale[0] = 0.f;
zp_from_min = 0.f;
} else {
scale[0] = (input_max[0] - input_min[0]) / (quant_max - quant_min);
zp_from_min = quant_min - input_min[0] / scale[0];
scale[0] = (max_data - max_data) / (quant_max - quant_min);
zp_from_min = quant_min - max_data / scale[0];
}
float nudge_zp = 0.f;

View File

@ -66,62 +66,34 @@ bool FakeQuantParam::get_is_perchannel() const {
return GetValue<bool>(value_ptr);
}
void FakeQuantParam::set_quant_param(const std::string &key, api::ValuePtr param, size_t channel_index) {
if (this->get_is_perchannel()) {
auto value_ptr = this->GetAttr(key);
std::vector<api::ValuePtr> params;
if (value_ptr != nullptr) {
params = GetValue<std::vector<api::ValuePtr>>(value_ptr);
}
if (channel_index == params.size()) {
params.emplace_back(param);
} else if (channel_index < params.size()) {
params[channel_index] = param;
} else {
MS_LOG(EXCEPTION) << "Please set quant parameter in ascending order of channels.";
}
(void)AddAttr(key, api::MakeValue(params));
} else {
if (channel_index != 0) {
MS_LOG(EXCEPTION) << "'channel_index' should be equal to zero while set a per-layer quant parameter, but got "
<< channel_index << ".";
}
std::vector<api::ValuePtr> params{param};
(void)AddAttr(key, api::MakeValue(params));
}
}
void FakeQuantParam::set_quant_param(const std::string &key, api::ValuePtr param) { (void)AddAttr(key, param); }
api::ValuePtr FakeQuantParam::get_quant_param(const std::string &key, size_t channel_index) const {
api::ValuePtr FakeQuantParam::get_quant_param(const std::string &key) const {
auto value_ptr = this->GetAttr(key);
if (value_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Quant parameter " << key << " not found!";
}
auto params = GetValue<std::vector<api::ValuePtr>>(value_ptr);
if (channel_index >= params.size()) {
MS_LOG(EXCEPTION) << "Channel index(" << channel_index << ") out of range of quant parameter size(" << params.size()
<< ")!";
}
return params[channel_index];
return value_ptr;
}
void FakeQuantParam::set_scale(double scale, size_t channel_index) {
this->set_quant_param(kAttrKeyLinearQuantParamScale, api::MakeValue(scale), channel_index);
void FakeQuantParam::set_scales(std::vector<float> scales) {
(void)this->AddAttr(kAttrKeyLinearQuantParamScale, api::MakeValue(scales));
}
double FakeQuantParam::get_scale(size_t channel_index) const {
auto scale = this->get_quant_param(kAttrKeyLinearQuantParamScale, channel_index);
MS_EXCEPTION_IF_NULL(scale);
return GetValue<double>(scale);
std::vector<float> FakeQuantParam::get_scales() const {
auto value_ptr = GetAttr(kAttrKeyLinearQuantParamScale);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<float>>(value_ptr);
}
void FakeQuantParam::set_zero_point(int64_t zero_point, size_t channel_index) {
this->set_quant_param(kAttrKeyLinearQuantParamZeroPoint, api::MakeValue(zero_point), channel_index);
void FakeQuantParam::set_zero_points(std::vector<int64_t> zero_points) {
(void)this->AddAttr(kAttrKeyLinearQuantParamZeroPoint, api::MakeValue(zero_points));
}
int64_t FakeQuantParam::get_zero_point(size_t channel_index) const {
auto zp = this->get_quant_param(kAttrKeyLinearQuantParamZeroPoint, channel_index);
MS_EXCEPTION_IF_NULL(zp);
return GetValue<int64_t>(zp);
std::vector<int64_t> FakeQuantParam::get_zero_points() const {
auto value_ptr = GetAttr(kAttrKeyLinearQuantParamZeroPoint);
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<std::vector<int64_t>>(value_ptr);
}
class FakeQuantParamInfer : public abstract::OpInferBase {

View File

@ -89,48 +89,34 @@ class MIND_API FakeQuantParam : public BaseOperator {
///
/// \param[in] key Define the name of quant parameter.
/// \param[in] param Define the value of quant parameter.
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
/// indicating first channel.
void set_quant_param(const std::string &key, api::ValuePtr param, size_t channel_index = 0);
void set_quant_param(const std::string &key, api::ValuePtr param);
/// \brief Method to get quant parameter named `key`.
///
/// \param[in] key Define the name of quant parameter.
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
/// indicating first channel.
///
/// \return a ValuePtr represents quant parameter.
api::ValuePtr get_quant_param(const std::string &key, size_t channel_index = 0) const;
api::ValuePtr get_quant_param(const std::string &key) const;
/// \brief Method to get quant parameter named `scale` for linear algorithm.
/// \brief Method to set quant parameter named `scale` for linear algorithm.
///
/// \param[in] scale Define the value of quant parameter.
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
/// indicating first channel.
void set_scale(double scale, size_t channel_index = 0);
void set_scales(std::vector<float> scales);
/// \brief Method to get quant parameter named `scale` for linear algorithm.
/// \brief Method to get quant parameters named `scale` for linear algorithm.
///
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
/// indicating first channel.
///
/// \return a double as scale.
double get_scale(size_t channel_index = 0) const;
/// \return a float vector as scale parameters.
std::vector<float> get_scales() const;
/// \brief Method to get quant parameter named `zero_point` for linear algorithm.
/// \brief Method to set quant parameter named `zero_point` for linear algorithm.
///
/// \param[in] zero_point Define the value of quant parameter.
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
/// indicating first channel.
void set_zero_point(int64_t zero_point, size_t channel_index = 0);
/// \param[in] zero_points Define the value of quant parameter.
void set_zero_points(std::vector<int64_t> zero_points);
/// \brief Method to get quant parameter named `zero_point` for linear algorithm.
/// \brief Method to get quant parameters named `zero_point` for linear algorithm.
///
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
/// indicating first channel.
///
/// \return a int64_t as zero_point.
int64_t get_zero_point(size_t channel_index = 0) const;
/// \return a int64_t vector as zero_point parameters.
std::vector<int64_t> get_zero_points() const;
};
using FakeQuantParamPtr = std::shared_ptr<FakeQuantParam>;

View File

@ -17,176 +17,143 @@
#include <vector>
#include <memory>
#include <set>
#include <map>
#include <algorithm>
#include <utility>
#include "tools/converter/converter_context.h"
#include "tools/converter/quantizer/quant_param_holder.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "src/common/log_adapter.h"
#include "src/common/quant_utils.h"
#include "tools/common/node_util.h"
#include "tools/converter/parser/parser_utils.h"
#include "nnacl/op_base.h"
#include "ops/core_ops.h"
#include "ops/op_utils.h"
#include "ops/fake_quant_param.h"
namespace mindspore {
namespace lite {
namespace {
int ConvertInputQuantParam(const PrimitivePtr &prim, bool input_narrow_range, bool weight_narrow_range,
int32_t act_numbits, int32_t weight_numbits,
std::map<int, std::vector<schema::QuantParamT>> *input_quant_params) {
std::vector<schema::QuantParamT> quants;
int ConvertQuantParam(const api::SharedPtr<mindspore::ops::FakeQuantParam> &fake_quant_prim,
std::vector<schema::QuantParamT> *quant_params) {
MS_CHECK_TRUE_MSG(fake_quant_prim != nullptr, RET_NULL_PTR, "fake_quant_prim is nullptr.");
MS_CHECK_TRUE_MSG(quant_params != nullptr, RET_NULL_PTR, "quant_params is nullptr.");
schema::QuantParamT quant_param;
auto input_min = prim->GetAttr("input_minq");
auto input_max = prim->GetAttr("input_maxq");
if (input_min != nullptr && input_max != nullptr) {
auto input_min_ptr = input_min->cast<tensor::TensorPtr>();
MS_ASSERT(input_min_ptr != nullptr);
auto input_max_ptr = input_max->cast<tensor::TensorPtr>();
MS_ASSERT(input_max_ptr != nullptr);
MS_CHECK_TRUE_MSG(input_min_ptr->data_c() != nullptr, RET_ERROR, "input_min_ptr->data_c() is nullptr");
MS_CHECK_TRUE_MSG(input_max_ptr->data_c() != nullptr, RET_ERROR, "input_max_ptr->data_c() is nullptr");
auto *min_buf = static_cast<float *>(input_min_ptr->data_c());
auto *max_buf = static_cast<float *>(input_max_ptr->data_c());
quant_param.min = *min_buf;
quant_param.max = *max_buf;
auto ret = CalQuantizationParams(&quant_param, quant_param.min, quant_param.max, act_numbits, input_narrow_range);
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Failed to calculate quant parameters.");
quants.emplace_back(quant_param);
input_quant_params->insert({0, quants});
auto scale = fake_quant_prim->get_scales();
auto zp = fake_quant_prim->get_zero_points();
if (scale.size() != zp.size()) {
MS_LOG(ERROR) << "The number of quant params scale and zero_points should be same.";
return RET_ERROR;
}
quants.clear();
auto filter_min = prim->GetAttr("filter_minq");
auto filter_max = prim->GetAttr("filter_maxq");
if (filter_min != nullptr && filter_max != nullptr) {
auto filter_min_ptr = filter_min->cast<tensor::TensorPtr>();
MS_ASSERT(filter_min_ptr != nullptr);
auto filter_max_ptr = filter_max->cast<tensor::TensorPtr>();
MS_ASSERT(filter_max_ptr != nullptr);
MS_CHECK_TRUE_MSG(filter_min_ptr->data_c() != nullptr, RET_ERROR, "filter_min_ptr->data_c() is nullptr");
MS_CHECK_TRUE_MSG(filter_max_ptr->data_c() != nullptr, RET_ERROR, "filter_max_ptr->data_c() is nullptr");
auto *min_buf = static_cast<float *>(filter_min_ptr->data_c());
auto *max_buf = static_cast<float *>(filter_max_ptr->data_c());
quant_param.min = FLT_MAX;
quant_param.max = FLT_MIN;
for (int i = 0; i < filter_min_ptr->ElementsNum(); ++i) {
schema::QuantParamT tmp_quant_param;
tmp_quant_param.min = *min_buf;
tmp_quant_param.max = *max_buf;
auto ret = CalQuantizationParams(&tmp_quant_param, tmp_quant_param.min, tmp_quant_param.max, weight_numbits,
weight_narrow_range);
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Failed to calculate quant parameters.");
quants.emplace_back(tmp_quant_param);
min_buf++;
max_buf++;
}
input_quant_params->insert({1, quants});
quant_params->resize(scale.size());
for (size_t i = 0; i < scale.size(); i++) {
quant_param.inited = True;
quant_param.scale = scale[i];
quant_param.zeroPoint = zp[i];
(*quant_params)[i] = quant_param;
}
return lite::RET_OK;
}
int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits,
std::map<int, std::vector<schema::QuantParamT>> *output_quant_params) {
int ConvertNodesQuantParam(const std::vector<std::shared_ptr<AnfNode>> &nodes,
std::map<int, std::vector<schema::QuantParamT>> *quant_params) {
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quant_param;
auto outputMin = prim->GetAttr("output_minq");
auto outputMax = prim->GetAttr("output_maxq");
if (outputMin != nullptr && outputMax != nullptr) {
auto outputMinPtr = outputMin->cast<tensor::TensorPtr>();
auto outputMaxPtr = outputMax->cast<tensor::TensorPtr>();
MS_ASSERT(outputMinPtr != nullptr);
MS_ASSERT(outputMaxPtr != nullptr);
MS_CHECK_TRUE_MSG(outputMinPtr->data_c() != nullptr, RET_ERROR, "outputMinPtr->data_c() is nullptr");
MS_CHECK_TRUE_MSG(outputMaxPtr->data_c() != nullptr, RET_ERROR, "outputMaxPtr->data_c() is nullptr");
auto *minBuf = static_cast<float *>(outputMinPtr->data_c());
auto *maxBuf = static_cast<float *>(outputMaxPtr->data_c());
quant_param.min = *minBuf;
quant_param.max = *maxBuf;
auto ret = CalQuantizationParams(&quant_param, quant_param.min, quant_param.max, numbits, narrow_range);
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Failed to calculate quant parameters.");
quants.emplace_back(quant_param);
output_quant_params->insert({0, quants});
}
return lite::RET_OK;
}
int GetNarrowRange(const PrimitivePtr &prim, const std::string &narrow_range_str, bool *const narrow_range_param) {
MS_CHECK_TRUE_MSG(narrow_range_param != nullptr, RET_ERROR, "narrow_range_param is nullptr");
auto narrow_range = prim->GetAttr(narrow_range_str);
if (narrow_range != nullptr) {
if (utils::isa<tensor::TensorPtr>(narrow_range)) {
auto narrow_range_tensor = narrow_range->cast<tensor::TensorPtr>();
MS_ASSERT(narrow_range_tensor != nullptr);
MS_CHECK_TRUE_MSG(narrow_range_tensor->data_c() != nullptr, RET_ERROR,
"narrow_range_tensor->data_c() is nullptr");
*narrow_range_param = *reinterpret_cast<bool *>(narrow_range_tensor->data_c());
} else if (utils::isa<ImmTraits<bool>::type>(narrow_range)) {
*narrow_range_param = GetValue<bool>(narrow_range);
} else {
MS_LOG(ERROR) << "valueptr is invalid.";
return lite::RET_ERROR;
for (size_t i = 0; i < nodes.size(); i++) {
quants.clear();
if (IsPrimitiveCNode(nodes[i], prim::kPrimFakeQuantParam)) {
auto fake_quant_prim =
ops::GetOperator<mindspore::ops::FakeQuantParam>(nodes[i]->cast<CNodePtr>()->input(0)->cast<ValueNodePtr>());
auto status = ConvertQuantParam(fake_quant_prim, &quants);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Convert quant param from FakeQuantParam operation failed.";
return lite::RET_ERROR;
}
}
if (!quants.empty()) {
quant_params->insert({i, quants});
}
}
return lite::RET_OK;
}
int GetNumBits(const PrimitivePtr &prim, const std::string &num_bits_str, int *const num_bits_param) {
MS_CHECK_TRUE_MSG(num_bits_param != nullptr, RET_ERROR, "num_bits_param is nullptr");
auto num_bits = prim->GetAttr(num_bits_str);
if (num_bits != nullptr) {
if (utils::isa<tensor::TensorPtr>(num_bits)) {
auto num_bits_tensor = num_bits->cast<tensor::TensorPtr>();
MS_ASSERT(num_bits_tensor != nullptr);
MS_CHECK_TRUE_MSG(num_bits_tensor->data_c() != nullptr, RET_ERROR, "num_bits_tensor->data_c() is nullptr");
MS_CHECK_TRUE_MSG(num_bits_tensor->data().nbytes() >= static_cast<int>(sizeof(int64_t)), RET_ERROR,
"num_bits_tensor->data_c() is not longer enough for int64_t");
*num_bits_param = *reinterpret_cast<int64_t *>(num_bits_tensor->data_c());
} else if (utils::isa<ImmTraits<int64_t>::type>(num_bits)) {
*num_bits_param = GetValue<int64_t>(num_bits);
} else {
MS_LOG(ERROR) << "valueptr is invalid.";
return lite::RET_ERROR;
int RemoveFakeQuantParam(const FuncGraphPtr &fg) {
MS_CHECK_TRUE_MSG(fg != nullptr, RET_NULL_PTR, "fg is nullptr.");
auto manager = fg->manager();
auto node_list = TopoSort(fg->get_return());
for (auto &node : node_list) {
if (IsPrimitiveCNode(node, prim::kPrimFakeQuantParam)) {
auto inputs = node->cast<CNodePtr>()->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [](const std::shared_ptr<AnfNode> &input) {
return IsPrimitiveCNode(input, prim::kPrimFakeQuantParam);
})) {
MS_LOG(ERROR) << "Two FakeQuantParam operators can't be joined together in mindir origin model";
return RET_ERROR;
}
auto iter = manager->node_users().find(node);
if (iter != manager->node_users().end()) {
auto outputs_set = manager->node_users()[node];
if (std::any_of(outputs_set.begin(), outputs_set.end(),
[](const std::pair<std::shared_ptr<AnfNode>, int> &output) {
return IsPrimitiveCNode(output.first, prim::kPrimFakeQuantParam);
})) {
MS_LOG(ERROR) << "Two FakeQuantParam operators can't be joined together in mindir origin model";
return RET_ERROR;
}
}
auto pre_node = node->cast<CNodePtr>()->input(1);
(void)manager->Replace(node, pre_node);
}
}
return lite::RET_OK;
return RET_OK;
}
int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs) {
bool input_narrow_range_param = false;
auto status = GetNarrowRange(prim, "input_narrow_range", &input_narrow_range_param);
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Get input narrow range failed.");
bool weight_narrow_range_param = true;
status = GetNarrowRange(prim, "weight_narrow_range", &weight_narrow_range_param);
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Get weight narrow range failed.");
bool output_narrow_range_param = false;
status = GetNarrowRange(prim, "output_narrow_range", &output_narrow_range_param);
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Get output narrow range failed.");
int32_t act_num_bits_param = 8;
status = GetNumBits(prim, "act_num_bits", &act_num_bits_param);
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Get activation num_bits failed.");
int32_t weight_num_bits_param = 8;
status = GetNumBits(prim, "weight_num_bits", &weight_num_bits_param);
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Get weight num_bits failed.");
int GetNodeQuantParam(std::shared_ptr<AnfNode> anf_node, const PrimitivePtr &primitive,
const FuncGraphManagerPtr &manager) {
if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(INFO) << "Only cnode need to convert primitive.";
return RET_NO_CHANGE;
}
std::map<int, std::vector<schema::QuantParamT>> input_quant_params;
std::map<int, std::vector<schema::QuantParamT>> output_quant_params;
status = ConvertInputQuantParam(prim, input_narrow_range_param, weight_narrow_range_param, act_num_bits_param,
weight_num_bits_param, &input_quant_params);
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Compute input quant param failed.");
status = ConvertOutputQuantParam(prim, output_narrow_range_param, act_num_bits_param, &output_quant_params);
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Compute output quant param failed.");
auto cnode = anf_node->cast<CNodePtr>();
auto inputs = cnode->inputs();
inputs.erase(inputs.begin());
auto status = ConvertNodesQuantParam(inputs, &input_quant_params);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert input quant param failed.";
return RET_ERROR;
}
auto iter = manager->node_users().find(anf_node);
std::vector<AnfNodePtr> outputs;
if (iter != manager->node_users().end()) {
auto outputs_set = manager->node_users()[anf_node];
std::transform(outputs_set.begin(), outputs_set.end(), std::back_inserter(outputs),
[](const std::pair<std::shared_ptr<AnfNode>, int> &output) { return output.first; });
status = ConvertNodesQuantParam(outputs, &output_quant_params);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output quant param failed.";
return RET_ERROR;
}
if (output_quant_params.size() > 1 || (output_quant_params.size() == 1 && outputs.size() != 1)) {
MS_LOG(ERROR) << "There can only be one FakeQuantParam as the output of " << anf_node->fullname_with_scope();
return RET_ERROR;
}
}
if (!input_quant_params.empty() || !output_quant_params.empty()) {
auto quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
auto quant_params_holder = std::make_shared<QuantParamHolder>(inputs.size(), outputs.size());
MSLITE_CHECK_PTR(quant_params_holder);
for (auto &iter : input_quant_params) {
quant_params_holder->set_input_quant_param(iter.first, iter.second);
for (auto &input : input_quant_params) {
quant_params_holder->set_input_quant_param(input.first, input.second);
}
for (auto &iter : output_quant_params) {
quant_params_holder->set_output_quant_param(iter.first, iter.second);
for (auto &output : output_quant_params) {
quant_params_holder->set_output_quant_param(output.first, output.second);
}
prim->AddAttr("quant_params", quant_params_holder);
primitive->AddAttr("quant_params", quant_params_holder);
}
return lite::RET_OK;
return RET_OK;
}
} // namespace
@ -216,7 +183,7 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) {
return lite::RET_NO_CHANGE;
}
auto value_node = anf_node->cast<ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "value_node is nullptr");
if (value_node->abstract() == nullptr) {
return lite::RET_NO_CHANGE;
}
@ -260,7 +227,7 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) {
return lite::RET_NO_CHANGE;
}
int MindirAdjust::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) {
int MindirAdjust::ConvertQuantParams(std::shared_ptr<AnfNode> anf_node, const FuncGraphManagerPtr &manager) {
MS_CHECK_TRUE_MSG(anf_node != nullptr, RET_ERROR, "anf_node is nullptr");
if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(INFO) << "only cnode need to convert primitive.";
@ -290,30 +257,7 @@ int MindirAdjust::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) {
return lite::RET_ERROR;
}
}
auto inputs = cnode->inputs();
inputs.erase(inputs.begin());
if (ConvertQuantParam(primitive, inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "compute quant param failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
int MindirAdjust::UpdateConv2DTransposeInput(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
if (!opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
return RET_OK;
}
auto inputs = cnode->inputs();
if (inputs.size() != opt::kInputSizeFour) {
MS_LOG(ERROR) << "the input size of mindir conv2dtranspose should be 4, but now is " << inputs.size()
<< ", please check.";
return RET_ERROR;
}
inputs.pop_back();
cnode->set_inputs(inputs);
return RET_OK;
return GetNodeQuantParam(anf_node, primitive, manager);
}
int MindirAdjust::ResetFuncGraph(const FuncGraphPtr &fg, std::set<FuncGraphPtr> all_func_graphs) {
@ -322,11 +266,21 @@ int MindirAdjust::ResetFuncGraph(const FuncGraphPtr &fg, std::set<FuncGraphPtr>
MS_CHECK_TRUE_MSG(manager != nullptr, RET_NULL_PTR, "manager is nullptr.");
manager->Clear();
manager->AddFuncGraph(fg, true);
auto status = RemoveFakeQuantParam(fg);
if (status != RET_OK) {
MS_LOG(ERROR) << "Remove FakeQuantParam operators failed.";
return RET_ERROR;
}
for (auto &item : all_func_graphs) {
if (item == fg) {
continue;
}
manager->AddFuncGraph(item);
status = RemoveFakeQuantParam(item);
if (status != RET_OK) {
MS_LOG(ERROR) << "Remove FakeQuantParam operators failed.";
return RET_ERROR;
}
}
return RET_OK;
}
@ -340,12 +294,14 @@ bool MindirAdjust::Run(const FuncGraphPtr &func_graph) {
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(func_graph, &all_func_graphs);
for (auto &graph : all_func_graphs) {
auto manager = graph->manager();
MS_CHECK_TRUE_MSG(manager != nullptr, RET_NULL_PTR, "manager is nullptr.");
auto node_list = TopoSort(graph->get_return());
int status = lite::RET_OK;
bool success_flag = true;
for (auto &node : node_list) {
if (utils::isa<CNodePtr>(node)) {
status = ComputeQuantParams(node);
status = ConvertQuantParams(node, manager);
} else if (utils::isa<ParameterPtr>(node)) {
status = AdjustInputDataType(node);
} else if (utils::isa<ValueNodePtr>(node)) {

View File

@ -36,8 +36,7 @@ class MindirAdjust {
private:
int AdjustInputDataType(AnfNodePtr anf_node);
int ValueNodeInt64Convert(AnfNodePtr anf_node);
int ComputeQuantParams(AnfNodePtr anf_node);
int UpdateConv2DTransposeInput(const CNodePtr &cnode);
int ConvertQuantParams(AnfNodePtr anf_node, const FuncGraphManagerPtr &manager);
int ResetFuncGraph(const FuncGraphPtr &fg, std::set<FuncGraphPtr> all_func_graphs);
FmkType fmk_type_ = FmkType::kFmkTypeMs;

View File

@ -117,8 +117,22 @@ class FakeQuantParam(Primitive):
@classmethod
def linear_quant_param(cls, quant_dtype, scale, zp, is_per_channel=False, **kwargs):
kwargs[FakeQuantParam.attr_key_linear_quant_scale] = scale
kwargs[FakeQuantParam.attr_key_linear_quant_zero_point] = zp
"""
Create a linear quantization operator based on scale and zero-point parameter.
"""
validator.check_value_type("scale", scale, [float, tuple, list], "FakeQuantParam")
if isinstance(scale, float):
scale_list = [scale]
else:
scale_list = scale
validator.check_value_type("zero_point", zp, [int, tuple, list], "FakeQuantParam")
if isinstance(zp, int):
zp_list = [zp]
else:
zp_list = zp
validator.check_value_type("is_per_channel", is_per_channel, [bool], "FakeQuantParam")
kwargs[FakeQuantParam.attr_key_linear_quant_scale] = scale_list
kwargs[FakeQuantParam.attr_key_linear_quant_zero_point] = zp_list
return cls(quant_dtype, FakeQuantParam.attr_value_linear_quant_algo_name, is_per_channel, **kwargs)

View File

@ -49,59 +49,13 @@ TEST_F(TestFakeQuantParam, test_attr_perlayer) {
auto perchannel = ops->get_is_perchannel();
EXPECT_EQ(perchannel, false);
bool has_error = false;
try {
ops->set_scale(1.0, 1);
} catch (...) {
has_error = true;
}
EXPECT_EQ(has_error, true);
ops->set_scales({1.0});
auto scale = ops->get_scales();
EXPECT_EQ(scale[0], 1.0);
ops->set_scale(1.0);
auto scale = ops->get_scale();
EXPECT_EQ(scale, 1.0);
ops->set_zero_point(1);
auto zp = ops->get_zero_point();
EXPECT_EQ(zp, 1);
ops->set_quant_param("slb-rate", api::MakeValue<float>(1.0));
auto slb_rate_value = ops->get_quant_param("slb-rate");
EXPECT_EQ(slb_rate_value->isa<api::FP32Imm>(), true);
auto slb_rate_imm = slb_rate_value->cast<api::FP32ImmPtr>();
auto slb_rate = slb_rate_imm->value();
EXPECT_EQ(slb_rate, 1.0);
}
/// Feature: setter and getter of per-channel FakeQuantParam operation.
/// Description: call setter and getter of FakeQuantParam operation and compare result of getter with argument of
/// setter.
/// Expectation: success.
TEST_F(TestFakeQuantParam, test_attr_perchannel) {
auto ops = std::make_shared<FakeQuantParam>();
ops->Init(kQuantDataTypeInt7, kAttrKeyLinearQuantAlgoName, true);
auto quant_dtype = ops->get_quant_dtype();
EXPECT_EQ(quant_dtype, kQuantDataTypeInt7);
auto algo_name = ops->get_quant_algo_name();
EXPECT_EQ(algo_name, kAttrKeyLinearQuantAlgoName);
auto perchannel = ops->get_is_perchannel();
EXPECT_EQ(perchannel, true);
bool has_error = false;
try {
ops->set_scale(1.0, 1);
} catch (...) {
has_error = true;
}
EXPECT_EQ(has_error, true);
ops->set_scale(1.0);
auto scale = ops->get_scale();
EXPECT_EQ(scale, 1.0);
ops->set_zero_point(1);
auto zp = ops->get_zero_point();
EXPECT_EQ(zp, 1);
ops->set_zero_points({1});
auto zp = ops->get_zero_points();
EXPECT_EQ(zp[0], 1);
ops->set_quant_param("slb-rate", api::MakeValue<float>(1.0));
auto slb_rate_value = ops->get_quant_param("slb-rate");
@ -117,8 +71,8 @@ TEST_F(TestFakeQuantParam, test_attr_perchannel) {
TEST_F(TestFakeQuantParam, test_infer_shape) {
auto ops = std::make_shared<FakeQuantParam>();
ops->Init(kQuantDataTypeInt7, kAttrKeyLinearQuantAlgoName, false);
ops->set_scale(1.0);
ops->set_zero_point(1);
ops->set_scales({1.0});
ops->set_zero_points({1});
auto input_x = TensorConstructUtils::CreateOnesTensor(kFloat32, std::vector<int64_t>{32, 3, 224, 224});
MS_EXCEPTION_IF_NULL(input_x);