forked from mindspore-Ecosystem/mindspore
lite train support mix precision model
This commit is contained in:
parent
50847c9659
commit
fc4e6dc533
|
@ -36,6 +36,7 @@ class MixPrecisionCfg {
|
|||
bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */
|
||||
float loss_scale_; /**< Initial loss scale factor */
|
||||
uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
|
||||
bool is_raw_mix_precision_ = false; /**< Is mix precision model export from mindspore */
|
||||
};
|
||||
|
||||
class TrainCfg {
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
import sys
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore import context, Tensor, FixedLossScaleManager
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import export
|
||||
from lenet import LeNet5
|
||||
|
@ -32,5 +32,8 @@ x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)
|
|||
label = Tensor(np.zeros([BATCH_SIZE]).astype(np.int32))
|
||||
net = train_wrap(n)
|
||||
export(net, x, label, file_name="lenet_tod", file_format='MINDIR')
|
||||
|
||||
loss_scale = 128.0
|
||||
loss_scale_manager = FixedLossScaleManager(loss_scale, False)
|
||||
mix_precision_net = train_wrap(n, None, None, None, loss_scale_manager)
|
||||
export(mix_precision_net, x, label, file_name="mix_lenet_tod", file_format='MINDIR')
|
||||
print("finished exporting")
|
||||
|
|
|
@ -46,4 +46,6 @@ if [[ ! -z ${QUANTIZE} ]]; then
|
|||
QUANT_OPTIONS="--configFile=${WEIGHT_QUANT_CONFIG}"
|
||||
fi
|
||||
LD_LIBRARY_PATH=./:${LD_LIBRARY_PATH} $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS
|
||||
|
||||
if [ -n "$3" ]; then
|
||||
LD_LIBRARY_PATH=./:${LD_LIBRARY_PATH} $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=mix_lenet_tod.mindir --outputFile=mix_lenet_tod
|
||||
fi
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore import amp
|
||||
|
||||
|
||||
def train_wrap(net, loss_fn=None, optimizer=None, weights=None):
|
||||
def train_wrap(net, loss_fn=None, optimizer=None, weights=None, loss_scale_manager=None):
|
||||
"""
|
||||
train_wrap
|
||||
"""
|
||||
|
@ -31,5 +31,8 @@ def train_wrap(net, loss_fn=None, optimizer=None, weights=None):
|
|||
if optimizer is None:
|
||||
optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False,
|
||||
use_nesterov=False, weight_decay=4e-5, loss_scale=1.0)
|
||||
train_net = nn.TrainOneStepCell(loss_net, optimizer)
|
||||
if loss_scale_manager is None:
|
||||
train_net = nn.TrainOneStepCell(loss_net, optimizer)
|
||||
else:
|
||||
train_net = amp.build_train_network(net, optimizer, loss_fn, level="O2", loss_scale_manager=loss_scale_manager)
|
||||
return train_net
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
display_usage()
|
||||
{
|
||||
echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86] [-q] [-o] [-b virtual_batch] [-m mindir] [-e epochs_to_train]\n"
|
||||
echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86] [-q] [-o] [-M] [-b virtual_batch] [-m mindir] [-e epochs_to_train]\n"
|
||||
}
|
||||
|
||||
checkopts()
|
||||
|
@ -15,7 +15,8 @@ checkopts()
|
|||
FP16_FLAG=""
|
||||
VIRTUAL_BATCH=-1
|
||||
EPOCHS="-e 5"
|
||||
while getopts 'D:b:d:e:m:oqr:t:' opt
|
||||
MIX_FLAG=""
|
||||
while getopts 'D:b:d:e:m:oqr:t:M:' opt
|
||||
do
|
||||
case "${opt}" in
|
||||
b)
|
||||
|
@ -42,6 +43,11 @@ checkopts()
|
|||
r)
|
||||
TARBALL=$OPTARG
|
||||
;;
|
||||
M)
|
||||
MIX_FLAG="-m"
|
||||
FP16_FLAG="-o"
|
||||
echo $OPTARG
|
||||
;;
|
||||
t)
|
||||
if [ "$OPTARG" == "arm64" ] || [ "$OPTARG" == "x86" ]; then
|
||||
TARGET=$OPTARG
|
||||
|
@ -98,7 +104,7 @@ fi
|
|||
|
||||
cd model/ || exit 1
|
||||
rm -f *.ms
|
||||
EXPORT=${EXPORT} QUANTIZE=${QUANTIZE} ./prepare_model.sh $BATCH $DOCKER || exit 1
|
||||
EXPORT=${EXPORT} QUANTIZE=${QUANTIZE} ./prepare_model.sh $BATCH $DOCKER $MIX_FLAG || exit 1
|
||||
cd ../
|
||||
|
||||
# Copy the .ms model to the package folder
|
||||
|
@ -140,16 +146,28 @@ mv bin ${PACKAGE}/ || exit 1
|
|||
if [ "${TARGET}" == "arm64" ]; then
|
||||
cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${PACKAGE}/lib/ || exit 1
|
||||
|
||||
echo "=======Pushing to device======="
|
||||
adb push ${PACKAGE} /data/local/tmp/
|
||||
echo "=======Pushing to device======="
|
||||
adb push ${PACKAGE} /data/local/tmp/
|
||||
if [ "${MIX_FLAG}" == "" ];then
|
||||
|
||||
echo "========Training on Device====="
|
||||
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh train.sh ${EPOCHS} ${FP16_FLAG} -b ${VIRTUAL_BATCH}"
|
||||
# origin model is fp32 model
|
||||
echo "========Training on Device origin model is fp32====="
|
||||
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh train.sh ${EPOCHS} ${FP16_FLAG} -b ${VIRTUAL_BATCH}"
|
||||
|
||||
echo
|
||||
echo "===Evaluating trained Model origin model is fp32====="
|
||||
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh eval.sh ${FP16_FLAG}"
|
||||
echo
|
||||
else
|
||||
echo "========Training on Device origin model is fp16 ====="
|
||||
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh train.sh ${EPOCHS} ${FP16_FLAG} -b ${VIRTUAL_BATCH} ${MIX_FLAG}"
|
||||
|
||||
echo
|
||||
echo "===Evaluating trained Model origin model is fp16====="
|
||||
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh eval.sh ${FP16_FLAG} ${MIX_FLAG}"
|
||||
echo
|
||||
fi
|
||||
|
||||
echo
|
||||
echo "===Evaluating trained Model====="
|
||||
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh eval.sh ${FP16_FLAG}"
|
||||
echo
|
||||
else
|
||||
cd ${PACKAGE} || exit 1
|
||||
echo "======Training Locally========="
|
||||
|
|
|
@ -15,4 +15,10 @@
|
|||
# ============================================================================
|
||||
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod_trained.ms -e 0 -d dataset $1
|
||||
is_mix_model=$(echo "$@" | grep "m")
|
||||
if [[ "$is_mix_model" != "" ]]
|
||||
then
|
||||
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/mix_lenet_tod_trained.ms -e 0 -d dataset $1
|
||||
else
|
||||
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod_trained.ms -e 0 -d dataset $1
|
||||
fi
|
|
@ -15,4 +15,10 @@
|
|||
# ============================================================================
|
||||
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod.ms -d dataset "$@"
|
||||
is_mix_model=$(echo "$@" | grep "m")
|
||||
if [[ "$is_mix_model" != "" ]]
|
||||
then
|
||||
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/mix_lenet_tod.ms -d dataset "$@"
|
||||
else
|
||||
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod.ms -d dataset "$@"
|
||||
fi
|
||||
|
|
|
@ -15,26 +15,27 @@
|
|||
*/
|
||||
|
||||
#include "src/net_runner.h"
|
||||
#include <math.h>
|
||||
#include <getopt.h>
|
||||
#include <stdio.h>
|
||||
#include <malloc.h>
|
||||
#include <cstring>
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include "include/context.h"
|
||||
#include "include/train/loss_monitor.h"
|
||||
#include "include/train/ckpt_saver.h"
|
||||
#include "include/train/lr_scheduler.h"
|
||||
#include "include/train/accuracy_metrics.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "include/train/classification_train_accuracy_monitor.h"
|
||||
#include "src/utils.h"
|
||||
#include "include/dataset/datasets.h"
|
||||
#include "include/dataset/vision_lite.h"
|
||||
#include "include/dataset/transforms.h"
|
||||
#include "include/dataset/vision_lite.h"
|
||||
#include "include/train/accuracy_metrics.h"
|
||||
#include "include/train/ckpt_saver.h"
|
||||
#include "include/train/classification_train_accuracy_monitor.h"
|
||||
#include "include/train/loss_monitor.h"
|
||||
#include "include/train/lr_scheduler.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/utils.h"
|
||||
|
||||
using mindspore::dataset::Dataset;
|
||||
using mindspore::dataset::Mnist;
|
||||
|
@ -149,7 +150,9 @@ void NetRunner::InitAndFigureInputs() {
|
|||
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
|
||||
context.thread_num_ = 2;
|
||||
|
||||
session_ = mindspore::session::TrainSession::CreateTrainSession(ms_file_, &context, true);
|
||||
mindspore::lite::TrainCfg train_cfg;
|
||||
train_cfg.mix_precision_cfg_.is_raw_mix_precision_ = is_raw_mix_precision_;
|
||||
session_ = mindspore::session::TrainSession::CreateTrainSession(ms_file_, &context, true, &train_cfg);
|
||||
MS_ASSERT(session_ != nullptr);
|
||||
|
||||
session_->SetupVirtualBatch(virtual_batch_);
|
||||
|
@ -277,6 +280,9 @@ bool NetRunner::ReadArgs(int argc, char *argv[]) {
|
|||
case 'b':
|
||||
virtual_batch_ = atoi(optarg);
|
||||
break;
|
||||
case 'r':
|
||||
is_raw_mix_precision_ = atoi(optarg);
|
||||
break;
|
||||
case 'h':
|
||||
default:
|
||||
Usage();
|
||||
|
|
|
@ -63,6 +63,7 @@ class NetRunner {
|
|||
int batch_size_ = 32;
|
||||
int h_ = 32;
|
||||
int w_ = 32;
|
||||
bool is_raw_mix_precision_ = false;
|
||||
};
|
||||
|
||||
#endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_NET_RUNNER_H_
|
||||
|
|
|
@ -28,24 +28,28 @@ class MixPrecisionCfg {
|
|||
this->loss_scale_ = 128.0f;
|
||||
this->keep_batchnorm_fp32_ = true;
|
||||
this->num_of_not_nan_iter_th_ = 1000;
|
||||
this->is_raw_mix_precision_ = false;
|
||||
}
|
||||
MixPrecisionCfg(const MixPrecisionCfg &rhs) {
|
||||
this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_;
|
||||
this->loss_scale_ = rhs.loss_scale_;
|
||||
this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_;
|
||||
this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_;
|
||||
this->is_raw_mix_precision_ = rhs.is_raw_mix_precision_;
|
||||
}
|
||||
MixPrecisionCfg &operator=(MixPrecisionCfg const &rhs) {
|
||||
this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_;
|
||||
this->loss_scale_ = rhs.loss_scale_;
|
||||
this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_;
|
||||
this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_;
|
||||
this->is_raw_mix_precision_ = rhs.is_raw_mix_precision_;
|
||||
return *this;
|
||||
}
|
||||
bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */
|
||||
float loss_scale_; /**< Initial loss scale factor */
|
||||
bool keep_batchnorm_fp32_ = true; /**< Keep batch norm in FP32 while training */
|
||||
uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
|
||||
bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */
|
||||
float loss_scale_; /**< Initial loss scale factor */
|
||||
bool keep_batchnorm_fp32_ = true; /**< Keep batch norm in FP32 while training */
|
||||
uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
|
||||
bool is_raw_mix_precision_ = false; /**< Is mix precision model export from mindspore */
|
||||
};
|
||||
|
||||
/// \brief TrainCfg defined for holding train configuration.
|
||||
|
|
|
@ -35,6 +35,7 @@ Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cf
|
|||
l_train_cfg->mix_precision_cfg_.loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
|
||||
l_train_cfg->mix_precision_cfg_.keep_batchnorm_fp32_ = (a_train_cfg->optimization_level_ != kO3);
|
||||
l_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_ = a_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_;
|
||||
l_train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = a_train_cfg->mix_precision_cfg_.is_raw_mix_precision_;
|
||||
l_train_cfg->accumulate_gradients_ = a_train_cfg->accumulate_gradients_;
|
||||
return kSuccess;
|
||||
}
|
||||
|
|
|
@ -121,6 +121,16 @@ int TrainSession::InitCallBack() {
|
|||
if (!context_->IsCpuFloat16Enabled()) {
|
||||
return false;
|
||||
}
|
||||
if (cfg_.mix_precision_cfg_.is_raw_mix_precision_) {
|
||||
auto out_tensor_indexs = node->output_indices_;
|
||||
if (out_tensor_indexs.empty()) {
|
||||
MS_LOG(DEBUG) << "Debug: " << node->name_ << " fp32";
|
||||
return false;
|
||||
}
|
||||
auto is_fp16 = model_->all_tensors_.at(out_tensor_indexs[0])->dataType() == kNumberTypeFloat16;
|
||||
MS_LOG(DEBUG) << "Debug: " << node->name_ << ((is_fp16) ? " fp16" : " fp32");
|
||||
return is_fp16;
|
||||
}
|
||||
auto node_type = GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR);
|
||||
if (node_type == schema::PrimitiveType_Cast) {
|
||||
return false;
|
||||
|
@ -128,7 +138,7 @@ int TrainSession::InitCallBack() {
|
|||
auto in_size = node->input_indices_.size();
|
||||
bool force_fp16 = false;
|
||||
for (std::size_t k = 0; k < in_size; k++) {
|
||||
schema::Tensor *tensor = model_.get()->all_tensors_.at(node->input_indices_[k]);
|
||||
schema::Tensor *tensor = model_->all_tensors_.at(node->input_indices_[k]);
|
||||
if ((tensor->dataType() == kNumberTypeFloat16) && (tensor->nodeType() == NodeType_ValueNode)) {
|
||||
force_fp16 = true;
|
||||
break;
|
||||
|
@ -437,7 +447,7 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a
|
|||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
auto &run_kernels = (train_mode_) ? train_kernels_ : inference_kernels_;
|
||||
if (context_->IsCpuFloat16Enabled()) {
|
||||
if (context_->IsCpuFloat16Enabled() && !cfg_.mix_precision_cfg_.is_raw_mix_precision_) {
|
||||
ret = MixPrecisionExecKernels(before, after, run_kernels);
|
||||
} else {
|
||||
ret = ExecKernels(before, after, run_kernels);
|
||||
|
@ -1077,7 +1087,7 @@ session::LiteSession *session::TrainSession::CreateTrainSession(const std::strin
|
|||
}
|
||||
}
|
||||
|
||||
mindspore::lite::InnerContext *inner_context = new (std::nothrow) mindspore::lite::InnerContext(context);
|
||||
auto *inner_context = new (std::nothrow) mindspore::lite::InnerContext(context);
|
||||
auto ret = session->Init(inner_context, cfg);
|
||||
if (ret != mindspore::lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "init session failed";
|
||||
|
|
|
@ -350,6 +350,7 @@ std::unique_ptr<session::LiteSession> NetTrain::CreateAndRunNetworkForTrain(cons
|
|||
} else {
|
||||
MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str();
|
||||
std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl;
|
||||
std::cout << "Is raw mix precision model: " << train_cfg.mix_precision_cfg_.is_raw_mix_precision_ << std::endl;
|
||||
session = std::unique_ptr<session::LiteSession>(
|
||||
session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg));
|
||||
if (session == nullptr) {
|
||||
|
@ -412,6 +413,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
|
|||
if (flags_->loss_name_ != "") {
|
||||
train_cfg.loss_name_ = flags_->loss_name_;
|
||||
}
|
||||
train_cfg.mix_precision_cfg_.is_raw_mix_precision_ = flags_->is_raw_mix_precision_;
|
||||
std::unique_ptr<session::LiteSession> session;
|
||||
if (train_session) {
|
||||
session = CreateAndRunNetworkForTrain(filename, bb_filename, context, train_cfg, epochs);
|
||||
|
|
|
@ -75,6 +75,8 @@ class MS_API NetTrainFlags : public virtual FlagParser {
|
|||
AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false);
|
||||
AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes",
|
||||
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
|
||||
AddFlag(&NetTrainFlags::is_raw_mix_precision_, "isRawMixPrecision",
|
||||
"If model is mix precision export from MindSpore,please set true", false);
|
||||
}
|
||||
|
||||
~NetTrainFlags() override = default;
|
||||
|
@ -107,6 +109,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
|
|||
std::vector<std::vector<int>> resize_dims_;
|
||||
std::string loss_name_ = "";
|
||||
std::string inference_file_ = "";
|
||||
bool is_raw_mix_precision_ = false;
|
||||
};
|
||||
|
||||
class MS_API NetTrain {
|
||||
|
|
Loading…
Reference in New Issue