!45884 [Golden-Stick] support quant lenet model lite inference
Merge pull request !45884 from yangruoqi713/gs
This commit is contained in:
commit
411a3c7e35
|
@ -40,16 +40,18 @@ __global__ void NudgeMinMaxPerChannel(float *input_min, float *input_max, const
|
|||
float nudge_zp = 0.f;
|
||||
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) {
|
||||
float max_data = input_max[i];
|
||||
float min_data = input_min[i];
|
||||
if (symmetric) {
|
||||
input_max[i] = abs(input_min[i]) < input_max[i] ? input_max[i] : -input_min[i];
|
||||
input_min[i] = abs(input_min[i]) < input_max[i] ? -input_max[i] : input_min[i];
|
||||
max_data = abs(input_min[i]) < input_max[i] ? input_max[i] : -input_min[i];
|
||||
min_data = abs(input_min[i]) < input_max[i] ? -input_max[i] : input_min[i];
|
||||
}
|
||||
if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) {
|
||||
if ((quant_max - quant_min) == 0 || (max_data - min_data) == 0) {
|
||||
scale[i] = 0.f;
|
||||
zp_from_min = 0.f;
|
||||
} else {
|
||||
scale[i] = (input_max[i] - input_min[i]) / (quant_max - quant_min);
|
||||
zp_from_min = quant_min - input_min[i] / scale[i];
|
||||
scale[i] = (max_data - min_data) / (quant_max - quant_min);
|
||||
zp_from_min = quant_min - min_data / scale[i];
|
||||
}
|
||||
|
||||
if (zp_from_min <= quant_min) {
|
||||
|
|
|
@ -61,17 +61,19 @@ __global__ void NudgeMinMaxPerLayer(float *input_min, float *input_max, const fl
|
|||
nudge_max[0] = 0.f;
|
||||
nudge_min[0] = 0.f;
|
||||
|
||||
float max_data = input_max[0];
|
||||
float min_data = input_min[0];
|
||||
if (symmetric) {
|
||||
input_max[0] = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0];
|
||||
input_min[0] = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0];
|
||||
max_data = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0];
|
||||
min_data = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0];
|
||||
}
|
||||
|
||||
if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) {
|
||||
if ((quant_max - quant_min) == 0 || (max_data - max_data) == 0) {
|
||||
scale[0] = 0.f;
|
||||
zp_from_min = 0.f;
|
||||
} else {
|
||||
scale[0] = (input_max[0] - input_min[0]) / (quant_max - quant_min);
|
||||
zp_from_min = quant_min - input_min[0] / scale[0];
|
||||
scale[0] = (max_data - max_data) / (quant_max - quant_min);
|
||||
zp_from_min = quant_min - max_data / scale[0];
|
||||
}
|
||||
|
||||
float nudge_zp = 0.f;
|
||||
|
|
|
@ -66,62 +66,34 @@ bool FakeQuantParam::get_is_perchannel() const {
|
|||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void FakeQuantParam::set_quant_param(const std::string &key, api::ValuePtr param, size_t channel_index) {
|
||||
if (this->get_is_perchannel()) {
|
||||
auto value_ptr = this->GetAttr(key);
|
||||
std::vector<api::ValuePtr> params;
|
||||
if (value_ptr != nullptr) {
|
||||
params = GetValue<std::vector<api::ValuePtr>>(value_ptr);
|
||||
}
|
||||
if (channel_index == params.size()) {
|
||||
params.emplace_back(param);
|
||||
} else if (channel_index < params.size()) {
|
||||
params[channel_index] = param;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Please set quant parameter in ascending order of channels.";
|
||||
}
|
||||
(void)AddAttr(key, api::MakeValue(params));
|
||||
} else {
|
||||
if (channel_index != 0) {
|
||||
MS_LOG(EXCEPTION) << "'channel_index' should be equal to zero while set a per-layer quant parameter, but got "
|
||||
<< channel_index << ".";
|
||||
}
|
||||
std::vector<api::ValuePtr> params{param};
|
||||
(void)AddAttr(key, api::MakeValue(params));
|
||||
}
|
||||
}
|
||||
void FakeQuantParam::set_quant_param(const std::string &key, api::ValuePtr param) { (void)AddAttr(key, param); }
|
||||
|
||||
api::ValuePtr FakeQuantParam::get_quant_param(const std::string &key, size_t channel_index) const {
|
||||
api::ValuePtr FakeQuantParam::get_quant_param(const std::string &key) const {
|
||||
auto value_ptr = this->GetAttr(key);
|
||||
if (value_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Quant parameter " << key << " not found!";
|
||||
}
|
||||
auto params = GetValue<std::vector<api::ValuePtr>>(value_ptr);
|
||||
if (channel_index >= params.size()) {
|
||||
MS_LOG(EXCEPTION) << "Channel index(" << channel_index << ") out of range of quant parameter size(" << params.size()
|
||||
<< ")!";
|
||||
}
|
||||
return params[channel_index];
|
||||
return value_ptr;
|
||||
}
|
||||
|
||||
void FakeQuantParam::set_scale(double scale, size_t channel_index) {
|
||||
this->set_quant_param(kAttrKeyLinearQuantParamScale, api::MakeValue(scale), channel_index);
|
||||
void FakeQuantParam::set_scales(std::vector<float> scales) {
|
||||
(void)this->AddAttr(kAttrKeyLinearQuantParamScale, api::MakeValue(scales));
|
||||
}
|
||||
|
||||
double FakeQuantParam::get_scale(size_t channel_index) const {
|
||||
auto scale = this->get_quant_param(kAttrKeyLinearQuantParamScale, channel_index);
|
||||
MS_EXCEPTION_IF_NULL(scale);
|
||||
return GetValue<double>(scale);
|
||||
std::vector<float> FakeQuantParam::get_scales() const {
|
||||
auto value_ptr = GetAttr(kAttrKeyLinearQuantParamScale);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<float>>(value_ptr);
|
||||
}
|
||||
|
||||
void FakeQuantParam::set_zero_point(int64_t zero_point, size_t channel_index) {
|
||||
this->set_quant_param(kAttrKeyLinearQuantParamZeroPoint, api::MakeValue(zero_point), channel_index);
|
||||
void FakeQuantParam::set_zero_points(std::vector<int64_t> zero_points) {
|
||||
(void)this->AddAttr(kAttrKeyLinearQuantParamZeroPoint, api::MakeValue(zero_points));
|
||||
}
|
||||
|
||||
int64_t FakeQuantParam::get_zero_point(size_t channel_index) const {
|
||||
auto zp = this->get_quant_param(kAttrKeyLinearQuantParamZeroPoint, channel_index);
|
||||
MS_EXCEPTION_IF_NULL(zp);
|
||||
return GetValue<int64_t>(zp);
|
||||
std::vector<int64_t> FakeQuantParam::get_zero_points() const {
|
||||
auto value_ptr = GetAttr(kAttrKeyLinearQuantParamZeroPoint);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
class FakeQuantParamInfer : public abstract::OpInferBase {
|
||||
|
|
|
@ -89,48 +89,34 @@ class MIND_API FakeQuantParam : public BaseOperator {
|
|||
///
|
||||
/// \param[in] key Define the name of quant parameter.
|
||||
/// \param[in] param Define the value of quant parameter.
|
||||
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
|
||||
/// indicating first channel.
|
||||
void set_quant_param(const std::string &key, api::ValuePtr param, size_t channel_index = 0);
|
||||
void set_quant_param(const std::string &key, api::ValuePtr param);
|
||||
|
||||
/// \brief Method to get quant parameter named `key`.
|
||||
///
|
||||
/// \param[in] key Define the name of quant parameter.
|
||||
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
|
||||
/// indicating first channel.
|
||||
///
|
||||
/// \return a ValuePtr represents quant parameter.
|
||||
api::ValuePtr get_quant_param(const std::string &key, size_t channel_index = 0) const;
|
||||
api::ValuePtr get_quant_param(const std::string &key) const;
|
||||
|
||||
/// \brief Method to get quant parameter named `scale` for linear algorithm.
|
||||
/// \brief Method to set quant parameter named `scale` for linear algorithm.
|
||||
///
|
||||
/// \param[in] scale Define the value of quant parameter.
|
||||
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
|
||||
/// indicating first channel.
|
||||
void set_scale(double scale, size_t channel_index = 0);
|
||||
void set_scales(std::vector<float> scales);
|
||||
|
||||
/// \brief Method to get quant parameter named `scale` for linear algorithm.
|
||||
/// \brief Method to get quant parameters named `scale` for linear algorithm.
|
||||
///
|
||||
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
|
||||
/// indicating first channel.
|
||||
///
|
||||
/// \return a double as scale.
|
||||
double get_scale(size_t channel_index = 0) const;
|
||||
/// \return a float vector as scale parameters.
|
||||
std::vector<float> get_scales() const;
|
||||
|
||||
/// \brief Method to get quant parameter named `zero_point` for linear algorithm.
|
||||
/// \brief Method to set quant parameter named `zero_point` for linear algorithm.
|
||||
///
|
||||
/// \param[in] zero_point Define the value of quant parameter.
|
||||
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
|
||||
/// indicating first channel.
|
||||
void set_zero_point(int64_t zero_point, size_t channel_index = 0);
|
||||
/// \param[in] zero_points Define the value of quant parameter.
|
||||
void set_zero_points(std::vector<int64_t> zero_points);
|
||||
|
||||
/// \brief Method to get quant parameter named `zero_point` for linear algorithm.
|
||||
/// \brief Method to get quant parameters named `zero_point` for linear algorithm.
|
||||
///
|
||||
/// \param[in] channel_index Define the index indicates which channel the quant parameter belong to. Default is 0
|
||||
/// indicating first channel.
|
||||
///
|
||||
/// \return a int64_t as zero_point.
|
||||
int64_t get_zero_point(size_t channel_index = 0) const;
|
||||
/// \return a int64_t vector as zero_point parameters.
|
||||
std::vector<int64_t> get_zero_points() const;
|
||||
};
|
||||
|
||||
using FakeQuantParamPtr = std::shared_ptr<FakeQuantParam>;
|
||||
|
|
|
@ -17,176 +17,143 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/converter/quantizer/quant_param_holder.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "ops/core_ops.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/fake_quant_param.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
int ConvertInputQuantParam(const PrimitivePtr &prim, bool input_narrow_range, bool weight_narrow_range,
|
||||
int32_t act_numbits, int32_t weight_numbits,
|
||||
std::map<int, std::vector<schema::QuantParamT>> *input_quant_params) {
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
int ConvertQuantParam(const api::SharedPtr<mindspore::ops::FakeQuantParam> &fake_quant_prim,
|
||||
std::vector<schema::QuantParamT> *quant_params) {
|
||||
MS_CHECK_TRUE_MSG(fake_quant_prim != nullptr, RET_NULL_PTR, "fake_quant_prim is nullptr.");
|
||||
MS_CHECK_TRUE_MSG(quant_params != nullptr, RET_NULL_PTR, "quant_params is nullptr.");
|
||||
schema::QuantParamT quant_param;
|
||||
auto input_min = prim->GetAttr("input_minq");
|
||||
auto input_max = prim->GetAttr("input_maxq");
|
||||
if (input_min != nullptr && input_max != nullptr) {
|
||||
auto input_min_ptr = input_min->cast<tensor::TensorPtr>();
|
||||
MS_ASSERT(input_min_ptr != nullptr);
|
||||
auto input_max_ptr = input_max->cast<tensor::TensorPtr>();
|
||||
MS_ASSERT(input_max_ptr != nullptr);
|
||||
MS_CHECK_TRUE_MSG(input_min_ptr->data_c() != nullptr, RET_ERROR, "input_min_ptr->data_c() is nullptr");
|
||||
MS_CHECK_TRUE_MSG(input_max_ptr->data_c() != nullptr, RET_ERROR, "input_max_ptr->data_c() is nullptr");
|
||||
auto *min_buf = static_cast<float *>(input_min_ptr->data_c());
|
||||
auto *max_buf = static_cast<float *>(input_max_ptr->data_c());
|
||||
quant_param.min = *min_buf;
|
||||
quant_param.max = *max_buf;
|
||||
auto ret = CalQuantizationParams(&quant_param, quant_param.min, quant_param.max, act_numbits, input_narrow_range);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Failed to calculate quant parameters.");
|
||||
quants.emplace_back(quant_param);
|
||||
input_quant_params->insert({0, quants});
|
||||
auto scale = fake_quant_prim->get_scales();
|
||||
auto zp = fake_quant_prim->get_zero_points();
|
||||
if (scale.size() != zp.size()) {
|
||||
MS_LOG(ERROR) << "The number of quant params scale and zero_points should be same.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
auto filter_min = prim->GetAttr("filter_minq");
|
||||
auto filter_max = prim->GetAttr("filter_maxq");
|
||||
if (filter_min != nullptr && filter_max != nullptr) {
|
||||
auto filter_min_ptr = filter_min->cast<tensor::TensorPtr>();
|
||||
MS_ASSERT(filter_min_ptr != nullptr);
|
||||
auto filter_max_ptr = filter_max->cast<tensor::TensorPtr>();
|
||||
MS_ASSERT(filter_max_ptr != nullptr);
|
||||
MS_CHECK_TRUE_MSG(filter_min_ptr->data_c() != nullptr, RET_ERROR, "filter_min_ptr->data_c() is nullptr");
|
||||
MS_CHECK_TRUE_MSG(filter_max_ptr->data_c() != nullptr, RET_ERROR, "filter_max_ptr->data_c() is nullptr");
|
||||
auto *min_buf = static_cast<float *>(filter_min_ptr->data_c());
|
||||
auto *max_buf = static_cast<float *>(filter_max_ptr->data_c());
|
||||
quant_param.min = FLT_MAX;
|
||||
quant_param.max = FLT_MIN;
|
||||
for (int i = 0; i < filter_min_ptr->ElementsNum(); ++i) {
|
||||
schema::QuantParamT tmp_quant_param;
|
||||
tmp_quant_param.min = *min_buf;
|
||||
tmp_quant_param.max = *max_buf;
|
||||
auto ret = CalQuantizationParams(&tmp_quant_param, tmp_quant_param.min, tmp_quant_param.max, weight_numbits,
|
||||
weight_narrow_range);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Failed to calculate quant parameters.");
|
||||
quants.emplace_back(tmp_quant_param);
|
||||
min_buf++;
|
||||
max_buf++;
|
||||
}
|
||||
input_quant_params->insert({1, quants});
|
||||
quant_params->resize(scale.size());
|
||||
for (size_t i = 0; i < scale.size(); i++) {
|
||||
quant_param.inited = True;
|
||||
quant_param.scale = scale[i];
|
||||
quant_param.zeroPoint = zp[i];
|
||||
(*quant_params)[i] = quant_param;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits,
|
||||
std::map<int, std::vector<schema::QuantParamT>> *output_quant_params) {
|
||||
int ConvertNodesQuantParam(const std::vector<std::shared_ptr<AnfNode>> &nodes,
|
||||
std::map<int, std::vector<schema::QuantParamT>> *quant_params) {
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
schema::QuantParamT quant_param;
|
||||
auto outputMin = prim->GetAttr("output_minq");
|
||||
auto outputMax = prim->GetAttr("output_maxq");
|
||||
if (outputMin != nullptr && outputMax != nullptr) {
|
||||
auto outputMinPtr = outputMin->cast<tensor::TensorPtr>();
|
||||
auto outputMaxPtr = outputMax->cast<tensor::TensorPtr>();
|
||||
MS_ASSERT(outputMinPtr != nullptr);
|
||||
MS_ASSERT(outputMaxPtr != nullptr);
|
||||
MS_CHECK_TRUE_MSG(outputMinPtr->data_c() != nullptr, RET_ERROR, "outputMinPtr->data_c() is nullptr");
|
||||
MS_CHECK_TRUE_MSG(outputMaxPtr->data_c() != nullptr, RET_ERROR, "outputMaxPtr->data_c() is nullptr");
|
||||
auto *minBuf = static_cast<float *>(outputMinPtr->data_c());
|
||||
auto *maxBuf = static_cast<float *>(outputMaxPtr->data_c());
|
||||
quant_param.min = *minBuf;
|
||||
quant_param.max = *maxBuf;
|
||||
auto ret = CalQuantizationParams(&quant_param, quant_param.min, quant_param.max, numbits, narrow_range);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Failed to calculate quant parameters.");
|
||||
quants.emplace_back(quant_param);
|
||||
output_quant_params->insert({0, quants});
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int GetNarrowRange(const PrimitivePtr &prim, const std::string &narrow_range_str, bool *const narrow_range_param) {
|
||||
MS_CHECK_TRUE_MSG(narrow_range_param != nullptr, RET_ERROR, "narrow_range_param is nullptr");
|
||||
auto narrow_range = prim->GetAttr(narrow_range_str);
|
||||
if (narrow_range != nullptr) {
|
||||
if (utils::isa<tensor::TensorPtr>(narrow_range)) {
|
||||
auto narrow_range_tensor = narrow_range->cast<tensor::TensorPtr>();
|
||||
MS_ASSERT(narrow_range_tensor != nullptr);
|
||||
MS_CHECK_TRUE_MSG(narrow_range_tensor->data_c() != nullptr, RET_ERROR,
|
||||
"narrow_range_tensor->data_c() is nullptr");
|
||||
*narrow_range_param = *reinterpret_cast<bool *>(narrow_range_tensor->data_c());
|
||||
} else if (utils::isa<ImmTraits<bool>::type>(narrow_range)) {
|
||||
*narrow_range_param = GetValue<bool>(narrow_range);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "valueptr is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
for (size_t i = 0; i < nodes.size(); i++) {
|
||||
quants.clear();
|
||||
if (IsPrimitiveCNode(nodes[i], prim::kPrimFakeQuantParam)) {
|
||||
auto fake_quant_prim =
|
||||
ops::GetOperator<mindspore::ops::FakeQuantParam>(nodes[i]->cast<CNodePtr>()->input(0)->cast<ValueNodePtr>());
|
||||
auto status = ConvertQuantParam(fake_quant_prim, &quants);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert quant param from FakeQuantParam operation failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (!quants.empty()) {
|
||||
quant_params->insert({i, quants});
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int GetNumBits(const PrimitivePtr &prim, const std::string &num_bits_str, int *const num_bits_param) {
|
||||
MS_CHECK_TRUE_MSG(num_bits_param != nullptr, RET_ERROR, "num_bits_param is nullptr");
|
||||
auto num_bits = prim->GetAttr(num_bits_str);
|
||||
if (num_bits != nullptr) {
|
||||
if (utils::isa<tensor::TensorPtr>(num_bits)) {
|
||||
auto num_bits_tensor = num_bits->cast<tensor::TensorPtr>();
|
||||
MS_ASSERT(num_bits_tensor != nullptr);
|
||||
MS_CHECK_TRUE_MSG(num_bits_tensor->data_c() != nullptr, RET_ERROR, "num_bits_tensor->data_c() is nullptr");
|
||||
MS_CHECK_TRUE_MSG(num_bits_tensor->data().nbytes() >= static_cast<int>(sizeof(int64_t)), RET_ERROR,
|
||||
"num_bits_tensor->data_c() is not longer enough for int64_t");
|
||||
*num_bits_param = *reinterpret_cast<int64_t *>(num_bits_tensor->data_c());
|
||||
} else if (utils::isa<ImmTraits<int64_t>::type>(num_bits)) {
|
||||
*num_bits_param = GetValue<int64_t>(num_bits);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "valueptr is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
int RemoveFakeQuantParam(const FuncGraphPtr &fg) {
|
||||
MS_CHECK_TRUE_MSG(fg != nullptr, RET_NULL_PTR, "fg is nullptr.");
|
||||
auto manager = fg->manager();
|
||||
auto node_list = TopoSort(fg->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimFakeQuantParam)) {
|
||||
auto inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (std::any_of(inputs.begin(), inputs.end(), [](const std::shared_ptr<AnfNode> &input) {
|
||||
return IsPrimitiveCNode(input, prim::kPrimFakeQuantParam);
|
||||
})) {
|
||||
MS_LOG(ERROR) << "Two FakeQuantParam operators can't be joined together in mindir origin model";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto iter = manager->node_users().find(node);
|
||||
if (iter != manager->node_users().end()) {
|
||||
auto outputs_set = manager->node_users()[node];
|
||||
if (std::any_of(outputs_set.begin(), outputs_set.end(),
|
||||
[](const std::pair<std::shared_ptr<AnfNode>, int> &output) {
|
||||
return IsPrimitiveCNode(output.first, prim::kPrimFakeQuantParam);
|
||||
})) {
|
||||
MS_LOG(ERROR) << "Two FakeQuantParam operators can't be joined together in mindir origin model";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto pre_node = node->cast<CNodePtr>()->input(1);
|
||||
(void)manager->Replace(node, pre_node);
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
bool input_narrow_range_param = false;
|
||||
auto status = GetNarrowRange(prim, "input_narrow_range", &input_narrow_range_param);
|
||||
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Get input narrow range failed.");
|
||||
bool weight_narrow_range_param = true;
|
||||
status = GetNarrowRange(prim, "weight_narrow_range", &weight_narrow_range_param);
|
||||
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Get weight narrow range failed.");
|
||||
bool output_narrow_range_param = false;
|
||||
status = GetNarrowRange(prim, "output_narrow_range", &output_narrow_range_param);
|
||||
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Get output narrow range failed.");
|
||||
int32_t act_num_bits_param = 8;
|
||||
status = GetNumBits(prim, "act_num_bits", &act_num_bits_param);
|
||||
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Get activation num_bits failed.");
|
||||
int32_t weight_num_bits_param = 8;
|
||||
status = GetNumBits(prim, "weight_num_bits", &weight_num_bits_param);
|
||||
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Get weight num_bits failed.");
|
||||
int GetNodeQuantParam(std::shared_ptr<AnfNode> anf_node, const PrimitivePtr &primitive,
|
||||
const FuncGraphManagerPtr &manager) {
|
||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||
MS_LOG(INFO) << "Only cnode need to convert primitive.";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
|
||||
std::map<int, std::vector<schema::QuantParamT>> input_quant_params;
|
||||
std::map<int, std::vector<schema::QuantParamT>> output_quant_params;
|
||||
status = ConvertInputQuantParam(prim, input_narrow_range_param, weight_narrow_range_param, act_num_bits_param,
|
||||
weight_num_bits_param, &input_quant_params);
|
||||
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Compute input quant param failed.");
|
||||
status = ConvertOutputQuantParam(prim, output_narrow_range_param, act_num_bits_param, &output_quant_params);
|
||||
MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "Compute output quant param failed.");
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
inputs.erase(inputs.begin());
|
||||
auto status = ConvertNodesQuantParam(inputs, &input_quant_params);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert input quant param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto iter = manager->node_users().find(anf_node);
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
if (iter != manager->node_users().end()) {
|
||||
auto outputs_set = manager->node_users()[anf_node];
|
||||
std::transform(outputs_set.begin(), outputs_set.end(), std::back_inserter(outputs),
|
||||
[](const std::pair<std::shared_ptr<AnfNode>, int> &output) { return output.first; });
|
||||
status = ConvertNodesQuantParam(outputs, &output_quant_params);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert output quant param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (output_quant_params.size() > 1 || (output_quant_params.size() == 1 && outputs.size() != 1)) {
|
||||
MS_LOG(ERROR) << "There can only be one FakeQuantParam as the output of " << anf_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (!input_quant_params.empty() || !output_quant_params.empty()) {
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(inputs.size(), outputs.size());
|
||||
MSLITE_CHECK_PTR(quant_params_holder);
|
||||
for (auto &iter : input_quant_params) {
|
||||
quant_params_holder->set_input_quant_param(iter.first, iter.second);
|
||||
for (auto &input : input_quant_params) {
|
||||
quant_params_holder->set_input_quant_param(input.first, input.second);
|
||||
}
|
||||
for (auto &iter : output_quant_params) {
|
||||
quant_params_holder->set_output_quant_param(iter.first, iter.second);
|
||||
for (auto &output : output_quant_params) {
|
||||
quant_params_holder->set_output_quant_param(output.first, output.second);
|
||||
}
|
||||
prim->AddAttr("quant_params", quant_params_holder);
|
||||
primitive->AddAttr("quant_params", quant_params_holder);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -216,7 +183,7 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) {
|
|||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto value_node = anf_node->cast<ValueNodePtr>();
|
||||
MS_ASSERT(value_node != nullptr);
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "value_node is nullptr");
|
||||
if (value_node->abstract() == nullptr) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
|
@ -260,7 +227,7 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) {
|
|||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
|
||||
int MindirAdjust::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) {
|
||||
int MindirAdjust::ConvertQuantParams(std::shared_ptr<AnfNode> anf_node, const FuncGraphManagerPtr &manager) {
|
||||
MS_CHECK_TRUE_MSG(anf_node != nullptr, RET_ERROR, "anf_node is nullptr");
|
||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||
MS_LOG(INFO) << "only cnode need to convert primitive.";
|
||||
|
@ -290,30 +257,7 @@ int MindirAdjust::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) {
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
inputs.erase(inputs.begin());
|
||||
|
||||
if (ConvertQuantParam(primitive, inputs) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "compute quant param failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int MindirAdjust::UpdateConv2DTransposeInput(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (!opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs.size() != opt::kInputSizeFour) {
|
||||
MS_LOG(ERROR) << "the input size of mindir conv2dtranspose should be 4, but now is " << inputs.size()
|
||||
<< ", please check.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
inputs.pop_back();
|
||||
cnode->set_inputs(inputs);
|
||||
return RET_OK;
|
||||
return GetNodeQuantParam(anf_node, primitive, manager);
|
||||
}
|
||||
|
||||
int MindirAdjust::ResetFuncGraph(const FuncGraphPtr &fg, std::set<FuncGraphPtr> all_func_graphs) {
|
||||
|
@ -322,11 +266,21 @@ int MindirAdjust::ResetFuncGraph(const FuncGraphPtr &fg, std::set<FuncGraphPtr>
|
|||
MS_CHECK_TRUE_MSG(manager != nullptr, RET_NULL_PTR, "manager is nullptr.");
|
||||
manager->Clear();
|
||||
manager->AddFuncGraph(fg, true);
|
||||
auto status = RemoveFakeQuantParam(fg);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Remove FakeQuantParam operators failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (auto &item : all_func_graphs) {
|
||||
if (item == fg) {
|
||||
continue;
|
||||
}
|
||||
manager->AddFuncGraph(item);
|
||||
status = RemoveFakeQuantParam(item);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Remove FakeQuantParam operators failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -340,12 +294,14 @@ bool MindirAdjust::Run(const FuncGraphPtr &func_graph) {
|
|||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(func_graph, &all_func_graphs);
|
||||
for (auto &graph : all_func_graphs) {
|
||||
auto manager = graph->manager();
|
||||
MS_CHECK_TRUE_MSG(manager != nullptr, RET_NULL_PTR, "manager is nullptr.");
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
int status = lite::RET_OK;
|
||||
bool success_flag = true;
|
||||
for (auto &node : node_list) {
|
||||
if (utils::isa<CNodePtr>(node)) {
|
||||
status = ComputeQuantParams(node);
|
||||
status = ConvertQuantParams(node, manager);
|
||||
} else if (utils::isa<ParameterPtr>(node)) {
|
||||
status = AdjustInputDataType(node);
|
||||
} else if (utils::isa<ValueNodePtr>(node)) {
|
||||
|
|
|
@ -36,8 +36,7 @@ class MindirAdjust {
|
|||
private:
|
||||
int AdjustInputDataType(AnfNodePtr anf_node);
|
||||
int ValueNodeInt64Convert(AnfNodePtr anf_node);
|
||||
int ComputeQuantParams(AnfNodePtr anf_node);
|
||||
int UpdateConv2DTransposeInput(const CNodePtr &cnode);
|
||||
int ConvertQuantParams(AnfNodePtr anf_node, const FuncGraphManagerPtr &manager);
|
||||
int ResetFuncGraph(const FuncGraphPtr &fg, std::set<FuncGraphPtr> all_func_graphs);
|
||||
|
||||
FmkType fmk_type_ = FmkType::kFmkTypeMs;
|
||||
|
|
|
@ -117,8 +117,22 @@ class FakeQuantParam(Primitive):
|
|||
|
||||
@classmethod
|
||||
def linear_quant_param(cls, quant_dtype, scale, zp, is_per_channel=False, **kwargs):
|
||||
kwargs[FakeQuantParam.attr_key_linear_quant_scale] = scale
|
||||
kwargs[FakeQuantParam.attr_key_linear_quant_zero_point] = zp
|
||||
"""
|
||||
Create a linear quantization operator based on scale and zero-point parameter.
|
||||
"""
|
||||
validator.check_value_type("scale", scale, [float, tuple, list], "FakeQuantParam")
|
||||
if isinstance(scale, float):
|
||||
scale_list = [scale]
|
||||
else:
|
||||
scale_list = scale
|
||||
validator.check_value_type("zero_point", zp, [int, tuple, list], "FakeQuantParam")
|
||||
if isinstance(zp, int):
|
||||
zp_list = [zp]
|
||||
else:
|
||||
zp_list = zp
|
||||
validator.check_value_type("is_per_channel", is_per_channel, [bool], "FakeQuantParam")
|
||||
kwargs[FakeQuantParam.attr_key_linear_quant_scale] = scale_list
|
||||
kwargs[FakeQuantParam.attr_key_linear_quant_zero_point] = zp_list
|
||||
return cls(quant_dtype, FakeQuantParam.attr_value_linear_quant_algo_name, is_per_channel, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -49,59 +49,13 @@ TEST_F(TestFakeQuantParam, test_attr_perlayer) {
|
|||
auto perchannel = ops->get_is_perchannel();
|
||||
EXPECT_EQ(perchannel, false);
|
||||
|
||||
bool has_error = false;
|
||||
try {
|
||||
ops->set_scale(1.0, 1);
|
||||
} catch (...) {
|
||||
has_error = true;
|
||||
}
|
||||
EXPECT_EQ(has_error, true);
|
||||
ops->set_scales({1.0});
|
||||
auto scale = ops->get_scales();
|
||||
EXPECT_EQ(scale[0], 1.0);
|
||||
|
||||
ops->set_scale(1.0);
|
||||
auto scale = ops->get_scale();
|
||||
EXPECT_EQ(scale, 1.0);
|
||||
|
||||
ops->set_zero_point(1);
|
||||
auto zp = ops->get_zero_point();
|
||||
EXPECT_EQ(zp, 1);
|
||||
|
||||
ops->set_quant_param("slb-rate", api::MakeValue<float>(1.0));
|
||||
auto slb_rate_value = ops->get_quant_param("slb-rate");
|
||||
EXPECT_EQ(slb_rate_value->isa<api::FP32Imm>(), true);
|
||||
auto slb_rate_imm = slb_rate_value->cast<api::FP32ImmPtr>();
|
||||
auto slb_rate = slb_rate_imm->value();
|
||||
EXPECT_EQ(slb_rate, 1.0);
|
||||
}
|
||||
|
||||
/// Feature: setter and getter of per-channel FakeQuantParam operation.
|
||||
/// Description: call setter and getter of FakeQuantParam operation and compare result of getter with argument of
|
||||
/// setter.
|
||||
/// Expectation: success.
|
||||
TEST_F(TestFakeQuantParam, test_attr_perchannel) {
|
||||
auto ops = std::make_shared<FakeQuantParam>();
|
||||
ops->Init(kQuantDataTypeInt7, kAttrKeyLinearQuantAlgoName, true);
|
||||
auto quant_dtype = ops->get_quant_dtype();
|
||||
EXPECT_EQ(quant_dtype, kQuantDataTypeInt7);
|
||||
auto algo_name = ops->get_quant_algo_name();
|
||||
EXPECT_EQ(algo_name, kAttrKeyLinearQuantAlgoName);
|
||||
auto perchannel = ops->get_is_perchannel();
|
||||
EXPECT_EQ(perchannel, true);
|
||||
|
||||
bool has_error = false;
|
||||
try {
|
||||
ops->set_scale(1.0, 1);
|
||||
} catch (...) {
|
||||
has_error = true;
|
||||
}
|
||||
EXPECT_EQ(has_error, true);
|
||||
|
||||
ops->set_scale(1.0);
|
||||
auto scale = ops->get_scale();
|
||||
EXPECT_EQ(scale, 1.0);
|
||||
|
||||
ops->set_zero_point(1);
|
||||
auto zp = ops->get_zero_point();
|
||||
EXPECT_EQ(zp, 1);
|
||||
ops->set_zero_points({1});
|
||||
auto zp = ops->get_zero_points();
|
||||
EXPECT_EQ(zp[0], 1);
|
||||
|
||||
ops->set_quant_param("slb-rate", api::MakeValue<float>(1.0));
|
||||
auto slb_rate_value = ops->get_quant_param("slb-rate");
|
||||
|
@ -117,8 +71,8 @@ TEST_F(TestFakeQuantParam, test_attr_perchannel) {
|
|||
TEST_F(TestFakeQuantParam, test_infer_shape) {
|
||||
auto ops = std::make_shared<FakeQuantParam>();
|
||||
ops->Init(kQuantDataTypeInt7, kAttrKeyLinearQuantAlgoName, false);
|
||||
ops->set_scale(1.0);
|
||||
ops->set_zero_point(1);
|
||||
ops->set_scales({1.0});
|
||||
ops->set_zero_points({1});
|
||||
|
||||
auto input_x = TensorConstructUtils::CreateOnesTensor(kFloat32, std::vector<int64_t>{32, 3, 224, 224});
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
|
|
Loading…
Reference in New Issue