forked from mindspore-Ecosystem/mindspore
!13574 fix lite training
From: @xutianchun Reviewed-by: @HilbertDavid,@hangangqiang Signed-off-by: @HilbertDavid
This commit is contained in:
commit
72cc142a0b
|
@ -16,7 +16,7 @@ CFLAGS := -Ofast -std=c++17 \
|
|||
-I . \
|
||||
-I ./msl/train \
|
||||
-I ./msl/train/minddata \
|
||||
-I ./msl/train/third_party/flatbuffers/include
|
||||
-I ./msl/tools/third_party/flatbuffers/include
|
||||
|
||||
|
||||
ifeq ($(TARGET),arm64)
|
||||
|
|
|
@ -79,15 +79,17 @@ cp model/*.ms ${PACKAGE}/model || exit 1
|
|||
cp scripts/*.sh ${PACKAGE}/
|
||||
|
||||
# Copy the shared MindSpore ToD library
|
||||
tar -xzf ${TARBALL}
|
||||
tar -xzf ${TARBALL}
|
||||
mv mindspore-*/train/lib ${PACKAGE}/
|
||||
mv mindspore-*/train/minddata/lib/* ${PACKAGE}/lib/
|
||||
mv mindspore-*/train/minddata/third_party/libjpeg-turbo/lib/* ${PACKAGE}/lib/
|
||||
if [ "${TARGET}" == "arm64" ]; then
|
||||
tar -xzf ${TARBALL} --wildcards --no-anchored hiai_ddk
|
||||
mv mindspore-*/train/third_party/hiai_ddk/lib/* ${PACKAGE}/lib/
|
||||
fi
|
||||
|
||||
rm -rf msl
|
||||
mkdir msl
|
||||
mv mindspore-*/* msl/
|
||||
rm -rf mindspore-*
|
||||
mv mindspore-* msl/
|
||||
|
||||
# Copy the dataset to the package
|
||||
cp -r $MNIST_DATA_PATH ${PACKAGE}/dataset || exit 1
|
||||
|
|
|
@ -101,7 +101,7 @@ void NetRunner::InitAndFigureInputs() {
|
|||
|
||||
session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context);
|
||||
MS_ASSERT(nullptr != session_);
|
||||
loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_, &context);
|
||||
loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_);
|
||||
|
||||
acc_metrics_ = std::shared_ptr<AccuracyMetrics>(new AccuracyMetrics);
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_loop.h"
|
||||
#include "include/train/accuracy_metrics.h"
|
||||
#include "include/ms_tensor.h"
|
||||
|
|
|
@ -1,32 +1,39 @@
|
|||
BASE_DIR=$(realpath ../../../../)
|
||||
APP:=bin/net_runner
|
||||
MSLIB:=mindspore-lite
|
||||
LMDLIB:=-lminddata-lite
|
||||
MSDIR:=$(realpath package-$(TARGET)/lib)
|
||||
ifneq ("$(wildcard $(MSDIR)/libhiai.so)","")
|
||||
LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai
|
||||
else
|
||||
LHIAILIB:=
|
||||
endif
|
||||
|
||||
SRC:=src/net_runner.cc src/dataset.cc
|
||||
SRC:=src/net_runner.cc src/dataset.cc
|
||||
OBJ:=$(SRC:.cc=.o)
|
||||
|
||||
CFLAGS := -Ofast -std=c++17 \
|
||||
-I . \
|
||||
-I ./msl/train \
|
||||
-I ./msl/train/third_party/flatbuffers/include
|
||||
-I ./msl/train/minddata \
|
||||
-I ./msl/tools/third_party/flatbuffers/include
|
||||
|
||||
|
||||
ifeq ($(TARGET),arm64)
|
||||
CXX := ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin/clang++
|
||||
CFLAGS += --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -fdata-sections -ffunction-sections
|
||||
LDFLAGS := --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -Wl,--gc-sections
|
||||
LDFLAGS += -L$(MSDIR) -l$(MSLIB) -pthread -llog -latomic -lm -Wl,-rpath,$(MSDIR)
|
||||
LDFLAGS += -L$(MSDIR) -l$(MSLIB) $(LMDLIB) $(LHIAILIB) -pthread -llog -latomic -lm -Wl,-rpath,$(MSDIR)
|
||||
else
|
||||
CFLAGS += -g
|
||||
LDFLAGS := -L$(MSDIR) -l$(MSLIB) -lpthread -Wl,-rpath,$(MSDIR)
|
||||
LDFLAGS := -L$(MSDIR) -l$(MSLIB) $(LMDLIB) $(LHIAILIB) -lpthread -Wl,-rpath,$(MSDIR)
|
||||
endif
|
||||
LD := ${CXX}
|
||||
|
||||
|
||||
all:$(APP)
|
||||
|
||||
$(APP): $(OBJ) $(MSDIR)/lib$(MSLIB).so
|
||||
$(APP): $(OBJ)
|
||||
@mkdir -p bin
|
||||
$(LD) $(OBJ) $(LDFLAGS) -o $@
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ fi
|
|||
echo "============Exporting=========="
|
||||
if [ -n "$1" ]; then
|
||||
DOCKER_IMG=$1
|
||||
rm *.so*
|
||||
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "python transfer_learning_export.py; chmod 444 transfer_learning_tod*.mindir; rm -rf __pycache__"
|
||||
else
|
||||
echo "MindSpore docker was not provided, attempting to run locally"
|
||||
|
|
|
@ -49,7 +49,7 @@ HEAD.weight.set_data(M.Tensor(np.random.normal(
|
|||
0, 0.1, HEAD.weight.data.shape).astype("float32")))
|
||||
HEAD.bias.set_data(M.Tensor(np.zeros(HEAD.bias.data.shape, dtype="float32")))
|
||||
|
||||
sgd = M.nn.SGD(HEAD.trainable_params(), learning_rate=0.01, momentum=0.9,
|
||||
sgd = M.nn.SGD(HEAD.trainable_params(), learning_rate=0.015, momentum=0.9,
|
||||
dampening=0.01, weight_decay=0.0, nesterov=False, loss_scale=1.0)
|
||||
net = TrainWrap(HEAD, optimizer=sgd)
|
||||
backbone_out = M.Tensor(np.zeros([BATCH_SIZE, 1000]).astype(np.float32))
|
||||
|
|
|
@ -82,10 +82,13 @@ tar -xzf ${TARBALL}
|
|||
mv mindspore-*/train/lib ${PACKAGE}/
|
||||
mv mindspore-*/train/minddata/lib/* ${PACKAGE}/lib/
|
||||
mv mindspore-*/train/minddata/third_party/libjpeg-turbo/lib/* ${PACKAGE}/lib/
|
||||
if [ "${TARGET}" == "arm64" ]; then
|
||||
tar -xzf ${TARBALL} --wildcards --no-anchored hiai_ddk
|
||||
mv mindspore-*/train/third_party/hiai_ddk/lib/* ${PACKAGE}/lib/
|
||||
fi
|
||||
|
||||
rm -rf msl
|
||||
mkdir msl
|
||||
mv mindspore-*/* msl/
|
||||
rm -rf mindspore-*
|
||||
mv mindspore-* msl/
|
||||
|
||||
# Convert the dataset into the package
|
||||
./prepare_dataset.sh ${PLACES_DATA_PATH} || exit 1
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "src/dataset.h"
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
|
||||
constexpr int METRICS_CLASSIFICATION = 0;
|
||||
constexpr int METRICS_MULTILABLE = 1;
|
||||
constexpr int METRICS_MULTILABEL = 1;
|
||||
|
||||
class AccuracyMetrics : public Metrics {
|
||||
public:
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <unordered_map>
|
||||
#include "include/train/train_loop_callback.h"
|
||||
#include "include/train/metrics.h"
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MSTensor;
|
||||
|
@ -41,10 +41,9 @@ class TrainLoop {
|
|||
/// \brief Static method to create a TrainLoop object
|
||||
///
|
||||
/// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API
|
||||
/// \param[in] context Defines the context of the session to be created
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite TrainLoop
|
||||
static TrainLoop *CreateTrainLoop(session::TrainSession *train_session, lite::Context *context, int batch_size = -1);
|
||||
static TrainLoop *CreateTrainLoop(session::TrainSession *train_session);
|
||||
|
||||
/// \brief Class destructor
|
||||
virtual ~TrainLoop() = default;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include <jni.h>
|
||||
#include "common/ms_log.h"
|
||||
#include "common/jni_utils.h"
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createSession(JNIEnv *env, jobject thiz,
|
||||
|
|
|
@ -55,32 +55,33 @@ int ActivationGradCPUKernel::DoActivation(int task_id) {
|
|||
size_t start = stride * task_id;
|
||||
|
||||
auto error_code = RET_OK;
|
||||
|
||||
if (param_act_grad_->type_ == schema::ActivationType_RELU) {
|
||||
error_code = ReluGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_RELU6) {
|
||||
error_code = Relu6Grad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) {
|
||||
error_code = LReluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) {
|
||||
// Sigmoid gets the input tensors in reverse order!
|
||||
error_code = SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_TANH) {
|
||||
error_code = TanhGrad(input_addr + start, yt_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) {
|
||||
error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) {
|
||||
error_code = HSigmoidGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_ELU) {
|
||||
error_code = EluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_GELU) {
|
||||
error_code = GeluGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Activation type error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (error_code != RET_OK) {
|
||||
return RET_ERROR;
|
||||
if (count > 0) {
|
||||
if (param_act_grad_->type_ == schema::ActivationType_RELU) {
|
||||
error_code = ReluGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_RELU6) {
|
||||
error_code = Relu6Grad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) {
|
||||
error_code = LReluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) {
|
||||
// Sigmoid gets the input tensors in reverse order!
|
||||
error_code = SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_TANH) {
|
||||
error_code = TanhGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) {
|
||||
error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) {
|
||||
error_code = HSigmoidGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_ELU) {
|
||||
error_code = EluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationType_GELU) {
|
||||
error_code = GeluGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Activation type error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (error_code != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -32,8 +32,8 @@ namespace mindspore::kernel {
|
|||
|
||||
int AdamCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int DoAdam(float *m, float *v, float *gradient, float *weight, float beta1, float beta2, float beta1_power,
|
||||
float beta2_power, float eps, float learning_rate, bool nesterov, size_t start, size_t end) {
|
||||
static int DoAdam(float *m, float *v, float *gradient, float *weight, float beta1, float beta2, float beta1_power,
|
||||
float beta2_power, float eps, float learning_rate, bool nesterov, int start, int end) {
|
||||
if ((1.f - beta1_power) <= 0.0f) {
|
||||
MS_LOG(ERROR) << "divisor cannot be 0 or below";
|
||||
return RET_ERROR;
|
||||
|
@ -47,13 +47,13 @@ int DoAdam(float *m, float *v, float *gradient, float *weight, float beta1, floa
|
|||
const float one_minus_beta1 = 1.f - beta1;
|
||||
const float one_minus_beta2 = 1.f - beta2;
|
||||
if (nesterov) { // Nadam
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
for (int i = start; i < end; ++i) {
|
||||
m[i] += (gradient[i] - m[i]) * one_minus_beta1;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2;
|
||||
weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (std::sqrt(v[i]) + eps);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
for (int i = start; i < end; ++i) {
|
||||
m[i] += (gradient[i] - m[i]) * one_minus_beta1;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2;
|
||||
weight[i] -= update_lr * m[i] / (std::sqrt(v[i]) + eps);
|
||||
|
@ -77,7 +77,6 @@ int AdamCPUKernel::Execute(int task_id) {
|
|||
|
||||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
||||
int start = stride * task_id;
|
||||
int end = start + count;
|
||||
|
||||
|
|
|
@ -30,15 +30,15 @@ using mindspore::schema::PrimitiveType_ApplyMomentum;
|
|||
namespace mindspore::kernel {
|
||||
int ApplyMomentumCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int DoApplyMomentum(float *weight, float *accumulate, float learning_rate, float *gradient, float moment, bool nesterov,
|
||||
size_t start, size_t end) {
|
||||
static int DoApplyMomentum(float *weight, float *accumulate, float learning_rate, float *gradient, float moment,
|
||||
bool nesterov, int start, int end) {
|
||||
if (nesterov) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
for (int i = start; i < end; i++) {
|
||||
accumulate[i] = accumulate[i] * moment + gradient[i];
|
||||
weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
for (int i = start; i < end; i++) {
|
||||
accumulate[i] = accumulate[i] * moment + gradient[i];
|
||||
weight[i] -= accumulate[i] * learning_rate;
|
||||
}
|
||||
|
@ -56,6 +56,7 @@ int ApplyMomentumCPUKernel::Execute(int task_id) {
|
|||
|
||||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
count = (count < 0) ? 0 : count;
|
||||
int start = stride * task_id;
|
||||
int end = start + count;
|
||||
|
||||
|
|
|
@ -72,7 +72,9 @@ int ArithmeticSelfGradCPUKernel::DoArithmeticSelfGrad(int task_id) {
|
|||
int count = MSMIN(stride, length - stride * task_id);
|
||||
int start = stride * task_id;
|
||||
|
||||
(*self_grad_operation_)(dy + start, in_x + start, dx + start, count);
|
||||
if (count > 0) {
|
||||
(*self_grad_operation_)(dy + start, in_x + start, dx + start, count);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -41,7 +41,9 @@ int AssignCPUKernel::Execute(int task_id) {
|
|||
|
||||
int start = stride * task_id;
|
||||
|
||||
memcpy(&(x[start]), &(y[start]), count * sizeof(float));
|
||||
if (count > 0) {
|
||||
memcpy(&(x[start]), &(y[start]), count * sizeof(float));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@ int BNGradCPUKernel::Execute(int task_id) {
|
|||
int total = spatial * batch;
|
||||
int stride = UP_DIV(total, thread_num);
|
||||
int count = MSMIN(stride, total - stride * task_id);
|
||||
count = (count < 0) ? 0 : count;
|
||||
switch (stage) {
|
||||
case 0: {
|
||||
for (int job = task_id; job < 4; job += thread_num) {
|
||||
|
|
|
@ -108,6 +108,7 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
|
|||
float *mat_tmp = mat_workspace + mat_alloc_;
|
||||
int stride = UP_DIV(batch, thread_num);
|
||||
int count = MSMIN(stride, batch - stride * task_id);
|
||||
count = (count < 0) ? 0 : count;
|
||||
int start = stride * task_id;
|
||||
int end = start + count;
|
||||
|
||||
|
@ -115,6 +116,7 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
|
|||
#ifdef ENABLE_ARM
|
||||
stride = UP_DIV(k_h * k_w, thread_num);
|
||||
count = MSMIN(stride, k_h * k_w - stride * task_id);
|
||||
count = (count < 0) ? 0 : count;
|
||||
start = stride * task_id;
|
||||
ConvDwFilterGrad(x_addr, dy_addr, dw_addr, start, count, conv_param);
|
||||
#else
|
||||
|
|
|
@ -92,6 +92,7 @@ int ConvolutionGradInputCPUKernel::Execute(int task_id) {
|
|||
float *mat_workspace = workspace_temp + ws_size_;
|
||||
int stride = UP_DIV(batch, thread_num);
|
||||
int count = MSMIN(stride, batch - stride * task_id);
|
||||
count = (count < 0) ? 0 : count;
|
||||
int start = stride * task_id;
|
||||
int end = start + count;
|
||||
|
||||
|
|
|
@ -67,22 +67,24 @@ int DropoutCPUKernel::Execute(int task_id) {
|
|||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
||||
size_t start = stride * task_id;
|
||||
size_t end = start + count;
|
||||
int start = stride * task_id;
|
||||
int end = start + count;
|
||||
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "Dropout op_parameter_ nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (IsEval()) {
|
||||
std::copy(&(input_ptr[start]), &(input_ptr[end]), &(output_ptr[start]));
|
||||
} else {
|
||||
std::default_random_engine generator;
|
||||
std::bernoulli_distribution distribution(param->ratio_);
|
||||
if (count > 0) {
|
||||
if (IsEval()) {
|
||||
std::copy(&(input_ptr[start]), &(input_ptr[end]), &(output_ptr[start]));
|
||||
} else {
|
||||
std::default_random_engine generator;
|
||||
std::bernoulli_distribution distribution(param->ratio_);
|
||||
|
||||
for (size_t i = start; i < end; i++) {
|
||||
mask[i] = distribution(generator);
|
||||
output_ptr[i] = input_ptr[i] * mask[i] * scale_;
|
||||
for (int i = start; i < end; i++) {
|
||||
mask[i] = distribution(generator);
|
||||
output_ptr[i] = input_ptr[i] * mask[i] * scale_;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -46,7 +46,7 @@ int NegGradCPUKernel::DoNegGrad(int task_id) {
|
|||
|
||||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
||||
count = (count < 0) ? 0 : count;
|
||||
int start = stride * task_id;
|
||||
|
||||
ElementNegative(dy + start, dx + start, count);
|
||||
|
|
|
@ -50,7 +50,7 @@ int PowerGradCPUKernel::Execute(int task_id) {
|
|||
|
||||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
||||
count = (count < 0) ? 0 : count;
|
||||
int start = stride * task_id;
|
||||
int end = start + count;
|
||||
|
||||
|
|
|
@ -33,21 +33,21 @@ namespace mindspore::kernel {
|
|||
int SgdCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int DoSgd(float *weight, float *accumulate, float *gradient, float learning_rate, float dampening, float moment,
|
||||
bool nesterov, size_t start, size_t end) {
|
||||
bool nesterov, int start, int end) {
|
||||
if (moment > 0.f) {
|
||||
if (nesterov) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
for (int i = start; i < end; ++i) {
|
||||
accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - dampening);
|
||||
weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
for (int i = start; i < end; ++i) {
|
||||
accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - dampening);
|
||||
weight[i] -= accumulate[i] * learning_rate;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
for (int i = start; i < end; ++i) {
|
||||
weight[i] -= gradient[i] * learning_rate;
|
||||
}
|
||||
}
|
||||
|
@ -55,14 +55,14 @@ int DoSgd(float *weight, float *accumulate, float *gradient, float learning_rate
|
|||
}
|
||||
|
||||
int DoSgdInit(float *weight, float *accumulate, float *gradient, float *stat, float learning_rate, float dampening,
|
||||
float moment, bool nesterov, size_t start, size_t end) {
|
||||
float moment, bool nesterov, int start, int end) {
|
||||
std::copy(&(gradient[start]), &(gradient[end]), &(accumulate[start]));
|
||||
if (nesterov) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
for (int i = start; i < end; ++i) {
|
||||
weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
for (int i = start; i < end; ++i) {
|
||||
weight[i] -= accumulate[i] * learning_rate;
|
||||
}
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ int SgdCPUKernel::Execute(int task_id) {
|
|||
|
||||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
||||
count = (count < 0) ? 0 : count;
|
||||
int start = stride * task_id;
|
||||
int end = start + count;
|
||||
|
||||
|
@ -97,16 +97,18 @@ int SgdCPUKernel::ExecuteInit(int task_id) {
|
|||
auto gradient = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
|
||||
float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0];
|
||||
auto stat = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData());
|
||||
size_t length = in_tensors_.at(0)->ElementsNum();
|
||||
int length = in_tensors_.at(0)->ElementsNum();
|
||||
|
||||
size_t stride = UP_DIV(length, thread_count_);
|
||||
size_t count = MSMIN(stride, length - stride * task_id);
|
||||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
||||
size_t start = stride * task_id;
|
||||
size_t end = start + count;
|
||||
int start = stride * task_id;
|
||||
int end = start + count;
|
||||
|
||||
DoSgdInit(weight, accumulate, gradient, stat, learning_rate, sgd_param_->dampening_, moment,
|
||||
sgd_param_->use_nesterov_, start, end);
|
||||
if (count > 0) {
|
||||
DoSgdInit(weight, accumulate, gradient, stat, learning_rate, sgd_param_->dampening_, moment,
|
||||
sgd_param_->use_nesterov_, start, end);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ int SmoothL1LossGradCPUKernel::Execute(int task_id) {
|
|||
|
||||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
||||
count = (count < 0) ? 0 : count;
|
||||
int start = stride * task_id;
|
||||
int end = start + count;
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/train/train_utils.h"
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/tensor.h"
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include <fstream>
|
||||
#include <memory>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/tensor.h"
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "include/iterator.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
|
@ -168,8 +168,7 @@ int TrainLoop::LoadPartialData(std::vector<tensor::MSTensor *> inputs, dataset::
|
|||
|
||||
} // namespace lite
|
||||
|
||||
session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::TrainSession *train_session, lite::Context *context,
|
||||
int batch_size) {
|
||||
session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::TrainSession *train_session) {
|
||||
auto loop = new (std::nothrow) lite::TrainLoop(train_session);
|
||||
return loop;
|
||||
}
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "include/train/train_loop.h"
|
||||
#include "include/train/metrics.h"
|
||||
#include "include/train_session.h"
|
||||
#include "include/datasets.h"
|
||||
#include "include/iterator.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/train/train_model.h"
|
||||
#include "src/lite_session.h"
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include "include/train_session.h"
|
||||
#include "src/train/train_model.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "src/train/train_session.h"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
mini_alexnet
|
||||
#nin
|
||||
nin
|
||||
lenet
|
||||
mobilenetv1
|
||||
mobilenetv2
|
||||
|
@ -10,5 +10,5 @@ effnet_tune
|
|||
googlenet
|
||||
densenet
|
||||
shufflenetv2
|
||||
#xception
|
||||
# xception
|
||||
# LAST
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "common/common_test.h"
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
#include "tools/common/flag_parser.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "include/train_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
enum MS_API DataType { kImage = 0, kBinary = 1 };
|
||||
|
@ -156,7 +156,6 @@ class MS_API NetTrain {
|
|||
std::cout << refOutput[j] << " ";
|
||||
}
|
||||
for (int j = 0; j < size; j++) {
|
||||
std::cout << std::endl;
|
||||
if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
|
||||
std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
|
||||
MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
|
||||
|
|
|
@ -126,6 +126,7 @@ constexpr auto kNameArgMinWithValue = "ArgMinWithValue";
|
|||
constexpr auto kNameBatchMatMul = "BatchMatMul";
|
||||
constexpr auto kNameFusedBatchNormEx = "FusedBatchNormEx";
|
||||
constexpr auto kNameFusedBatchNormGradEx = "FusedBatchNormGradEx";
|
||||
constexpr auto kNameFusedBatchNormGradCPU = "FusedBatchNormGradCPU";
|
||||
constexpr auto kNameHSigmoid = "HSigmoid";
|
||||
constexpr auto kNameHSigmoidGrad = "HSigmoidGrad";
|
||||
constexpr auto kNameHSwish = "HSwish";
|
||||
|
@ -549,6 +550,7 @@ REGIST_PRIMITIVE_ADJUST(kNameEluGrad, MoveAttrMapActivationGrad)
|
|||
REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon<ops::ExpFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormEx, MoveAttrMapCommon<ops::FusedBatchNorm>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapCommon<ops::BatchNormGrad>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradCPU, MoveAttrMapCommon<ops::BatchNormGrad>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameGeLUGrad, MoveAttrMapActivationGrad)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameHSigmoid, MoveAttrMapActivation)
|
||||
|
|
Loading…
Reference in New Issue