lite train support mix precision model

This commit is contained in:
zhengjun10 2021-09-02 16:20:36 +08:00
parent 50847c9659
commit fc4e6dc533
14 changed files with 105 additions and 39 deletions

View File

@ -36,6 +36,7 @@ class MixPrecisionCfg {
bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */
float loss_scale_; /**< Initial loss scale factor */
uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
bool is_raw_mix_precision_ = false; /**< Is mix precision model export from mindspore */
};
class TrainCfg {

View File

@ -16,7 +16,7 @@
import sys
import numpy as np
from mindspore import context, Tensor
from mindspore import context, Tensor, FixedLossScaleManager
import mindspore.common.dtype as mstype
from mindspore.train.serialization import export
from lenet import LeNet5
@ -32,5 +32,8 @@ x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)
label = Tensor(np.zeros([BATCH_SIZE]).astype(np.int32))
net = train_wrap(n)
export(net, x, label, file_name="lenet_tod", file_format='MINDIR')
loss_scale = 128.0
loss_scale_manager = FixedLossScaleManager(loss_scale, False)
mix_precision_net = train_wrap(n, None, None, None, loss_scale_manager)
export(mix_precision_net, x, label, file_name="mix_lenet_tod", file_format='MINDIR')
print("finished exporting")

View File

@ -46,4 +46,6 @@ if [[ ! -z ${QUANTIZE} ]]; then
QUANT_OPTIONS="--configFile=${WEIGHT_QUANT_CONFIG}"
fi
LD_LIBRARY_PATH=./:${LD_LIBRARY_PATH} $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS
if [ -n "$3" ]; then
LD_LIBRARY_PATH=./:${LD_LIBRARY_PATH} $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=mix_lenet_tod.mindir --outputFile=mix_lenet_tod
fi

View File

@ -16,9 +16,9 @@
import mindspore.nn as nn
from mindspore.common.parameter import ParameterTuple
from mindspore import amp
def train_wrap(net, loss_fn=None, optimizer=None, weights=None):
def train_wrap(net, loss_fn=None, optimizer=None, weights=None, loss_scale_manager=None):
"""
train_wrap
"""
@ -31,5 +31,8 @@ def train_wrap(net, loss_fn=None, optimizer=None, weights=None):
if optimizer is None:
optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False,
use_nesterov=False, weight_decay=4e-5, loss_scale=1.0)
train_net = nn.TrainOneStepCell(loss_net, optimizer)
if loss_scale_manager is None:
train_net = nn.TrainOneStepCell(loss_net, optimizer)
else:
train_net = amp.build_train_network(net, optimizer, loss_fn, level="O2", loss_scale_manager=loss_scale_manager)
return train_net

View File

@ -2,7 +2,7 @@
display_usage()
{
echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86] [-q] [-o] [-b virtual_batch] [-m mindir] [-e epochs_to_train]\n"
echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86] [-q] [-o] [-M] [-b virtual_batch] [-m mindir] [-e epochs_to_train]\n"
}
checkopts()
@ -15,7 +15,8 @@ checkopts()
FP16_FLAG=""
VIRTUAL_BATCH=-1
EPOCHS="-e 5"
while getopts 'D:b:d:e:m:oqr:t:' opt
MIX_FLAG=""
while getopts 'D:b:d:e:m:oqr:t:M:' opt
do
case "${opt}" in
b)
@ -42,6 +43,11 @@ checkopts()
r)
TARBALL=$OPTARG
;;
M)
MIX_FLAG="-m"
FP16_FLAG="-o"
echo $OPTARG
;;
t)
if [ "$OPTARG" == "arm64" ] || [ "$OPTARG" == "x86" ]; then
TARGET=$OPTARG
@ -98,7 +104,7 @@ fi
cd model/ || exit 1
rm -f *.ms
EXPORT=${EXPORT} QUANTIZE=${QUANTIZE} ./prepare_model.sh $BATCH $DOCKER || exit 1
EXPORT=${EXPORT} QUANTIZE=${QUANTIZE} ./prepare_model.sh $BATCH $DOCKER $MIX_FLAG || exit 1
cd ../
# Copy the .ms model to the package folder
@ -140,16 +146,28 @@ mv bin ${PACKAGE}/ || exit 1
if [ "${TARGET}" == "arm64" ]; then
cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${PACKAGE}/lib/ || exit 1
echo "=======Pushing to device======="
adb push ${PACKAGE} /data/local/tmp/
echo "=======Pushing to device======="
adb push ${PACKAGE} /data/local/tmp/
if [ "${MIX_FLAG}" == "" ];then
echo "========Training on Device====="
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh train.sh ${EPOCHS} ${FP16_FLAG} -b ${VIRTUAL_BATCH}"
# origin model is fp32 model
echo "========Training on Device origin model is fp32====="
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh train.sh ${EPOCHS} ${FP16_FLAG} -b ${VIRTUAL_BATCH}"
echo
echo "===Evaluating trained Model origin model is fp32====="
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh eval.sh ${FP16_FLAG}"
echo
else
echo "========Training on Device origin model is fp16 ====="
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh train.sh ${EPOCHS} ${FP16_FLAG} -b ${VIRTUAL_BATCH} ${MIX_FLAG}"
echo
echo "===Evaluating trained Model origin model is fp16====="
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh eval.sh ${FP16_FLAG} ${MIX_FLAG}"
echo
fi
echo
echo "===Evaluating trained Model====="
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh eval.sh ${FP16_FLAG}"
echo
else
cd ${PACKAGE} || exit 1
echo "======Training Locally========="

