forked from mindspore-Ecosystem/mindspore
!15349 [MS][LITE][TOD] Add quantization to convert of ToD and Dequantization in RunTime
From: @ehaleva Reviewed-by: @HilbertDavid,@hangangqiang Signed-off-by: @HilbertDavid
This commit is contained in:
commit
d0a2f3866d
|
@ -53,9 +53,6 @@ int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
|
|||
if ((inputs[i]->shape_size_ != max_dims) && (GetElementNum(inputs[i]) != GetElementNum(inputs[max_dims_idx]))) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
if (inputs[i]->data_type_ != inputs[0]->data_type_) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t d = 0; d < input->shape_size_; ++d) {
|
||||
|
|
|
@ -39,6 +39,10 @@ int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso
|
|||
size_t input_shape1_size = input1->shape_size_;
|
||||
output->format_ = input0->format_;
|
||||
output->data_type_ = input0->data_type_;
|
||||
if ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32)) {
|
||||
output->data_type_ = input1->data_type_;
|
||||
}
|
||||
|
||||
if (!parameter->infer_flag_) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
|
|
@ -28,5 +28,10 @@ if [ ! -f "$CONVERTER" ]; then
|
|||
fi
|
||||
|
||||
echo "============Converting========="
|
||||
LD_LIBRARY_PATH=./ $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod
|
||||
QUANT_OPTIONS=""
|
||||
if [[ ! -z ${QUANTIZE} ]]; then
|
||||
echo "Quantizing weights"
|
||||
QUANT_OPTIONS="--quantType=WeightQuant --bitNum=8 --quantWeightSize=100 --quantWeightChannel=15"
|
||||
fi
|
||||
LD_LIBRARY_PATH=./ $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
display_usage()
|
||||
{
|
||||
echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86]\n"
|
||||
echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86] [-q]\n"
|
||||
}
|
||||
|
||||
checkopts()
|
||||
|
@ -10,7 +10,8 @@ checkopts()
|
|||
TARGET="arm64"
|
||||
DOCKER=""
|
||||
MNIST_DATA_PATH=""
|
||||
while getopts 'D:d:r:t:' opt
|
||||
QUANTIZE=""
|
||||
while getopts 'D:d:r:t:q' opt
|
||||
do
|
||||
case "${opt}" in
|
||||
D)
|
||||
|
@ -31,6 +32,9 @@ checkopts()
|
|||
r)
|
||||
TARBALL=$OPTARG
|
||||
;;
|
||||
q)
|
||||
QUANTIZE="QUANTIZE"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option ${opt}!"
|
||||
display_usage
|
||||
|
@ -64,7 +68,7 @@ fi
|
|||
# Prepare the model
|
||||
cd model/ || exit 1
|
||||
rm -f *.ms
|
||||
./prepare_model.sh $DOCKER || exit 1
|
||||
QUANTIZE=${QUANTIZE} ./prepare_model.sh $DOCKER || exit 1
|
||||
cd ../
|
||||
|
||||
# Copy the .ms model to the package folder
|
||||
|
|
|
@ -110,6 +110,9 @@ void NetRunner::InitAndFigureInputs() {
|
|||
MS_ASSERT(nullptr != session_);
|
||||
loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_);
|
||||
|
||||
if (verbose_) {
|
||||
loop_->SetKernelCallBack(nullptr, after_callback);
|
||||
}
|
||||
acc_metrics_ = std::shared_ptr<AccuracyMetrics>(new AccuracyMetrics);
|
||||
|
||||
loop_->Init({acc_metrics_.get()});
|
||||
|
@ -125,11 +128,11 @@ void NetRunner::InitAndFigureInputs() {
|
|||
|
||||
float NetRunner::CalculateAccuracy(int max_tests) {
|
||||
test_ds_ = Mnist(data_dir_ + "/test", "all");
|
||||
TypeCast typecast_f("float32");
|
||||
TypeCast typecast_f(mindspore::DataType::kNumberTypeFloat32);
|
||||
Resize resize({h_, w_});
|
||||
test_ds_ = test_ds_->Map({&resize, &typecast_f}, {"image"});
|
||||
|
||||
TypeCast typecast("int32");
|
||||
TypeCast typecast(mindspore::DataType::kNumberTypeInt32);
|
||||
test_ds_ = test_ds_->Map({&typecast}, {"label"});
|
||||
test_ds_ = test_ds_->Batch(batch_size_, true);
|
||||
|
||||
|
@ -144,14 +147,14 @@ float NetRunner::CalculateAccuracy(int max_tests) {
|
|||
int NetRunner::InitDB() {
|
||||
train_ds_ = Mnist(data_dir_ + "/train", "all");
|
||||
|
||||
TypeCast typecast_f("float32");
|
||||
TypeCast typecast_f(mindspore::DataType::kNumberTypeFloat32);
|
||||
Resize resize({h_, w_});
|
||||
train_ds_ = train_ds_->Map({&resize, &typecast_f}, {"image"});
|
||||
|
||||
TypeCast typecast("int32");
|
||||
TypeCast typecast(mindspore::DataType::kNumberTypeInt32);
|
||||
train_ds_ = train_ds_->Map({&typecast}, {"label"});
|
||||
|
||||
train_ds_ = train_ds_->Shuffle(2);
|
||||
// train_ds_ = train_ds_->Shuffle(2);
|
||||
train_ds_ = train_ds_->Batch(batch_size_, true);
|
||||
|
||||
if (verbose_) {
|
||||
|
|
|
@ -187,7 +187,7 @@ int NetRunner::TrainLoop() {
|
|||
if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) {
|
||||
auto cpkt_fn =
|
||||
ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms";
|
||||
session_->SaveToFile(cpkt_fn);
|
||||
mindspore::lite::Model::Export(head_model_, cpkt_fn.c_str());
|
||||
}
|
||||
|
||||
std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl;
|
||||
|
@ -213,7 +213,7 @@ int NetRunner::Main() {
|
|||
|
||||
if (cycles_ > 0) {
|
||||
auto trained_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained.ms";
|
||||
session_->SaveToFile(trained_fn);
|
||||
mindspore::lite::Model::Export(head_model_, trained_fn.c_str());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -44,6 +44,8 @@ class NetRunner {
|
|||
|
||||
DataSet ds_;
|
||||
mindspore::session::TrainSession *session_ = nullptr;
|
||||
mindspore::lite::Model *backbone_model_ = nullptr;
|
||||
mindspore::lite::Model *head_model_ = nullptr;
|
||||
|
||||
std::string ms_backbone_file_ = "";
|
||||
std::string ms_head_file_ = "";
|
||||
|
|
|
@ -176,10 +176,6 @@ int Flags::InitTrainModel() {
|
|||
std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (this->quantType != QuantType_QUANT_NONE) {
|
||||
std::cerr << "INPUT ILLEGAL: train model converter is not supporting quantization";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -181,6 +181,57 @@ bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &inputNode) const {
|
||||
if (inputNode == nullptr) {
|
||||
MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
|
||||
return false;
|
||||
}
|
||||
ParameterPtr paramNode = nullptr;
|
||||
|
||||
if (inputNode->isa<Parameter>()) {
|
||||
paramNode = inputNode->cast<ParameterPtr>();
|
||||
}
|
||||
|
||||
if (paramNode == nullptr) {
|
||||
MS_LOG(INFO) << "CanTensorQuantized invalid paramNode!";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto abstract_base = paramNode->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(INFO) << "abstract is nullptr";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
|
||||
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
|
||||
return false;
|
||||
}
|
||||
|
||||
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
|
||||
if (weight_shape.size() < 2) { // do not quant single dim tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t shapeSize = 1;
|
||||
for (auto dim : weight_shape) {
|
||||
shapeSize = shapeSize * dim;
|
||||
}
|
||||
if (shapeSize < m_weight_size_) {
|
||||
MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (weight_shape.size() == 4) { // assume Convolution
|
||||
if (weight_shape[0] <= static_cast<int>(m_conv_weight_quant_channel_threshold_)) {
|
||||
MS_LOG(INFO) << "channel less m_conv_weight_quant_channel_threshold_!" << weight_shape[0];
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
QuantParamHolderPtr quant_params_holder = nullptr;
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_
|
||||
|
||||
#include <dirent.h>
|
||||
#include <sys/stat.h>
|
||||
|
@ -83,6 +83,7 @@ class QuantStrategy {
|
|||
bool CanConvOpQuantized(const CNodePtr &node) const;
|
||||
bool CanMulOpQuantized(const CNodePtr &node) const;
|
||||
bool CanOpPostQuantized(AnfNodePtr &node) const;
|
||||
bool CanTensorQuantized(const AnfNodePtr &inputNode) const;
|
||||
|
||||
size_t m_weight_size_;
|
||||
size_t m_conv_weight_quant_channel_threshold_;
|
||||
|
@ -417,4 +418,4 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &);
|
|||
|
||||
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_
|
||||
|
|
|
@ -75,6 +75,7 @@ STATUS WeightQuantizer::SetAbstract(const tensor::TensorPtr &tensor_info, const
|
|||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
quant_param_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT);
|
||||
|
||||
weight_quantized_tensors.insert({tensor_info, param_node});
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -244,6 +245,82 @@ STATUS WeightQuantizer::DoGatherQuantize(const CNodePtr &cnode) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoOptimizerQuantize(const CNodePtr &cnode) {
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
|
||||
std::vector<int> weight_indices = {2};
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) {
|
||||
weight_indices = {2, 3};
|
||||
}
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
|
||||
weight_indices = {4, 6};
|
||||
}
|
||||
|
||||
for (int idx : weight_indices) {
|
||||
auto input = cnode->input(idx);
|
||||
if (!quant_strategy_->CanTensorQuantized(input)) {
|
||||
MS_LOG(INFO) << "Input " << idx << "of Optimizer is not quantizable";
|
||||
continue;
|
||||
}
|
||||
ParameterPtr param_node;
|
||||
tensor::TensorPtr tensor_info;
|
||||
GetLiteParameter(input, ¶m_node, &tensor_info);
|
||||
if (param_node == nullptr || tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, type_id_, idx - 1);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, type_id_, idx - 1);
|
||||
}
|
||||
if (status != RET_OK && status != RET_CONTINUE) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(tensor_info, param_node, primitive);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMarkWeightQuantizeIfQuantized(const CNodePtr &cnode) {
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
if (quant_param_holder->quant_type() == schema::QuantType_QUANT_WEIGHT) {
|
||||
// already marked with QUANT_WEIGHT
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < cnode->size(); i++) {
|
||||
auto inputNode = cnode->input(i);
|
||||
if (inputNode->isa<Parameter>()) {
|
||||
ParameterPtr param_node;
|
||||
tensor::TensorPtr tensor_info;
|
||||
GetLiteParameter(inputNode, ¶m_node, &tensor_info);
|
||||
auto param = weight_quantized_tensors.find(tensor_info);
|
||||
if (param != weight_quantized_tensors.end()) {
|
||||
quant_param_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const PrimitivePtr &primitive,
|
||||
const int &index) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
|
@ -649,6 +726,8 @@ STATUS WeightQuantizer::DoMixedQuant(const FuncGraphPtr &func_graph) {
|
|||
|
||||
STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
weight_quantized_tensors.clear();
|
||||
|
||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
|
@ -681,10 +760,34 @@ STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "DoGatherQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if ((opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) || (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) ||
|
||||
(opt::CheckPrimitiveType(cnode, prim::kPrimApplyMomentum))) {
|
||||
auto status = DoOptimizerQuantize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoOptimizerQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << op_name << " of type: " << primitive->name() << " no need quant";
|
||||
}
|
||||
}
|
||||
return MarkWeightQuantizationInNodes(func_graph);
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::MarkWeightQuantizationInNodes(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
|
||||
continue;
|
||||
}
|
||||
auto status = DoMarkWeightQuantizeIfQuantized(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "MarkWeightQuantizationInNodes error marking " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
@ -43,6 +43,7 @@ class WeightQuantizer : public Quantizer {
|
|||
STATUS DoQuantize(FuncGraphPtr func_graph) override;
|
||||
STATUS DoConvQuantize(const CNodePtr &);
|
||||
STATUS DoMulQuantize(const CNodePtr &);
|
||||
STATUS DoOptimizerQuantize(const CNodePtr &);
|
||||
STATUS DoLstmQuantize(const CNodePtr &cnode);
|
||||
STATUS DoGatherQuantize(const CNodePtr &cnode);
|
||||
|
||||
|
@ -57,6 +58,7 @@ class WeightQuantizer : public Quantizer {
|
|||
std::unique_ptr<QuantStrategy> quant_strategy_;
|
||||
size_t bit_num_{8};
|
||||
std::string config_file_;
|
||||
std::map<tensor::TensorPtr, ParameterPtr> weight_quantized_tensors;
|
||||
PostQuantConfig config_param_;
|
||||
std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
|
||||
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
|
||||
|
@ -65,6 +67,8 @@ class WeightQuantizer : public Quantizer {
|
|||
STATUS SetAbstract(const tensor::TensorPtr &tensor_info, const ParameterPtr ¶m_node,
|
||||
const PrimitivePtr &primitive);
|
||||
STATUS DoFixedQuant(const FuncGraphPtr &);
|
||||
STATUS MarkWeightQuantizationInNodes(const FuncGraphPtr &);
|
||||
STATUS DoMarkWeightQuantizeIfQuantized(const CNodePtr &);
|
||||
STATUS RunFp32Graph(const FuncGraphPtr &);
|
||||
|
||||
STATUS DoMixedQuantize(const FuncGraphPtr &func_graph);
|
||||
|
@ -74,6 +78,7 @@ class WeightQuantizer : public Quantizer {
|
|||
STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const tensor::TensorPtr &tensor_info,
|
||||
const PrimitivePtr &primitive);
|
||||
STATUS DoQuantSearch(const FuncGraphPtr &func_graph);
|
||||
STATUS DoTensorQuantize(const CNodePtr &);
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
|
||||
|
|
Loading…
Reference in New Issue