diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/addn_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/addn_infer.c index 07133d18923..66b1ee9b5ae 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/addn_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/addn_infer.c @@ -53,9 +53,6 @@ int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o if ((inputs[i]->shape_size_ != max_dims) && (GetElementNum(inputs[i]) != GetElementNum(inputs[max_dims_idx]))) { return NNACL_ERR; } - if (inputs[i]->data_type_ != inputs[0]->data_type_) { - return NNACL_ERR; - } } for (size_t d = 0; d < input->shape_size_; ++d) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c index edf38a52bb5..74f6ae4eeca 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c @@ -39,6 +39,10 @@ int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso size_t input_shape1_size = input1->shape_size_; output->format_ = input0->format_; output->data_type_ = input0->data_type_; + if ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32)) { + output->data_type_ = input1->data_type_; + } + if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } diff --git a/mindspore/lite/examples/train_lenet/model/prepare_model.sh b/mindspore/lite/examples/train_lenet/model/prepare_model.sh index ae6816d80fe..5656233b554 100755 --- a/mindspore/lite/examples/train_lenet/model/prepare_model.sh +++ b/mindspore/lite/examples/train_lenet/model/prepare_model.sh @@ -28,5 +28,10 @@ if [ ! -f "$CONVERTER" ]; then fi echo "============Converting=========" -LD_LIBRARY_PATH=./ $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod +QUANT_OPTIONS="" +if [[ ! -z ${QUANTIZE} ]]; then + echo "Quantizing weights" + QUANT_OPTIONS="--quantType=WeightQuant --bitNum=8 --quantWeightSize=100 --quantWeightChannel=15" +fi +LD_LIBRARY_PATH=./ $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS diff --git a/mindspore/lite/examples/train_lenet/prepare_and_run.sh b/mindspore/lite/examples/train_lenet/prepare_and_run.sh index 904f6c59a3d..d2757f4d236 100755 --- a/mindspore/lite/examples/train_lenet/prepare_and_run.sh +++ b/mindspore/lite/examples/train_lenet/prepare_and_run.sh @@ -2,7 +2,7 @@ display_usage() { - echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86]\n" + echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86] [-q]\n" } checkopts() @@ -10,7 +10,8 @@ checkopts() TARGET="arm64" DOCKER="" MNIST_DATA_PATH="" - while getopts 'D:d:r:t:' opt + QUANTIZE="" + while getopts 'D:d:r:t:q' opt do case "${opt}" in D) @@ -31,6 +32,9 @@ checkopts() r) TARBALL=$OPTARG ;; + q) + QUANTIZE="QUANTIZE" + ;; *) echo "Unknown option ${opt}!" display_usage @@ -64,7 +68,7 @@ fi # Prepare the model cd model/ || exit 1 rm -f *.ms -./prepare_model.sh $DOCKER || exit 1 +QUANTIZE=${QUANTIZE} ./prepare_model.sh $DOCKER || exit 1 cd ../ # Copy the .ms model to the package folder diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.cc b/mindspore/lite/examples/train_lenet/src/net_runner.cc index f52ae1c4dd4..5e4dea836fc 100644 --- a/mindspore/lite/examples/train_lenet/src/net_runner.cc +++ b/mindspore/lite/examples/train_lenet/src/net_runner.cc @@ -110,6 +110,9 @@ void NetRunner::InitAndFigureInputs() { MS_ASSERT(nullptr != session_); loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_); + if (verbose_) { + loop_->SetKernelCallBack(nullptr, after_callback); + } acc_metrics_ = std::shared_ptr(new AccuracyMetrics); loop_->Init({acc_metrics_.get()}); @@ -125,11 +128,11 @@ void NetRunner::InitAndFigureInputs() { float NetRunner::CalculateAccuracy(int max_tests) { test_ds_ = Mnist(data_dir_ + "/test", "all"); - TypeCast typecast_f("float32"); + TypeCast typecast_f(mindspore::DataType::kNumberTypeFloat32); Resize resize({h_, w_}); test_ds_ = test_ds_->Map({&resize, &typecast_f}, {"image"}); - TypeCast typecast("int32"); + TypeCast typecast(mindspore::DataType::kNumberTypeInt32); test_ds_ = test_ds_->Map({&typecast}, {"label"}); test_ds_ = test_ds_->Batch(batch_size_, true); @@ -144,14 +147,14 @@ float NetRunner::CalculateAccuracy(int max_tests) { int NetRunner::InitDB() { train_ds_ = Mnist(data_dir_ + "/train", "all"); - TypeCast typecast_f("float32"); + TypeCast typecast_f(mindspore::DataType::kNumberTypeFloat32); Resize resize({h_, w_}); train_ds_ = train_ds_->Map({&resize, &typecast_f}, {"image"}); - TypeCast typecast("int32"); + TypeCast typecast(mindspore::DataType::kNumberTypeInt32); train_ds_ = train_ds_->Map({&typecast}, {"label"}); - train_ds_ = train_ds_->Shuffle(2); + // train_ds_ = train_ds_->Shuffle(2); train_ds_ = train_ds_->Batch(batch_size_, true); if (verbose_) { diff --git a/mindspore/lite/examples/transfer_learning/src/net_runner.cc b/mindspore/lite/examples/transfer_learning/src/net_runner.cc index cc0db4ae537..844f0f60d29 100644 --- a/mindspore/lite/examples/transfer_learning/src/net_runner.cc +++ b/mindspore/lite/examples/transfer_learning/src/net_runner.cc @@ -187,7 +187,7 @@ int NetRunner::TrainLoop() { if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) { auto cpkt_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms"; - session_->SaveToFile(cpkt_fn); + mindspore::lite::Model::Export(head_model_, cpkt_fn.c_str()); } std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl; @@ -213,7 +213,7 @@ int NetRunner::Main() { if (cycles_ > 0) { auto trained_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained.ms"; - session_->SaveToFile(trained_fn); + mindspore::lite::Model::Export(head_model_, trained_fn.c_str()); } return 0; } diff --git a/mindspore/lite/examples/transfer_learning/src/net_runner.h b/mindspore/lite/examples/transfer_learning/src/net_runner.h index a807bf706f3..e25da59a9b6 100644 --- a/mindspore/lite/examples/transfer_learning/src/net_runner.h +++ b/mindspore/lite/examples/transfer_learning/src/net_runner.h @@ -44,6 +44,8 @@ class NetRunner { DataSet ds_; mindspore::session::TrainSession *session_ = nullptr; + mindspore::lite::Model *backbone_model_ = nullptr; + mindspore::lite::Model *head_model_ = nullptr; std::string ms_backbone_file_ = ""; std::string ms_head_file_ = ""; diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 5ff98aaa20c..725e228a1be 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -176,10 +176,6 @@ int Flags::InitTrainModel() { std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors"; return RET_INPUT_PARAM_INVALID; } - if (this->quantType != QuantType_QUANT_NONE) { - std::cerr << "INPUT ILLEGAL: train model converter is not supporting quantization"; - return RET_INPUT_PARAM_INVALID; - } } return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index bed6a79a06c..a8482ad36e0 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -181,6 +181,57 @@ bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { return true; } +bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &inputNode) const { + if (inputNode == nullptr) { + MS_LOG(INFO) << "CanTensorQuantized input is nullptr!"; + return false; + } + ParameterPtr paramNode = nullptr; + + if (inputNode->isa()) { + paramNode = inputNode->cast(); + } + + if (paramNode == nullptr) { + MS_LOG(INFO) << "CanTensorQuantized invalid paramNode!"; + return false; + } + + auto abstract_base = paramNode->abstract(); + if (abstract_base == nullptr) { + MS_LOG(INFO) << "abstract is nullptr"; + return false; + } + + if (!utils::isa(abstract_base->GetShapeTrack())) { + MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); + return false; + } + + auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); + if (weight_shape.size() < 2) { // do not quant single dim tensors + return false; + } + + size_t shapeSize = 1; + for (auto dim : weight_shape) { + shapeSize = shapeSize * dim; + } + if (shapeSize < m_weight_size_) { + MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; + return false; + } + + if (weight_shape.size() == 4) { // assume Convolution + if (weight_shape[0] <= static_cast(m_conv_weight_quant_channel_threshold_)) { + MS_LOG(INFO) << "channel less m_conv_weight_quant_channel_threshold_!" << weight_shape[0]; + return false; + } + } + + return true; +} + QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) { MS_ASSERT(primitive != nullptr); QuantParamHolderPtr quant_params_holder = nullptr; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index cb56f80c2ba..8c12f78f66f 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_ #include #include @@ -83,6 +83,7 @@ class QuantStrategy { bool CanConvOpQuantized(const CNodePtr &node) const; bool CanMulOpQuantized(const CNodePtr &node) const; bool CanOpPostQuantized(AnfNodePtr &node) const; + bool CanTensorQuantized(const AnfNodePtr &inputNode) const; size_t m_weight_size_; size_t m_conv_weight_quant_channel_threshold_; @@ -417,4 +418,4 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info); } // namespace mindspore::lite::quant -#endif +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_ diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index c3235a37a8d..b5c96abb72b 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -75,6 +75,7 @@ STATUS WeightQuantizer::SetAbstract(const tensor::TensorPtr &tensor_info, const auto quant_param_holder = GetCNodeQuantHolder(primitive); quant_param_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT); + weight_quantized_tensors.insert({tensor_info, param_node}); return RET_OK; } @@ -244,6 +245,82 @@ STATUS WeightQuantizer::DoGatherQuantize(const CNodePtr &cnode) { return RET_OK; } +STATUS WeightQuantizer::DoOptimizerQuantize(const CNodePtr &cnode) { + auto primitive = GetValueNode(cnode->input(0)); + MS_ASSERT(primitive != nullptr); + + std::vector weight_indices = {2}; + if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) { + weight_indices = {2, 3}; + } + if (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) { + weight_indices = {4, 6}; + } + + for (int idx : weight_indices) { + auto input = cnode->input(idx); + if (!quant_strategy_->CanTensorQuantized(input)) { + MS_LOG(INFO) << "Input " << idx << "of Optimizer is not quantizable"; + continue; + } + ParameterPtr param_node; + tensor::TensorPtr tensor_info; + GetLiteParameter(input, ¶m_node, &tensor_info); + if (param_node == nullptr || tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32) { + MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight"; + return RET_OK; + } + + auto status = RET_ERROR; + if (type_id_ == kNumberTypeInt8) { + status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + false, type_id_, idx - 1); + } else if (type_id_ == kNumberTypeInt16) { + status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + false, type_id_, idx - 1); + } + if (status != RET_OK && status != RET_CONTINUE) { + MS_LOG(ERROR) << "QuantFilter failed : " << status; + return status; + } + status = SetAbstract(tensor_info, param_node, primitive); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetAbstract failed : " << status; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS WeightQuantizer::DoMarkWeightQuantizeIfQuantized(const CNodePtr &cnode) { + auto primitive = GetValueNode(cnode->input(0)); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_ERROR; + } + + auto quant_param_holder = GetCNodeQuantHolder(primitive); + if (quant_param_holder->quant_type() == schema::QuantType_QUANT_WEIGHT) { + // already marked with QUANT_WEIGHT + return RET_OK; + } + + for (size_t i = 1; i < cnode->size(); i++) { + auto inputNode = cnode->input(i); + if (inputNode->isa()) { + ParameterPtr param_node; + tensor::TensorPtr tensor_info; + GetLiteParameter(inputNode, ¶m_node, &tensor_info); + auto param = weight_quantized_tensors.find(tensor_info); + if (param != weight_quantized_tensors.end()) { + quant_param_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT); + continue; + } + } + } + return RET_OK; +} + STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const PrimitivePtr &primitive, const int &index) { auto op_name = cnode->fullname_with_scope(); @@ -649,6 +726,8 @@ STATUS WeightQuantizer::DoMixedQuant(const FuncGraphPtr &func_graph) { STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); + weight_quantized_tensors.clear(); + for (auto &cnode : func_graph->GetOrderedCnodes()) { auto primitive = GetValueNode>(cnode->input(0)); if (primitive == nullptr) { @@ -681,10 +760,34 @@ STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "DoGatherQuantize error"; return RET_ERROR; } + } else if ((opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) || (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) || + (opt::CheckPrimitiveType(cnode, prim::kPrimApplyMomentum))) { + auto status = DoOptimizerQuantize(cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoOptimizerQuantize error"; + return RET_ERROR; + } } else { MS_LOG(DEBUG) << op_name << " of type: " << primitive->name() << " no need quant"; } } + return MarkWeightQuantizationInNodes(func_graph); +} + +STATUS WeightQuantizer::MarkWeightQuantizationInNodes(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + for (auto &cnode : func_graph->GetOrderedCnodes()) { + auto primitive = GetValueNode>(cnode->input(0)); + if (primitive == nullptr) { + MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr"; + continue; + } + auto status = DoMarkWeightQuantizeIfQuantized(cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "MarkWeightQuantizationInNodes error marking " << cnode->fullname_with_scope(); + return RET_ERROR; + } + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index 111de02303c..13a54701de5 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_ #include #include @@ -43,6 +43,7 @@ class WeightQuantizer : public Quantizer { STATUS DoQuantize(FuncGraphPtr func_graph) override; STATUS DoConvQuantize(const CNodePtr &); STATUS DoMulQuantize(const CNodePtr &); + STATUS DoOptimizerQuantize(const CNodePtr &); STATUS DoLstmQuantize(const CNodePtr &cnode); STATUS DoGatherQuantize(const CNodePtr &cnode); @@ -57,6 +58,7 @@ class WeightQuantizer : public Quantizer { std::unique_ptr quant_strategy_; size_t bit_num_{8}; std::string config_file_; + std::map weight_quantized_tensors; PostQuantConfig config_param_; std::vector> images_; // multi_input, [[mode_input_0], [model_input_1]...] std::vector> fp32_output_tensors_; @@ -65,6 +67,8 @@ class WeightQuantizer : public Quantizer { STATUS SetAbstract(const tensor::TensorPtr &tensor_info, const ParameterPtr ¶m_node, const PrimitivePtr &primitive); STATUS DoFixedQuant(const FuncGraphPtr &); + STATUS MarkWeightQuantizationInNodes(const FuncGraphPtr &); + STATUS DoMarkWeightQuantizeIfQuantized(const CNodePtr &); STATUS RunFp32Graph(const FuncGraphPtr &); STATUS DoMixedQuantize(const FuncGraphPtr &func_graph); @@ -74,6 +78,7 @@ class WeightQuantizer : public Quantizer { STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const tensor::TensorPtr &tensor_info, const PrimitivePtr &primitive); STATUS DoQuantSearch(const FuncGraphPtr &func_graph); + STATUS DoTensorQuantize(const CNodePtr &); }; } // namespace mindspore::lite::quant -#endif +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_