View File

@ -15,4 +15,10 @@
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod_trained.ms -e 0 -d dataset $1
is_mix_model=$(echo "$@" | grep "m")
if [[ "$is_mix_model" != "" ]]
then
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/mix_lenet_tod_trained.ms -e 0 -d dataset $1
else
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod_trained.ms -e 0 -d dataset $1
fi

View File

@ -15,4 +15,10 @@
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod.ms -d dataset "$@"
is_mix_model=$(echo "$@" | grep "m")
if [[ "$is_mix_model" != "" ]]
then
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/mix_lenet_tod.ms -d dataset "$@"
else
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod.ms -d dataset "$@"
fi

View File

@ -15,26 +15,27 @@
*/
#include "src/net_runner.h"
#include <math.h>
#include <getopt.h>
#include <stdio.h>
#include <malloc.h>
#include <cstring>
#include <math.h>
#include <stdio.h>
#include <chrono>
#include <iostream>
#include <cstring>
#include <fstream>
#include <iostream>
#include <utility>
#include "include/context.h"
#include "include/train/loss_monitor.h"
#include "include/train/ckpt_saver.h"
#include "include/train/lr_scheduler.h"
#include "include/train/accuracy_metrics.h"
#include "include/train/train_session.h"
#include "include/train/classification_train_accuracy_monitor.h"
#include "src/utils.h"
#include "include/dataset/datasets.h"
#include "include/dataset/vision_lite.h"
#include "include/dataset/transforms.h"
#include "include/dataset/vision_lite.h"
#include "include/train/accuracy_metrics.h"
#include "include/train/ckpt_saver.h"
#include "include/train/classification_train_accuracy_monitor.h"
#include "include/train/loss_monitor.h"
#include "include/train/lr_scheduler.h"
#include "include/train/train_cfg.h"
#include "include/train/train_session.h"
#include "src/utils.h"
using mindspore::dataset::Dataset;
using mindspore::dataset::Mnist;
@ -149,7 +150,9 @@ void NetRunner::InitAndFigureInputs() {
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = 2;
session_ = mindspore::session::TrainSession::CreateTrainSession(ms_file_, &context, true);
mindspore::lite::TrainCfg train_cfg;
train_cfg.mix_precision_cfg_.is_raw_mix_precision_ = is_raw_mix_precision_;
session_ = mindspore::session::TrainSession::CreateTrainSession(ms_file_, &context, true, &train_cfg);
MS_ASSERT(session_ != nullptr);
session_->SetupVirtualBatch(virtual_batch_);
@ -277,6 +280,9 @@ bool NetRunner::ReadArgs(int argc, char *argv[]) {
case 'b':
virtual_batch_ = atoi(optarg);
break;
case 'r':
is_raw_mix_precision_ = atoi(optarg);
break;
case 'h':
default:
Usage();

View File

@ -63,6 +63,7 @@ class NetRunner {
int batch_size_ = 32;
int h_ = 32;
int w_ = 32;
bool is_raw_mix_precision_ = false;
};
#endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_NET_RUNNER_H_

View File

@ -28,24 +28,28 @@ class MixPrecisionCfg {
this->loss_scale_ = 128.0f;
this->keep_batchnorm_fp32_ = true;
this->num_of_not_nan_iter_th_ = 1000;
this->is_raw_mix_precision_ = false;
}
MixPrecisionCfg(const MixPrecisionCfg &rhs) {
this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_;
this->loss_scale_ = rhs.loss_scale_;
this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_;
this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_;
this->is_raw_mix_precision_ = rhs.is_raw_mix_precision_;
}
MixPrecisionCfg &operator=(MixPrecisionCfg const &rhs) {
this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_;
this->loss_scale_ = rhs.loss_scale_;
this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_;
this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_;
this->is_raw_mix_precision_ = rhs.is_raw_mix_precision_;
return *this;
}
bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */
float loss_scale_; /**< Initial loss scale factor */
bool keep_batchnorm_fp32_ = true; /**< Keep batch norm in FP32 while training */
uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */
float loss_scale_; /**< Initial loss scale factor */
bool keep_batchnorm_fp32_ = true; /**< Keep batch norm in FP32 while training */
uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
bool is_raw_mix_precision_ = false; /**< Is mix precision model export from mindspore */
};
/// \brief TrainCfg defined for holding train configuration.

View File

@ -35,6 +35,7 @@ Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cf
l_train_cfg->mix_precision_cfg_.loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
l_train_cfg->mix_precision_cfg_.keep_batchnorm_fp32_ = (a_train_cfg->optimization_level_ != kO3);
l_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_ = a_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_;
l_train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = a_train_cfg->mix_precision_cfg_.is_raw_mix_precision_;
l_train_cfg->accumulate_gradients_ = a_train_cfg->accumulate_gradients_;
return kSuccess;
}

View File

@ -121,6 +121,16 @@ int TrainSession::InitCallBack() {
if (!context_->IsCpuFloat16Enabled()) {
return false;
}
if (cfg_.mix_precision_cfg_.is_raw_mix_precision_) {
auto out_tensor_indexs = node->output_indices_;
if (out_tensor_indexs.empty()) {
MS_LOG(DEBUG) << "Debug: " << node->name_ << " fp32";
return false;
}
auto is_fp16 = model_->all_tensors_.at(out_tensor_indexs[0])->dataType() == kNumberTypeFloat16;
MS_LOG(DEBUG) << "Debug: " << node->name_ << ((is_fp16) ? " fp16" : " fp32");
return is_fp16;
}
auto node_type = GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR);
if (node_type == schema::PrimitiveType_Cast) {
return false;
@ -128,7 +138,7 @@ int TrainSession::InitCallBack() {
auto in_size = node->input_indices_.size();
bool force_fp16 = false;
for (std::size_t k = 0; k < in_size; k++) {
schema::Tensor *tensor = model_.get()->all_tensors_.at(node->input_indices_[k]);
schema::Tensor *tensor = model_->all_tensors_.at(node->input_indices_[k]);
if ((tensor->dataType() == kNumberTypeFloat16) && (tensor->nodeType() == NodeType_ValueNode)) {
force_fp16 = true;
break;
@ -437,7 +447,7 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a
return lite::RET_NULL_PTR;
}
auto &run_kernels = (train_mode_) ? train_kernels_ : inference_kernels_;
if (context_->IsCpuFloat16Enabled()) {
if (context_->IsCpuFloat16Enabled() && !cfg_.mix_precision_cfg_.is_raw_mix_precision_) {
ret = MixPrecisionExecKernels(before, after, run_kernels);
} else {
ret = ExecKernels(before, after, run_kernels);
@ -1077,7 +1087,7 @@ session::LiteSession *session::TrainSession::CreateTrainSession(const std::strin
}
}
mindspore::lite::InnerContext *inner_context = new (std::nothrow) mindspore::lite::InnerContext(context);
auto *inner_context = new (std::nothrow) mindspore::lite::InnerContext(context);
auto ret = session->Init(inner_context, cfg);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init session failed";

View File

@ -350,6 +350,7 @@ std::unique_ptr<session::LiteSession> NetTrain::CreateAndRunNetworkForTrain(cons
} else {
MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str();
std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl;
std::cout << "Is raw mix precision model: " << train_cfg.mix_precision_cfg_.is_raw_mix_precision_ << std::endl;
session = std::unique_ptr<session::LiteSession>(
session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg));
if (session == nullptr) {
@ -412,6 +413,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
if (flags_->loss_name_ != "") {
train_cfg.loss_name_ = flags_->loss_name_;
}
train_cfg.mix_precision_cfg_.is_raw_mix_precision_ = flags_->is_raw_mix_precision_;
std::unique_ptr<session::LiteSession> session;
if (train_session) {
session = CreateAndRunNetworkForTrain(filename, bb_filename, context, train_cfg, epochs);

View File

@ -75,6 +75,8 @@ class MS_API NetTrainFlags : public virtual FlagParser {
AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false);
AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes",
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
AddFlag(&NetTrainFlags::is_raw_mix_precision_, "isRawMixPrecision",
"If model is mix precision export from MindSpore,please set true", false);
}
~NetTrainFlags() override = default;
@ -107,6 +109,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
std::vector<std::vector<int>> resize_dims_;
std::string loss_name_ = "";
std::string inference_file_ = "";
bool is_raw_mix_precision_ = false;
};
class MS_API NetTrain {