tod add train loop

This commit is contained in:
yoni 2021-01-17 18:17:30 +02:00
parent 8008843562
commit 33d7741904
96 changed files with 2418 additions and 531 deletions

View File

@ -395,6 +395,7 @@ if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" || "X$ENABLE_GPU" = "
git submodule update --init --recursive akg
fi
build_exit()
{
echo "$@" >&2
@ -596,33 +597,40 @@ build_lite()
build_lite_java_arm64() {
# build mindspore-lite arm64
if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch64.tar.gz" ]]; then
JTARBALL=mindspore-lite-${VERSION_STR}-inference-android-aarch64
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
JTARBALL=mindspore-lite-${VERSION_STR}-train-android-aarch64
fi
if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/${JTARBALL}.tar.gz" ]]; then
build_lite "arm64" "off"
fi
# copy arm64 so
cd ${BASEPATH}/output/
rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch64
tar -zxvf mindspore-lite-${VERSION_STR}-inference-android-aarch64.tar.gz
rm -rf ${JTARBALL}
tar -zxvf ${JTARBALL}.tar.gz
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/arm64-v8a/
mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch64/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
echo mindspore-lite-${VERSION_STR}-inference-android-aarch64
[ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch64
cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
[ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
}
build_lite_java_arm32() {
# build mindspore-lite arm32
if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch32.tar.gz" ]]; then
JTARBALL=mindspore-lite-${VERSION_STR}-inference-android-aarch32
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
JTARBALL=mindspore-lite-${VERSION_STR}-train-android-aarch32
fi
if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/${JTARBALL}.tar.gz" ]]; then
build_lite "arm32" "off"
fi
# copy arm32 so
cd ${BASEPATH}/output/
rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch32
tar -zxvf mindspore-lite-${VERSION_STR}-inference-android-aarch32.tar.gz
rm -rf ${JTARBALL}
tar -zxvf ${JTARBALL}.tar.gz
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/armeabi-v7a/
mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch32/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
[ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch32
cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
[ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
}
build_jni_arm64() {
@ -635,7 +643,7 @@ build_jni_arm64() {
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
-DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
-DPLATFORM_ARM64=on "${JAVA_PATH}/java/app/src/main/native"
-DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DPLATFORM_ARM64=on "${JAVA_PATH}/java/app/src/main/native"
make -j$THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "---------------- mindspore lite: build jni arm64 failed----------------"
@ -655,7 +663,7 @@ build_jni_arm32() {
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
-DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
-DPLATFORM_ARM32=on "${JAVA_PATH}/java/app/src/main/native"
-DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DPLATFORM_ARM32=on "${JAVA_PATH}/java/app/src/main/native"
make -j$THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "---------------- mindspore lite: build jni arm32 failed----------------"

View File

@ -3,7 +3,7 @@ APP:=bin/net_runner
MSLIB:=mindspore-lite
MSDIR:=$(realpath package-$(TARGET)/lib)
SRC:=src/net_runner.cc src/dataset.cc
SRC:=src/net_runner.cc src/dataset.cc src/data_callbacks.cc
OBJ:=$(SRC:.cc=.o)
CFLAGS := -Ofast -std=c++17 \

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_
#define MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_
#include <vector>
#include <string>
#include <utility>
#include <unordered_map>
#include "include/train/train_loop.h"
#include "src/dataset.h"
using GraphPoint = std::pair<int, float>;
class AccuracyMonitor : public mindspore::session::TrainLoopCallBack {
public:
explicit AccuracyMonitor(DataSet *dataset, int check_every_n, int max_steps = -1)
: ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {}
int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override;
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
private:
DataSet *ds_;
std::vector<GraphPoint> accuracies_;
int check_every_n_;
int max_steps_;
};
#endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_

View File

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <math.h>
#include <getopt.h>
#include <cstring>
#include <iostream>
#include <fstream>
#include <utility>
#include "src/net_runner.h"
#include "include/context.h"
#include "src/utils.h"
#include "src/data_loader.h"
#include "src/accuracy_monitor.h"
static unsigned int seed = time(NULL);
std::vector<int> FillInputDataUtil(const mindspore::session::TrainLoopCallBackData &cb_data,
const std::vector<DataLabelTuple> &dataset, bool serially) {
static unsigned int idx = 1;
int total_size = dataset.size();
std::vector<int> labels_vec;
auto inputs = cb_data.session_->GetInputs();
char *input_data = reinterpret_cast<char *>(inputs.at(0)->MutableData());
auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
int batch_size = inputs.at(0)->shape()[0];
int num_of_classes = inputs.at(1)->shape()[1];
int data_size = inputs.at(0)->Size() / batch_size;
MS_ASSERT(total_size > 0);
MS_ASSERT(input_data != nullptr);
std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f);
for (int i = 0; i < batch_size; i++) {
if (serially) {
idx = ++idx % total_size;
} else {
idx = rand_r(&seed) % total_size;
}
int label = 0;
char *data = nullptr;
std::tie(data, label) = dataset[idx];
std::copy(data, data + data_size, input_data + i * data_size);
labels[i * num_of_classes + label] = 1.0; // Model expects labels in onehot representation
labels_vec.push_back(label);
}
return labels_vec;
}
void DataLoader::StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) {
FillInputDataUtil(cb_data, ds_->train_data(), false);
}
int AccuracyMonitor::EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) {
if ((cb_data.epoch_ + 1) % check_every_n_ != 0) return mindspore::session::RET_CONTINUE;
float accuracy = 0.0;
auto inputs = cb_data.session_->GetInputs();
int batch_size = inputs.at(0)->shape()[0];
int num_of_classes = ds_->num_of_classes();
int tests = ds_->test_data().size() / batch_size;
if (max_steps_ != -1 && tests > max_steps_) tests = max_steps_;
cb_data.session_->Eval();
for (int i = 0; i < tests; i++) {
auto labels = FillInputDataUtil(cb_data, ds_->test_data(), false);
cb_data.session_->RunGraph();
auto outputs = cb_data.session_->GetPredictions();
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
if (it->second->ElementsNum() == batch_size * num_of_classes) {
auto scores = reinterpret_cast<float *>(it->second->MutableData());
for (int b = 0; b < batch_size; b++) {
int max_idx = 0;
float max_score = scores[num_of_classes * b];
for (int c = 1; c < num_of_classes; c++) {
if (scores[num_of_classes * b + c] > max_score) {
max_score = scores[num_of_classes * b + c];
max_idx = c;
}
}
if (labels[b] == max_idx) accuracy += 1.0;
}
break;
}
}
}
accuracy /= static_cast<float>(batch_size * tests);
accuracies_.push_back(std::make_pair(cb_data.epoch_, accuracy));
std::cout << cb_data.epoch_ + 1 << ":\tAccuracy is " << accuracy << std::endl;
cb_data.session_->Train();
return mindspore::session::RET_CONTINUE;
}

View File

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_
#define MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_
#include <vector>
#include <string>
#include <utility>
#include <unordered_map>
#include "include/train/train_loop.h"
#include "src/dataset.h"
class DataLoader : public mindspore::session::TrainLoopCallBack {
public:
explicit DataLoader(DataSet *dataset) : ds_(dataset) {}
void StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) override;
private:
DataSet *ds_;
};
#endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_

View File

@ -20,10 +20,21 @@
#include <cstring>
#include <iostream>
#include <fstream>
#include <utility>
#include "include/context.h"
#include "include/train/loss_monitor.h"
#include "include/train/ckpt_saver.h"
#include "include/train/lr_scheduler.h"
#include "include/train/classification_train_accuracy_monitor.h"
#include "src/utils.h"
#include "src/data_loader.h"
#include "src/accuracy_monitor.h"
using mindspore::session::TrainLoopCallBack;
using mindspore::session::TrainLoopCallBackData;
static unsigned int seed = time(NULL);
unsigned int NetRunner::seed_ = time(NULL);
// Definition of callback function after forwarding operator.
bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
@ -54,15 +65,18 @@ bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inpu
}
NetRunner::~NetRunner() {
if (session_ != nullptr) delete session_;
if (loop_ != nullptr) delete loop_;
}
void NetRunner::InitAndFigureInputs() {
mindspore::lite::Context context;
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
context.thread_num_ = 1;
context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = false;
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = 2;
session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context);
loop_ = mindspore::session::TrainLoop::CreateTrainLoop(ms_file_, &context);
session_ = loop_->train_session();
MS_ASSERT(nullptr != session_);
auto inputs = session_->GetInputs();
@ -76,71 +90,10 @@ void NetRunner::InitAndFigureInputs() {
}
}
mindspore::tensor::MSTensor *NetRunner::SearchOutputsForSize(size_t size) const {
auto outputs = session_->GetOutputs();
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
if (it->second->ElementsNum() == size) return it->second;
}
std::cout << "Model does not have an output tensor with size " << size << std::endl;
return nullptr;
}
std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dataset, bool serially) const {
std::vector<int> labels_vec;
static unsigned int idx = 1;
int total_size = dataset.size();
auto inputs = session_->GetInputs();
char *input_data = reinterpret_cast<char *>(inputs.at(data_index_)->MutableData());
auto labels = reinterpret_cast<float *>(inputs.at(label_index_)->MutableData());
MS_ASSERT(total_size > 0);
MS_ASSERT(input_data != nullptr);
std::fill(labels, labels + inputs.at(label_index_)->ElementsNum(), 0.f);
for (int i = 0; i < batch_size_; i++) {
if (serially) {
idx = ++idx % total_size;
} else {
idx = rand_r(&seed_) % total_size;
}
int label = 0;
char *data = nullptr;
std::tie(data, label) = dataset[idx];
std::memcpy(input_data + i * data_size_, data, data_size_);
labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation
labels_vec.push_back(label);
}
return labels_vec;
}
float NetRunner::CalculateAccuracy(int max_tests) const {
float accuracy = 0.0;
const std::vector<DataLabelTuple> test_set = ds_.test_data();
int tests = test_set.size() / batch_size_;
if (max_tests != -1 && tests < max_tests) tests = max_tests;
session_->Eval();
for (int i = 0; i < tests; i++) {
auto labels = FillInputData(test_set, (max_tests == -1));
session_->RunGraph();
auto outputsv = SearchOutputsForSize(batch_size_ * num_of_classes_);
MS_ASSERT(outputsv != nullptr);
auto scores = reinterpret_cast<float *>(outputsv->MutableData());
for (int b = 0; b < batch_size_; b++) {
int max_idx = 0;
float max_score = scores[num_of_classes_ * b];
for (int c = 0; c < num_of_classes_; c++) {
if (scores[num_of_classes_ * b + c] > max_score) {
max_score = scores[num_of_classes_ * b + c];
max_idx = c;
}
}
if (labels[b] == max_idx) accuracy += 1.0;
}
}
session_->Train();
accuracy /= static_cast<float>(batch_size_ * tests);
return accuracy;
float NetRunner::CalculateAccuracy(int max_tests) {
AccuracyMonitor test_am(&ds_, 1, max_tests);
test_am.EpochEnd(TrainLoopCallBackData(true, 0, session_, loop_));
return 0.0;
}
int NetRunner::InitDB() {
@ -155,35 +108,17 @@ int NetRunner::InitDB() {
return ret;
}
float NetRunner::GetLoss() const {
auto outputsv = SearchOutputsForSize(1); // Search for Loss which is a single value tensor
MS_ASSERT(outputsv != nullptr);
auto loss = reinterpret_cast<float *>(outputsv->MutableData());
return loss[0];
}
int NetRunner::TrainLoop() {
session_->Train();
float min_loss = 1000.;
float max_acc = 0.;
for (int i = 0; i < cycles_; i++) {
FillInputData(ds_.train_data());
session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr);
float loss = GetLoss();
if (min_loss > loss) min_loss = loss;
struct mindspore::lite::StepLRLambda step_lr_lambda(100, 0.9);
mindspore::lite::LRScheduler step_lr_sched(mindspore::lite::StepLRLambda, static_cast<void *>(&step_lr_lambda), 100);
if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) {
auto cpkt_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms";
session_->SaveToFile(cpkt_fn);
}
mindspore::lite::LossMonitor lm(100);
// mindspore::lite::ClassificationTrainAccuracyMonitor am(10);
mindspore::lite::CkptSaver cs(1000, std::string("lenet"));
AccuracyMonitor test_am(&ds_, 500, 10);
DataLoader dl(&ds_);
if ((i + 1) % 100 == 0) {
float acc = CalculateAccuracy(10);
if (max_acc < acc) max_acc = acc;
std::cout << i + 1 << ":\tLoss is " << std::setw(7) << loss << " [min=" << min_loss << "] "
<< " max_acc=" << max_acc << std::endl;
}
}
loop_->Train(cycles_, std::vector<TrainLoopCallBack *>{&dl, &lm, &test_am, &cs, &step_lr_sched});
return 0;
}
@ -194,8 +129,7 @@ int NetRunner::Main() {
TrainLoop();
float acc = CalculateAccuracy();
std::cout << "accuracy = " << acc << std::endl;
CalculateAccuracy();
if (cycles_ > 0) {
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(cycles_) + ".ms";

View File

@ -23,6 +23,7 @@
#include <vector>
#include <string>
#include "include/train_session.h"
#include "include/train/train_loop.h"
#include "include/ms_tensor.h"
#include "src/dataset.h"
@ -38,12 +39,13 @@ class NetRunner {
int InitDB();
int TrainLoop();
std::vector<int> FillInputData(const std::vector<DataLabelTuple> &dataset, bool is_train_set = false) const;
float CalculateAccuracy(int max_tests = -1) const;
float CalculateAccuracy(int max_tests = -1);
float GetLoss() const;
mindspore::tensor::MSTensor *SearchOutputsForSize(size_t size) const;
DataSet ds_;
mindspore::session::TrainSession *session_ = nullptr;
mindspore::session::TrainLoop *loop_ = nullptr;
std::string ms_file_ = "";
std::string data_dir_ = "";

View File

@ -17,9 +17,10 @@
#include "src/net_runner.h"
#include <math.h>
#include <getopt.h>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <fstream>
#include <iostream>
#include "include/context.h"
#include "src/utils.h"
@ -113,7 +114,7 @@ std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dat
int label = 0;
char *data = nullptr;
std::tie(data, label) = dataset[idx];
std::memcpy(input_data + i * data_size_, data, data_size_);
std::copy(data, data + data_size, input_data + i * data_size);
labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation
labels_vec.push_back(label);
}

View File

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_
#include <stdio.h>
#include <vector>
#include <string>
#include <utility>
#include <unordered_map>
#include "include/train/train_loop.h"
using GraphPoint = std::pair<int, float>;
namespace mindspore {
namespace lite {
class CkptSaver : public session::TrainLoopCallBack {
public:
CkptSaver(int save_every_n, const std::string &filename_prefix)
: save_every_n_(save_every_n), filename_prefix_(filename_prefix) {}
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
remove(cpkt_fn.c_str());
cb_data.session_->SaveToFile(cpkt_fn);
}
return session::RET_CONTINUE;
}
private:
int save_every_n_;
std::string filename_prefix_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_

View File

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_
#include <vector>
#include <string>
#include <utility>
#include <climits>
#include <unordered_map>
#include "include/train/train_loop.h"
using GraphPoint = std::pair<int, float>;
namespace mindspore {
namespace lite {
class ClassificationTrainAccuracyMonitor : public session::TrainLoopCallBack {
public:
explicit ClassificationTrainAccuracyMonitor(int print_every_n = INT_MAX) : print_every_n_(print_every_n) {}
virtual ~ClassificationTrainAccuracyMonitor() = default;
void Begin(const session::TrainLoopCallBackData &cb_data) override;
void EpochBegin(const session::TrainLoopCallBackData &cb_data) override;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
void StepEnd(const session::TrainLoopCallBackData &cb_data) override;
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
private:
std::vector<GraphPoint> accuracies_;
int print_every_n_ = 0;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_

View File

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_
#include <vector>
#include <string>
#include <utility>
#include <climits>
#include <unordered_map>
#include "include/train/train_loop_callback.h"
using GraphPoint = std::pair<int, float>;
namespace mindspore {
namespace lite {
class LossMonitor : public session::TrainLoopCallBack {
public:
explicit LossMonitor(int print_every_n = INT_MAX) : print_every_n_(print_every_n) {}
virtual ~LossMonitor() = default;
void Begin(const session::TrainLoopCallBackData &cb_data) override;
void EpochBegin(const session::TrainLoopCallBackData &cb_data) override;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
void StepEnd(const session::TrainLoopCallBackData &cb_data) override;
const std::vector<GraphPoint> &GetLossPoints() const { return losses_; }
private:
std::vector<GraphPoint> losses_;
int print_every_n_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_
#include <vector>
#include <string>
#include <utility>
#include <functional>
#include <unordered_map>
#include "include/train/train_loop_callback.h"
namespace mindspore {
namespace lite {
constexpr int DONT_UPDATE_LR = 0;
constexpr int UPDATE_LR = 1;
using LR_Lambda = std::function<int(float *lr, int epoch, void *cb_data)>;
/// \brief Multiply the LR by a factor of gamma every epoch
int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication);
/// \brief Multiply the LR by a factor of gamma every step_size
int StepLRLambda(float *lr, int epoch, void *step_size);
struct StepLRLambda {
StepLRLambda(int step, float g) : step_size(step), gamma(g) {}
int step_size; // period of LR decay
float gamma; // LR decay factor
};
class LRScheduler : public session::TrainLoopCallBack {
public:
explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step_ = 1);
virtual ~LRScheduler() = default;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
private:
LR_Lambda lambda_func_;
void *lr_data_ = nullptr;
int step_ = 1;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_

View File

@ -0,0 +1,69 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
#include <vector>
#include <string>
#include <tuple>
#include <unordered_map>
#include "include/train/train_loop_callback.h"
#include "include/train_session.h"
namespace mindspore {
namespace session {
class TrainLoop {
public:
/// \brief Static method to create a TrainLoop object
///
/// \param[in] filename Filename to read flatbuffer from
/// \param[in] context Defines the context of the session to be created
///
/// \return Pointer of MindSpore Lite TrainLoop
static TrainLoop *CreateTrainLoop(const std::string &model_filename, lite::Context *context, int batch_size = -1);
/// \brief Class destructor
virtual ~TrainLoop() = default;
/// \brief Resets the epoch counter
///
/// \return 0 on success or -1 in case of error
virtual int Reset() = 0; // resets the epoch counter to 0.
/// \brief Accessor to the TrainSession
///
/// \return pointer of the train_session
virtual session::TrainSession *train_session() = 0;
/// \brief Accessor to the Session KernelCallbacks
///
/// \param[in] before Define a call_back_function to be called before running each node.
/// \param[in] after Define a call_back_function called after running each node.
///
/// \return 0 on success or -1 in case of error
virtual int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) = 0;
/// \brief Performs the training Loop
///
/// \param[in] epoch The number of epochs to run
/// \param[in] cbs A vector of TrainLoopCallBack objects
///
/// \return 0 on success or -1 in case of error
virtual int Train(int epochs, std::vector<TrainLoopCallBack *> cbs) = 0;
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_

View File

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
#include <vector>
#include <string>
#include <tuple>
#include <unordered_map>
namespace mindspore {
namespace session {
class TrainSession;
class TrainLoop;
struct TrainLoopCallBackData {
TrainLoopCallBackData(bool train_mode, int epoch, TrainSession *session, TrainLoop *loop)
: train_mode_(train_mode), epoch_(epoch), session_(session), loop_(loop) {}
bool train_mode_; /**< training mode of TrainSession object */
unsigned int epoch_; /**< the current training epoch (starts at 0) */
unsigned int step_ = 0; /**< the current step within the epoch */
TrainSession *session_; /**< pointer to the TrainSession */
TrainLoop *loop_;
};
constexpr int RET_CONTINUE = 0;
constexpr int RET_STOP_TRAINING = 1;
constexpr int RET_EXIT = 2;
class TrainLoopCallBack {
public:
virtual ~TrainLoopCallBack() = default;
virtual void Begin(const TrainLoopCallBackData &cb_data) {}
virtual void End(const TrainLoopCallBackData &cb_data) {}
virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {}
virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; }
virtual void StepBegin(const TrainLoopCallBackData &cb_data) {}
virtual void StepEnd(const TrainLoopCallBackData &cb_data) {}
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_

View File

@ -83,6 +83,23 @@ class TrainSession : public session::LiteSession {
/// \return boolean indication if model is in eval mode
bool IsEval() { return train_mode_ == false; }
/// \brief Sets the Learning Rate of the training
///
/// \param[in] learning_rate to set
///
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int SetLearningRate(float learning_rate) = 0;
/// \brief Gets the Learning Rate of the training
///
/// \return learning rate. 0.0 if no optimizer was found
virtual float GetLearningRate() = 0;
/// \brief Get output MindSpore Lite MSTensors of Training model prediction
///
/// \return The map of output tensor name and MindSpore Lite MSTensor.
virtual std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetPredictions() const = 0;
protected:
bool train_mode_ = false;
};

View File

@ -0,0 +1,178 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.lite;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.mindspore.lite.config.MSConfig;
public class TrainSession {
static {
System.loadLibrary("mindspore-lite-jni");
}
private long sessionPtr;
public TrainSession() {
this.sessionPtr = 0;
}
public boolean init(String modelFilename, MSConfig config) {
this.sessionPtr = createSession(modelFilename, config.getMSConfigPtr());
return this.sessionPtr != 0;
}
public long getSessionPtr() {
return sessionPtr;
}
public void bindThread(boolean if_bind) {
this.bindThread(this.sessionPtr, if_bind);
}
public boolean runGraph() {
return this.runGraph(this.sessionPtr);
}
public List<MSTensor> getInputs() {
List<Long> ret = this.getInputs(this.sessionPtr);
ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
for (Long ms_tensor_addr : ret) {
MSTensor msTensor = new MSTensor(ms_tensor_addr);
tensors.add(msTensor);
}
return tensors;
}
public MSTensor getInputsByTensorName(String tensorName) {
Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName);
if(tensor_addr == null){
return null;
}
MSTensor msTensor = new MSTensor(tensor_addr);
return msTensor;
}
public List<MSTensor> getOutputsByNodeName(String nodeName) {
List<Long> ret = this.getOutputsByNodeName(this.sessionPtr, nodeName);
ArrayList<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
}
return tensors;
}
public Map<String, MSTensor> getOutputMapByTensor() {
Map<String, Long> ret = this.getOutputMapByTensor(this.sessionPtr);
Map<String, MSTensor> tensorMap = new HashMap<>();
Set<Map.Entry<String, Long>> entrySet = ret.entrySet();
for (Map.Entry<String, Long> entry : entrySet) {
String name = entry.getKey();
Long msTensorAddr = entry.getValue();
tensorMap.put(name, new MSTensor(msTensorAddr));
}
return tensorMap;
}
public List<String> getOutputTensorNames() {
return getOutputTensorNames(this.sessionPtr);
}
public MSTensor getOutputByTensorName(String tensorName) {
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName);
if(tensor_addr == null){
return null;
}
return new MSTensor(tensor_addr);
}
public void free() {
this.free(this.sessionPtr);
this.sessionPtr = 0;
}
public boolean resize(List<MSTensor> inputs, int[][] dims) {
long[] inputs_array = new long[inputs.size()];
for (int i = 0; i < inputs.size(); i++) {
inputs_array[i] = inputs.get(i).getMSTensorPtr();
}
return this.resize(this.sessionPtr, inputs_array, dims);
}
public boolean saveToFile(String modelFilename) {
return this.saveToFile(this.sessionPtr, modelFilename);
}
public boolean train() {
return this.train(this.sessionPtr);
}
public boolean eval() {
return this.eval(this.sessionPtr);
}
public boolean isTrain() {
return this.isTrain(this.sessionPtr);
}
public boolean isEval() {
return this.isEval(this.sessionPtr);
}
public boolean setLearningRate(float learning_rate) {
return this.setLearningRate(this.sessionPtr, learning_rate);
}
private native long createSession(String modelFilename, long msConfigPtr);
private native void bindThread(long sessionPtr, boolean if_bind);
private native boolean runGraph(long sessionPtr);
private native List<Long> getInputs(long sessionPtr);
private native long getInputsByTensorName(long sessionPtr, String tensorName);
private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName);
private native Map<String, Long> getOutputMapByTensor(long sessionPtr);
private native List<String> getOutputTensorNames(long sessionPtr);
private native long getOutputByTensorName(long sessionPtr, String tensorName);
private native void free(long sessionPtr);
private native boolean resize(long sessionPtr, long[] inputs, int[][] dims);
private native boolean saveToFile(long sessionPtr, String modelFilename);
private native boolean train(long sessionPtr);
private native boolean eval(long sessionPtr);
private native boolean isTrain(long sessionPtr);
private native boolean isEval(long sessionPtr);
private native boolean setLearningRate(long sessionPtr, float learning_rate);
}

View File

@ -7,8 +7,10 @@ set(PLATFORM_ARM "on")
set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR})
set(MS_VERSION_MINOR ${MS_VERSION_MINOR})
set(MS_VERSION_REVISION ${MS_VERSION_REVISION})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
-DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
-DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
#set for cross-compiling toolchain
@ -16,16 +18,16 @@ set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
if (ENABLE_VERBOSE)
if(ENABLE_VERBOSE)
set(CMAKE_VERBOSE_MAKEFILE on)
endif ()
endif()
if (PLATFORM_ARM32)
if(PLATFORM_ARM32)
add_compile_definitions(ENABLE_ARM32)
endif ()
if (PLATFORM_ARM64)
endif()
if(PLATFORM_ARM64)
add_compile_definitions(ENABLE_ARM64)
endif ()
endif()
set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../../..)
set(LITE_DIR ${TOP_DIR}/mindspore/lite)
@ -40,15 +42,25 @@ include_directories(${LITE_DIR}/build) ## flatbuffers
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../libs/${ANDROID_ABI}/)
add_library(mindspore-lite-jni SHARED
${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/version.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_config.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp
)
set(JNI_SRC
${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/version.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_config.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp
)
if(SUPPORT_TRAIN)
set(JNI_SRC
${JNI_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
)
endif()
add_library(mindspore-lite-jni SHARED ${JNI_SRC})
find_library(log-lib log)
target_link_libraries(mindspore-lite-jni mindspore-lite ${log-lib})
target_link_libraries(mindspore-lite-jni mindspore-lite ${log-lib})

View File

@ -0,0 +1,305 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <jni.h>
#include "common/ms_log.h"
#include "common/jni_utils.h"
#include "include/train_session.h"
#include "include/errorcode.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createSession(JNIEnv *env, jobject thiz,
jstring model_file_name,
jlong ms_config_ptr) {
auto *pointer = reinterpret_cast<void *>(ms_config_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return jlong(nullptr);
}
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
auto session = mindspore::session::TrainSession::CreateSession(JstringToChar(env, model_file_name), lite_context_ptr);
if (session == nullptr) {
MS_LOGE("CreateSession failed");
return jlong(nullptr);
}
return jlong(session);
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_TrainSession_bindThread(JNIEnv *env, jobject thiz,
jlong session_ptr, jboolean if_bind) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
train_session_ptr->BindThread(if_bind);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_runGraph(JNIEnv *env, jobject thiz,
jlong session_ptr) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
auto ret = train_session_ptr->RunGraph();
return (jboolean)(ret == mindspore::lite::RET_OK);
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getInputs(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
auto inputs = train_session_ptr->GetInputs();
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
}
return ret;
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_getInputsByTensorName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring tensor_name) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return jlong(nullptr);
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
auto input = train_session_ptr->GetInputsByTensorName(JstringToChar(env, tensor_name));
return jlong(input);
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getOutputsByNodeName(JNIEnv *env,
jobject thiz,
jlong session_ptr,
jstring node_name) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
auto inputs = train_session_ptr->GetOutputsByNodeName(JstringToChar(env, node_name));
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
}
return ret;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getOutputMapByTensor(JNIEnv *env,
jobject thiz,
jlong session_ptr) {
jclass hash_map_clazz = env->FindClass("java/util/HashMap");
jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V");
jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct);
jmethodID hash_map_put =
env->GetMethodID(hash_map_clazz, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return hash_map;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
auto outputs = train_session_ptr->GetOutputs();
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
for (auto output_iter : outputs) {
auto node_name = output_iter.first;
auto ms_tensor = output_iter.second;
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor));
env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), tensor_addr);
}
return hash_map;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getOutputTensorNames(JNIEnv *env,
jobject thiz,
jlong session_ptr) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
auto output_names = train_session_ptr->GetOutputTensorNames();
for (auto output_name : output_names) {
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
}
return ret;
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_getOutputByTensorName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring tensor_name) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return jlong(nullptr);
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
auto output = train_session_ptr->GetOutputByTensorName(JstringToChar(env, tensor_name));
return jlong(output);
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_TrainSession_free(JNIEnv *env, jobject thiz,
jlong session_ptr) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
delete (train_session_ptr);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_resize(JNIEnv *env, jobject thiz,
jlong session_ptr, jlongArray inputs,
jobjectArray dims) {
std::vector<std::vector<int>> c_dims;
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer);
jsize input_size = static_cast<int>(env->GetArrayLength(inputs));
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
std::vector<mindspore::tensor::MSTensor *> c_inputs;
for (int i = 0; i < input_size; i++) {
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
if (tensor_pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return false;
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer);
c_inputs.push_back(ms_tensor_ptr);
}
jsize tensor_size = static_cast<int>(env->GetArrayLength(dims));
for (int i = 0; i < tensor_size; i++) {
jintArray array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i));
jsize dim_size = static_cast<int>(env->GetArrayLength(array));
jint *dim_data = env->GetIntArrayElements(array, nullptr);
std::vector<int> tensor_dims;
for (int j = 0; j < dim_size; j++) {
tensor_dims.push_back(dim_data[j]);
}
c_dims.push_back(tensor_dims);
}
int ret = train_session_ptr->Resize(c_inputs, c_dims);
return (jboolean)(ret == mindspore::lite::RET_OK);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_saveToFile(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring model_file_name) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer);
auto ret = train_session_ptr->SaveToFile(JstringToChar(env, model_file_name));
return (jboolean)(ret == 0);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_train(JNIEnv *env, jobject thiz,
jlong session_ptr) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer);
auto ret = train_session_ptr->Train();
return (jboolean)(ret == mindspore::lite::RET_OK);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_eval(JNIEnv *env, jobject thiz,
jlong session_ptr) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer);
auto ret = train_session_ptr->Eval();
return (jboolean)(ret == mindspore::lite::RET_OK);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_isTrain(JNIEnv *env, jobject thiz,
jlong session_ptr) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer);
auto ret = train_session_ptr->IsTrain();
return (jboolean)(ret);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_isEval(JNIEnv *env, jobject thiz,
jlong session_ptr) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer);
auto ret = train_session_ptr->IsEval();
return (jboolean)(ret);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setLearningRate(JNIEnv *env, jobject thiz,
jlong session_ptr,
jfloat learning_rate) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer);
auto ret = train_session_ptr->SetLearningRate(learning_rate);
return (jboolean)(ret == mindspore::lite::RET_OK);
}

View File

@ -800,8 +800,8 @@ int ElementSubRelu6(const float *in0, const float *in1, float *out, int size) {
int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size,
ArithmeticParameter *param) {
TileDimensionsFp32(in0, in1, tile_in0, tile_in0, param);
return ElementDiv(tile_in0, tile_in0, out, size);
TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param);
return ElementDiv(tile_in0, tile_in1, out, size);
}
int ElementDiv(const float *in0, const float *in1, float *out, int size) {

View File

@ -21,32 +21,48 @@
#include "nnacl/errorcode.h"
inline int ReluGrad(float *src0, float *src1, size_t length, float *dst) {
for (size_t i = 0; i < length; ++i) {
if (src1[i] > 0) {
dst[i] = src0[i];
} else {
dst[i] = 0;
}
int i = 0;
#ifdef ENABLE_ARM
float32x4_t zero_4 = vdupq_n_f32(0.0f);
for (; i < length - 4; i += 4) {
float32x4_t src1_4 = vld1q_f32(src1 + i);
float32x4_t src0_4 = vld1q_f32(src0 + i);
uint32x4_t mask_4 = vcgtq_f32(src1_4, zero_4);
float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4);
vst1q_f32(dst + i, dst_4);
}
#endif
for (; i < length; ++i) {
dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f;
}
return NNACL_OK;
}
int Relu6Grad(float *src0, float *src1, size_t length, float *dst) {
for (size_t i = 0; i < length; ++i) {
if (src1[i] > 0.0f && src1[i] <= 6.0f) {
dst[i] = src0[i];
} else {
dst[i] = 0.0f;
}
int i = 0;
#ifdef ENABLE_ARM
float32x4_t zero_4 = vdupq_n_f32(0.0f);
float32x4_t six_4 = vdupq_n_f32(6.0f);
for (; i < length - 4; i += 4) {
float32x4_t src1_4 = vld1q_f32(src1 + i);
float32x4_t src0_4 = vld1q_f32(src0 + i);
float32x4_t max_4 = vmaxq_f32(src1_4, zero_4);
float32x4_t min_max_4 = vminq_f32(max_4, six_4);
uint32x4_t mask_4 = vceqq_f32(min_max_4, src1_4);
float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4);
vst1q_f32(dst + i, dst_4);
}
#endif
for (; i < length; ++i) {
dst[i] = (src1[i] > 0.0f && src1[i] <= 6.0f) ? src0[i] : 0.0f;
}
return NNACL_OK;
}
int LReluGrad(float *src0, float *src1, size_t length, float *dst, float alpha) {
for (size_t i = 0; i < length; ++i) {
dst[i] = src1[i] > 0.0f ? 1.0f : alpha;
dst[i] = src1[i] > 0.0f ? src0[i] : alpha * src0[i];
}
ElementMul(src0, dst, dst, length);
return NNACL_OK;
}

View File

@ -17,55 +17,36 @@
#include <string.h>
#include "nnacl/fp32_grad/batch_norm.h"
void sumSpatialBatch(const float *in, size_t size, int ch, float *out) {
memset(out, 0, ch * sizeof(float));
for (size_t i = 0; i < size; i++) {
const float *ptr = in + (i * ch);
for (size_t c = 0; c < ch; c++) {
out[c] += ptr[c];
}
}
}
void backwardX(const float *in, const float *dout, const float *scale, const size_t size, int channels, float *mean,
float *invar, float *dxhathat_sum, float *dxhat_sum, float *out) {
const float N = (size);
for (size_t i = 0; i < size; i++) {
for (size_t f = 0; f < channels; f++) {
size_t ix = i * channels + f;
float x_hat = (in[ix] - mean[f]) * invar[f];
float dx_hat = dout[ix] * scale[f];
dxhat_sum[f] += dx_hat;
dxhathat_sum[f] += dx_hat * x_hat;
}
}
for (size_t i = 0; i < size; i++) {
for (size_t f = 0; f < channels; f++) {
size_t ix = i * channels + f;
float x_hat = (in[ix] - mean[f]) * invar[f];
float dx_hat = dout[ix] * scale[f];
out[ix] = 1.0f / N * (invar[f]) * (N * dx_hat - dxhat_sum[f] - x_hat * dxhathat_sum[f]);
}
}
}
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n,
int size, float *scale_updates) {
size_t i, b, f;
memset(scale_updates, 0, n * sizeof(float));
for (b = 0; b < batch; ++b) {
for (i = 0; i < size; ++i) {
for (f = 0; f < n; ++f) {
int index = (b * size + i) * n + f;
float x_norm = (x[index] - mean[f]) * invar[f];
scale_updates[f] += (delta[index] * x_norm);
}
}
}
}
void var2Invar(float *save_var, size_t size, float eps) {
for (size_t i = 0; i < size; i++) {
void var2Invar(float *save_var, int size, float eps) {
for (int i = 0; i < size; i++) {
save_var[i] = 1.0f / sqrt(save_var[i] + eps);
}
}
void backwardAll(const float *restrict in, const float *restrict yt, const float *restrict mean,
const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dxhat_sum,
float *restrict dxhathat_sum, float *restrict dbias, float *restrict dscale, float *restrict dx) {
float N = (float)size;
for (int i = 0; i < size; i++) {
for (int c = 0; c < ch; c++) {
int ix = i * ch + c;
dbias[c] += yt[ix];
// dscale
float x_hat = (in[ix] - mean[c]) * invar[c];
dscale[c] += (yt[ix] * x_hat);
// dx_1
float dx_hat = yt[ix] * scale[c];
dxhat_sum[c] += dx_hat;
dxhathat_sum[c] += dx_hat * x_hat;
}
}
for (int i = 0; i < size; i++) {
for (int c = 0; c < ch; c++) {
// dx_2
int ix = i * ch + c;
float x_hat = (in[ix] - mean[c]) * invar[c];
float dx_hat = yt[ix] * scale[c];
dx[ix] = 1.0f / N * (invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c]);
}
}
}

View File

@ -29,13 +29,9 @@ typedef struct BNGradParameter {
extern "C" {
#endif
void sumSpatialBatch(const float *in, size_t size, int ch, float *out);
void backwardX(const float *in, const float *dout, const float *scale, const size_t size, int channels, float *mean,
float *invar, float *xhat_sum, float *dxhat_sum, float *out);
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n,
int size, float *scale_updates);
void var2Invar(float *save_var, size_t size, float eps);
void var2Invar(float *save_var, int size, float eps);
void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx);
#ifdef __cplusplus
}
#endif

View File

@ -20,7 +20,7 @@
int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y,
const float *weight, const float *dloss, float *dx) {
const float epsilon = 1e-12;
const float epsilon = 1e-12f;
if (reduction == 0) {
for (int i = 0; i < input_size; i++) {
float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);

View File

@ -21,7 +21,7 @@
#endif
#include "nnacl/fp32/matmul_fp32.h"
static void addv(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) {
void AddMatrix(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) {
const float *src_ptr = v1;
float *dst_ptr = v2;
for (int r = 0; r < row; r++) {
@ -86,7 +86,8 @@ static void RowMajor2Row12MajorStride(const float *src_ptr, float *dst_ptr, int
return;
}
static void RowMajor2Col12MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) {
static void RowMajor2Col12MajorStride(const float *restrict src_ptr, float *restrict dst_ptr, size_t row, size_t col,
int lead) {
size_t row_up_12 = UP_ROUND(row, C12NUM);
size_t row12 = row / C12NUM * C12NUM;
size_t col4 = col / C4NUM * C4NUM;
@ -549,7 +550,7 @@ void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const floa
#else
MatMulOpt(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc);
#endif
if (incremental) addv(output, mat_c, beta, M, N, ldc);
if (incremental) AddMatrix(output, mat_c, beta, M, N, ldc);
gcb->mat_a = mat_a_input;
gcb->mat_b = mat_b_input;
}

View File

@ -37,6 +37,7 @@ void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *m
int ldb, float beta, float *mat_c, int ldc, float *workspace);
int MatSize(int row, int col, int round);
int MatSizeTotal(int row, int col, int deep, int inc);
void AddMatrix(const float *v1, float *v2, float beta, int row, int col, int stride);
#ifdef __cplusplus
}
#endif

View File

@ -18,9 +18,8 @@
#include "nnacl/fp32_grad/pack_ext.h"
#include "nnacl/pack.h"
static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); }
void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start) {
void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num,
int start) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
@ -43,22 +42,43 @@ void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParamet
int kernel_row, kernel_col;
for (int i = 0; i < rows; i++) {
int block_start = start + i;
int input_h = block_start / output_w * stride_h;
int input_w = block_start % output_w * stride_w;
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
int input_row = -pad_up + kernel_row * dilation_h + input_h;
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_col = -pad_left + kernel_col * dilation_w + input_w;
if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
const int offset = (input_row * in_width + input_col) * tot_channels;
memcpy(data_col, in_data + offset, sizeof(float) * channels);
data_col += channels;
} else {
memset(data_col, 0, sizeof(float) * channels);
data_col += channels;
if (channels == 1) {
for (int i = 0; i < real_cal_num; i++) {
int block_start = start + i;
int input_h = block_start / output_w * stride_h;
int input_w = block_start % output_w * stride_w;
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
int input_row = -pad_up + kernel_row * dilation_h + input_h;
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_col = -pad_left + kernel_col * dilation_w + input_w;
if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) {
const int offset = (input_row * in_width + input_col) * tot_channels;
*data_col = in_data[offset];
data_col++;
} else {
*data_col = 0;
data_col++;
}
}
}
}
} else {
for (int i = 0; i < real_cal_num; i++) {
int block_start = start + i;
int input_h = block_start / output_w * stride_h;
int input_w = block_start % output_w * stride_w;
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
int input_row = -pad_up + kernel_row * dilation_h + input_h;
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_col = -pad_left + kernel_col * dilation_w + input_w;
if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) {
const int offset = (input_row * in_width + input_col) * tot_channels;
memcpy(data_col, in_data + offset, sizeof(float) * channels);
data_col += channels;
} else {
memset(data_col, 0, sizeof(float) * channels);
data_col += channels;
}
}
}
}
@ -70,7 +90,6 @@ void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *con
rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index);
}
// output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w)
void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, bool transpose) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
@ -100,14 +119,14 @@ void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_up + kernel_row * dilation_h;
for (output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) {
if (!((unsigned)(input_row) < (unsigned)(in_height))) {
for (output_col = output_w; output_col; output_col--) {
*(data_row++) = 0;
}
} else {
int input_col = -pad_left + kernel_col * dilation_w;
for (output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
if (((unsigned)(input_col) < (unsigned)(in_width))) {
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
*(data_row++) = in_data[offset];
} else {
@ -127,14 +146,14 @@ void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv
for (channel = 0; channel < channels; channel++) {
int input_row = -pad_up + kernel_row * dilation_h;
for (output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) {
if (!((unsigned)(input_row) < (unsigned)(in_height))) {
for (output_col = output_w; output_col; output_col--) {
*(data_row++) = 0;
}
} else {
int input_col = -pad_left + kernel_col * dilation_w;
for (output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
if (((unsigned)(input_col) < (unsigned)(in_width))) {
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
*(data_row++) = in_data[offset];
} else {
@ -150,7 +169,6 @@ void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv
}
}
}
void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
@ -177,14 +195,14 @@ void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParamet
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
for (output_rows = start; output_rows < start + rows; output_rows++) {
int input_row = -pad_up + kernel_row * dilation_h + output_rows * stride_h;
if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) {
if (!((unsigned)(input_row) < (unsigned)(in_height))) {
for (output_col = output_w; output_col; output_col--) {
*(data_row++) = 0;
}
} else {
int input_col = -pad_left + kernel_col * dilation_w;
for (output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
if (((unsigned)(input_col) < (unsigned)(in_width))) {
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
*(data_row++) = in_data[offset];
} else {
@ -193,7 +211,6 @@ void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParamet
input_col += stride_w;
}
}
// input_row += stride_h;
}
}
}
@ -232,8 +249,7 @@ void col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv
int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset;
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset;
if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) {
int offset = (input_row * in_width + input_col) * tot_channels;
float *data_im_ptr = &data_im[offset];
for (int i = 0; i < channels; i++) {
@ -271,20 +287,36 @@ void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParamet
int kernel_row, kernel_col;
for (int r = 0; r < rows; r++) {
int output_col = (start + r) % output_w;
int output_row = (start + r) / output_w;
int row_stride_offset = output_row * stride_h;
int col_stride_offset = output_col * stride_w;
// for (output_col = 0; output_col < output_w; output_col++)
{
if (channels == 1) {
for (int r = 0; r < rows; r++) {
int output_col = (start + r) % output_w;
int output_row = (start + r) / output_w;
int row_stride_offset = output_row * stride_h;
int col_stride_offset = output_col * stride_w;
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset;
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset;
if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) {
int offset = (input_row * in_width + input_col) * tot_channels;
float *data_im_ptr = &data_im[offset];
*data_im_ptr += *data_col;
}
data_col++;
}
}
}
} else {
for (int r = 0; r < rows; r++) {
int output_col = (start + r) % output_w;
int output_row = (start + r) / output_w;
int row_stride_offset = output_row * stride_h;
int col_stride_offset = output_col * stride_w;
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset;
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset;
if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) {
int offset = (input_row * in_width + input_col) * tot_channels;
float *data_im_ptr = &data_im[offset];
for (int i = 0; i < channels; i++) {

View File

@ -308,7 +308,7 @@ table SoftmaxCrossEntropy {
}
table SparseSoftmaxCrossEntropy {
isGrad: int;
isGrad: bool;
}
table make_tuple {
@ -1225,11 +1225,9 @@ table SmoothL1LossGrad {
}
table SigmoidCrossEntropyWithLogits {
beta : float;
}
table SigmoidCrossEntropyWithLogitsGrad {
beta : float;
}
table Reciprocal {

View File

@ -65,7 +65,10 @@ if(SUPPORT_TRAIN)
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_model.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/loss_monitor.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/lr_scheduler.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
)
endif()

View File

@ -19,6 +19,9 @@
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
#ifdef SUPPORT_TRAIN
#include <tuple>
#endif
namespace mindspore {
namespace lite {
@ -53,12 +56,20 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
}
string paddingmode = "REFLECT";
if (prim.GetAttr("mode") == nullptr) {
MS_LOG(ERROR) << "get mode failed!";
delete this->primitive_;
delete attr;
this->primitive_ = nullptr;
attr = nullptr;
return RET_ERROR;
#ifdef SUPPORT_TRAIN
if (prim.name() == "Pad") {
paddingmode = "CONSTANT";
} else {
#endif
MS_LOG(ERROR) << "get mode failed!";
delete this->primitive_;
delete attr;
this->primitive_ = nullptr;
attr = nullptr;
return RET_ERROR;
#ifdef SUPPORT_TRAIN
}
#endif
} else {
paddingmode = GetValue<string>(prim.GetAttr("mode"));
}
@ -66,6 +77,21 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
attr->paddingMode = schema::PaddingMode_REFLECT;
} else if (paddingmode == "SYMMETRIC") {
attr->paddingMode = schema::PaddingMode_SYMMETRIC;
#ifdef SUPPORT_TRAIN
} else if (paddingmode == "CONSTANT") {
attr->paddingMode = schema::PaddingMode_CONSTANT;
if (prim.GetAttr("paddings") != nullptr) {
auto paddings = prim.GetAttr("paddings");
auto str = (*paddings).ToString();
std::replace(str.begin(), str.end(), ',', ' ');
std::replace(str.begin(), str.end(), ')', ' ');
std::replace(str.begin(), str.end(), '(', ' ');
std::stringstream ss(str);
for (int i; ss >> i;) {
attr->paddings.push_back(i);
}
}
#endif
} else {
MS_LOG(ERROR) << "model type not supported!";
delete this->primitive_;

View File

@ -674,7 +674,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
} else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" ||
op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) {
return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType);
} else if ((op_type == "MaxPoolGrad") || (op_type == "AvgPoolGrad") || (op_type == "AvgPoolGradGpu")) {
} else if ((op_type == "MaxPoolGrad") || (op_type == "AvgPoolGrad") || (op_type == "AvgPoolGradGpu") ||
(op_type == "AvgPoolGradCpu")) {
return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType);
} else if (op_type == "Conv2DBackpropFilter") {
return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType);
@ -684,7 +685,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
} else if (op_type == "FlattenGrad") {
return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType);
} else if (op_type == "FusedBatchNormGrad") {
} else if ((op_type == "FusedBatchNormGrad") || (op_type == "FusedBatchNormGradCpu")) {
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
} else if (op_type == "PowerGrad") {
return NewPrimitiveC<PowerGrad>(prim, inputs, quantType);
@ -714,6 +715,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<SigmoidCrossEntropyWithLogits>(prim, inputs, quantType);
} else if (op_type == "SigmoidCrossEntropyWithLogitsGrad") {
return NewPrimitiveC<SigmoidCrossEntropyWithLogitsGrad>(prim, inputs, quantType);
} else if (op_type == "Pad") {
return NewPrimitiveC<Pad>(prim, inputs, quantType);
#else
} else if (op_type == "Conv2DBackpropInput") {
return NewPrimitiveC<DeConv2D>(prim, inputs, quantType);

View File

@ -237,6 +237,9 @@ int Convolution1x1CPUKernel::Run() {
MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!";
return RET_MEMORY_FAILED;
}
if (IsTrain()) {
PackWeight();
}
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
output_ptr_ = src_out + batch_index * matmul_param_->row_ * matmul_param_->col_;
@ -261,4 +264,45 @@ int Convolution1x1CPUKernel::Run() {
}
return RET_OK;
}
void Convolution1x1CPUKernel::PackWeight() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
auto input_channel = filter_tensor->Channel();
auto output_channel = filter_tensor->Batch();
#ifdef ENABLE_AVX
row_tile_ = C6NUM;
col_tile_ = C16NUM;
#elif defined(ENABLE_SSE)
row_tile_ = C4NUM;
col_tile_ = C8NUM;
#elif defined(ENABLE_ARM32)
row_tile_ = C12NUM;
col_tile_ = C4NUM;
#else
row_tile_ = C12NUM;
col_tile_ = C8NUM;
#endif
int size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float);
int down_size = input_channel * DOWN_DIV(output_channel, col_tile_) * col_tile_ * sizeof(float);
memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size);
#ifdef ENABLE_AVX
RowMajor2Col16Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
#elif defined(ENABLE_ARM32)
RowMajor2Col4Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
#else
RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
#endif
}
int Convolution1x1CPUKernel::Eval() {
LiteKernel::Eval();
PackWeight();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_FP32_H_
#include <float.h>
#include <vector>
@ -42,6 +42,7 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
int Init() override;
int Run() override;
int ReSize() override;
int Eval() override;
public:
int DoConv1x1(int task_id);
@ -53,6 +54,7 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
void InitConv1x1MatmulParam();
void FreeTmpBuffer();
void PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col);
void PackWeight();
private:
MatMulParameter *matmul_param_ = nullptr;
@ -70,4 +72,4 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
int col_tile_ = 0;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_FP32_H_

View File

@ -47,6 +47,15 @@ class ConvolutionDelegateCPUKernel : public LiteKernel {
static float *CopyData(lite::Tensor *tensor);
void FreeCopiedData();
int Eval() override {
LiteKernel::Eval();
return conv_kernel_->Eval();
}
int Train() override {
LiteKernel::Train();
return conv_kernel_->Train();
}
protected:
bool need_free_weight_ = false;
bool need_free_bias_ = false;

View File

@ -127,6 +127,10 @@ int ConvolutionDepthwise3x3CPUKernel::Run() {
return ret;
}
if (IsTrain()) {
PackWeight();
}
auto input_tensor = in_tensors_.at(kInputIndex);
input_ptr_ = reinterpret_cast<float *>(input_tensor->data_c());
@ -146,4 +150,18 @@ int ConvolutionDepthwise3x3CPUKernel::Run() {
context_->allocator->Free(buffer_);
return RET_OK;
}
void ConvolutionDepthwise3x3CPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch());
}
int ConvolutionDepthwise3x3CPUKernel::Eval() {
LiteKernel::Eval();
PackWeight();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_
#include <vector>
#include "src/lite_kernel.h"
@ -37,8 +37,10 @@ class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel {
int InitWeightBias();
int Execute(int task_id);
int Eval() override;
private:
void PackWeight();
int InitBuffer();
SlidingWindowParam *sliding_ = nullptr;
float *packed_weight_ = nullptr;
@ -48,4 +50,4 @@ class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel {
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_

View File

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h"
#include <limits>
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h"
#include "schema/model_generated.h"
@ -104,6 +105,10 @@ int ConvDwRun(void *cdata, int task_id) {
}
int ConvolutionDepthwiseCPUKernel::Run() {
if (IsTrain()) {
PackWeight();
}
auto input_tensor = in_tensors_.at(kInputIndex);
input_ptr_ = reinterpret_cast<float *>(input_tensor->MutableData());
@ -118,6 +123,19 @@ int ConvolutionDepthwiseCPUKernel::Run() {
return RET_OK;
}
void ConvolutionDepthwiseCPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch());
}
int ConvolutionDepthwiseCPUKernel::Eval() {
LiteKernel::Eval();
PackWeight();
return RET_OK;
}
kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const InnerContext *ctx, const kernel::KernelKey &desc,

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_FP32_H_
#include <vector>
#include "src/lite_kernel.h"
@ -37,12 +37,14 @@ class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel {
int InitWeightBias();
int Execute(int task_id);
int Eval() override;
private:
void PackWeight();
float *packed_weight_ = nullptr;
float *input_ptr_ = nullptr;
float *output_ptr_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_FP32_H_

View File

@ -190,6 +190,10 @@ int ConvolutionDepthwiseIndirectCPUKernel::Run() {
packed_input_ = input_ptr;
}
if (IsTrain()) {
PackWeight();
}
auto output_tensor = out_tensors_.at(kOutputIndex);
output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c());
@ -205,4 +209,23 @@ int ConvolutionDepthwiseIndirectCPUKernel::Run() {
}
return RET_OK;
}
void ConvolutionDepthwiseIndirectCPUKernel::PackWeight() {
auto weight_tensor = in_tensors_[kWeightIndex];
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
#ifdef ENABLE_AVX
PackDepthwiseIndirectWeightC8Fp32(origin_weight, packed_weight_, weight_tensor->Height(), weight_tensor->Width(),
weight_tensor->Batch());
#else
PackDepthwiseIndirectWeightC4Fp32(origin_weight, packed_weight_, weight_tensor->Height(), weight_tensor->Width(),
weight_tensor->Batch());
#endif
}
int ConvolutionDepthwiseIndirectCPUKernel::Eval() {
LiteKernel::Eval();
PackWeight();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_FP32_H_
#include <vector>
#include "src/lite_kernel.h"
@ -37,10 +37,12 @@ class ConvolutionDepthwiseIndirectCPUKernel : public ConvolutionBaseCPUKernel {
int InitWeightBias();
int Execute(int task_id);
int Eval() override;
private:
int MallocIndirectBuffer();
int MallocPackedInput();
void PackWeight();
int step_w = 0;
int step_h = 0;
float **indirect_buffer_ = nullptr;
@ -51,4 +53,4 @@ class ConvolutionDepthwiseIndirectCPUKernel : public ConvolutionBaseCPUKernel {
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_FP32_H_

View File

@ -145,6 +145,11 @@ int ConvolutionDepthwiseSWCPUKernel::Run() {
FreePackedInputOutput();
return RET_ERROR;
}
if (IsTrain()) {
PackWeight();
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_ptr = reinterpret_cast<float *>(input_tensor->MutableData());
@ -183,4 +188,18 @@ void ConvolutionDepthwiseSWCPUKernel::FreePackedInputOutput() {
packed_output_ = nullptr;
}
}
void ConvolutionDepthwiseSWCPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch());
}
int ConvolutionDepthwiseSWCPUKernel::Eval() {
LiteKernel::Eval();
PackWeight();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_FP32_H_
#include <vector>
#include "src/lite_kernel.h"
@ -37,10 +37,12 @@ class ConvolutionDepthwiseSWCPUKernel : public ConvolutionBaseCPUKernel {
int InitWeightBias();
int Execute(int task_id);
int Eval() override;
private:
int InitPackedInputOutput();
void FreePackedInputOutput();
void PackWeight();
SlidingWindowParam *sliding_ = nullptr;
float *packed_weight_ = nullptr;
float *packed_input_ = nullptr;
@ -49,4 +51,4 @@ class ConvolutionDepthwiseSWCPUKernel : public ConvolutionBaseCPUKernel {
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_FP32_H_

View File

@ -150,6 +150,9 @@ int ConvolutionCPUKernel::Run() {
FreeTmpBuffer();
return RET_ERROR;
}
if (IsTrain()) {
PackWeight();
}
ret = ParallelLaunch(this->context_->thread_pool_, ConvolutionImpl, this, thread_count_);
if (ret != RET_OK) {
@ -158,4 +161,37 @@ int ConvolutionCPUKernel::Run() {
FreeTmpBuffer();
return ret;
}
void ConvolutionCPUKernel::PackWeight() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
int in_channel = filter_tensor->Channel();
int out_channel = filter_tensor->Batch();
int kernel_plane = filter_tensor->Height() * filter_tensor->Width();
#ifdef ENABLE_AVX
const int oc_block = C16NUM;
#elif ENABLE_ARM32
const int oc_block = C4NUM;
#else
const int oc_block = C8NUM;
#endif
int oc_block_num = UP_ROUND(out_channel, oc_block);
int pack_weight_size = oc_block_num * in_channel * kernel_plane;
auto origin_weight = reinterpret_cast<float *>(filter_tensor->data_c());
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
#ifdef ENABLE_AVX
RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
#elif ENABLE_ARM32
RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
#else
RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
#endif
}
int ConvolutionCPUKernel::Eval() {
LiteKernel::Eval();
PackWeight();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -46,7 +46,10 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
int Run() override;
virtual int RunImpl(int task_id);
int Eval() override;
protected:
void PackWeight();
void FreeTmpBuffer() {
if (packed_input_ != nullptr) {
ctx_->allocator->Free(packed_input_);

View File

@ -58,10 +58,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float);
trans_weight_ = reinterpret_cast<float *>(malloc(trans_matrix_data_size));
if (trans_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
return RET_MEMORY_FAILED;
trans_weight_ = reinterpret_cast<float *>(malloc(trans_matrix_data_size));
if (trans_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
return RET_MEMORY_FAILED;
}
}
memset(trans_weight_, 0, trans_matrix_data_size);
@ -217,6 +219,9 @@ int ConvolutionWinogradCPUKernel::Run() {
FreeTmpBuffer();
return RET_ERROR;
}
if (IsTrain()) {
InitWeightBias();
}
ret = ParallelLaunch(this->context_->thread_pool_, ConvolutionWinogradImpl, this, thread_count_);
if (ret != RET_OK) {
@ -226,4 +231,11 @@ int ConvolutionWinogradCPUKernel::Run() {
FreeTmpBuffer();
return ret;
}
int ConvolutionWinogradCPUKernel::Eval() {
LiteKernel::Eval();
InitWeightBias();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_FP32_H_
#include <vector>
#include "src/lite_kernel.h"
@ -43,6 +43,7 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
int Init() override;
int ReSize() override;
int Run() override;
int Eval() override;
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
@ -84,4 +85,4 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_FP32_H_

View File

@ -48,33 +48,27 @@ int ActivationGradCPUKernel::DoActivation(int task_id) {
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
int length = in_tensors_.at(0)->ElementsNum();
int stride = UP_DIV(length, 1);
int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
auto error_code = RET_OK;
if (param_act_grad_->type_ == schema::ActivationType_RELU) {
error_code =
ReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id);
error_code = ReluGrad(yt_addr + start, input_addr + start, count, output_addr + start);
} else if (param_act_grad_->type_ == schema::ActivationType_RELU6) {
error_code =
Relu6Grad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id);
error_code = Relu6Grad(yt_addr + start, input_addr + start, count, output_addr + start);
} else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) {
error_code = LReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count,
output_addr + stride * task_id, param_act_grad_->alpha_);
error_code = LReluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_);
} else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) {
// Sigmoid gets the input tensors in reverse order!
error_code =
SigmoidGrad(input_addr + stride * task_id, yt_addr + stride * task_id, count, output_addr + stride * task_id);
error_code = SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start);
} else if (param_act_grad_->type_ == schema::ActivationType_TANH) {
error_code =
TanhGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id);
error_code = TanhGrad(yt_addr + start, input_addr + start, count, output_addr + start);
} else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) {
error_code =
HSwishGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id);
error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start);
} else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) {
error_code =
HSigmoidGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id);
error_code = HSigmoidGrad(yt_addr + start, input_addr + start, count, output_addr + start);
} else {
MS_LOG(ERROR) << "Activation type error";
return RET_ERROR;
@ -97,7 +91,7 @@ int ActivationGradRun(void *cdata, int task_id) {
}
int ActivationGradCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRun, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Activation Grad function error error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -27,7 +27,7 @@ class ActivationGradCPUKernel : public LiteKernel {
explicit ActivationGradCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(param, inputs, outputs, ctx, primitive) {
: LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
param_act_grad_ = reinterpret_cast<ActivationParameter *>(param);
}
~ActivationGradCPUKernel() override = default;
@ -39,6 +39,7 @@ class ActivationGradCPUKernel : public LiteKernel {
private:
ActivationParameter *param_act_grad_;
int thread_count_;
};
} // namespace mindspore::kernel

View File

@ -33,17 +33,23 @@ namespace mindspore::kernel {
int AdamCPUKernel::ReSize() { return RET_OK; }
int AdamCPUKernel::Execute(int task_id) {
auto weight = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto m = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
auto v = reinterpret_cast<float *>(in_tensors_[2]->MutableData());
auto beta1_power = reinterpret_cast<float *>(in_tensors_[3]->MutableData())[0];
auto beta2_power = reinterpret_cast<float *>(in_tensors_[4]->MutableData())[0];
auto learning_rate = reinterpret_cast<float *>(in_tensors_[5]->MutableData())[0];
auto beta1 = reinterpret_cast<float *>(in_tensors_[6]->MutableData())[0];
auto beta2 = reinterpret_cast<float *>(in_tensors_[7]->MutableData())[0];
auto eps = reinterpret_cast<float *>(in_tensors_[8]->MutableData())[0];
auto gradient = reinterpret_cast<float *>(in_tensors_[9]->MutableData());
size_t elem_num = in_tensors_[0]->ElementsNum();
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto m = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto v = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
auto beta1_power = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData())[0];
auto beta2_power = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0];
auto learning_rate = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData())[0];
auto beta1 = reinterpret_cast<float *>(in_tensors_.at(6)->MutableData())[0];
auto beta2 = reinterpret_cast<float *>(in_tensors_.at(7)->MutableData())[0];
auto eps = reinterpret_cast<float *>(in_tensors_.at(8)->MutableData())[0];
auto gradient = reinterpret_cast<float *>(in_tensors_.at(9)->MutableData());
size_t length = in_tensors_.at(0)->ElementsNum();
size_t stride = UP_DIV(length, thread_count_);
size_t count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
size_t end = start + count;
if ((1.f - beta1_power) <= 0.0f) {
MS_LOG(ERROR) << "divisor cannot be 0 or below";
@ -55,17 +61,19 @@ int AdamCPUKernel::Execute(int task_id) {
}
auto update_lr = learning_rate * std::sqrt(1.f - beta2_power) / (1.f - beta1_power);
const float one_minus_beta1 = 1.f - beta1;
const float one_minus_beta2 = 1.f - beta2;
if (adam_param_->use_nesterov_) { // Nadam
for (size_t i = 0; i < elem_num; ++i) {
m[i] += (gradient[i] - m[i]) * (1.f - beta1);
v[i] += (gradient[i] * gradient[i] - v[i]) * (1.f - beta2);
weight[i] -= update_lr * (m[i] * beta1 + (1.f - beta1) * gradient[i]) / (std::sqrt(v[i]) + eps);
for (size_t i = start; i < end; ++i) {
m[i] += (gradient[i] - m[i]) * one_minus_beta1;
v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2;
weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (std::sqrt(v[i]) + eps);
}
} else {
for (size_t i = 0; i < elem_num; ++i) {
m[i] += (gradient[i] - m[i]) * (1.f - beta1);
v[i] += (gradient[i] * gradient[i] - v[i]) * (1.f - beta2);
for (size_t i = start; i < end; ++i) {
m[i] += (gradient[i] - m[i]) * one_minus_beta1;
v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2;
weight[i] -= update_lr * m[i] / (std::sqrt(v[i]) + eps);
}
}
@ -84,7 +92,7 @@ int AdamRun(void *cdata, int task_id) {
}
int AdamCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]";
return RET_ERROR;
@ -92,6 +100,17 @@ int AdamCPUKernel::Run() {
return RET_OK;
}
int AdamCPUKernel::SetLearningRate(float lr) {
auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData());
learning_rate_tensor[0] = lr;
return RET_OK;
}
float AdamCPUKernel::GetLearningRate() {
auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData());
return learning_rate_tensor[0];
}
int AdamCPUKernel::Init() { return RET_OK; }
kernel::LiteKernel *CpuAdamFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,

View File

@ -18,25 +18,28 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/train/optimizer_kernel.h"
#include "nnacl/fp32_grad/optimizer.h"
namespace mindspore::kernel {
class AdamCPUKernel : public LiteKernel {
class AdamCPUKernel : public OptimizerKernel {
public:
explicit AdamCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
: OptimizerKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
adam_param_ = reinterpret_cast<AdamParameter *>(parameter);
}
~AdamCPUKernel() override {}
int Init() override;
int ReSize() override;
int Run() override;
int SetLearningRate(float lr) override;
float GetLearningRate() override;
int Execute(int task_id);
private:
int thread_count_;
AdamParameter *adam_param_;
};
} // namespace mindspore::kernel

View File

@ -31,20 +31,26 @@ namespace mindspore::kernel {
int ApplyMomentumCPUKernel::ReSize() { return RET_OK; }
int ApplyMomentumCPUKernel::Execute(int task_id) {
auto weight = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto accumulate = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
float learning_rate = reinterpret_cast<float *>(in_tensors_[2]->MutableData())[0];
auto gradient = reinterpret_cast<float *>(in_tensors_[3]->MutableData());
float moment = reinterpret_cast<float *>(in_tensors_[4]->MutableData())[0];
size_t elem_num = in_tensors_[0]->ElementsNum();
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto accumulate = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
float learning_rate = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData())[0];
auto gradient = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData());
float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0];
size_t length = in_tensors_.at(0)->ElementsNum();
size_t stride = UP_DIV(length, thread_count_);
size_t count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
size_t end = start + count;
if (apply_momentum_param_->use_nesterov_) {
for (size_t i = 0; i < elem_num; ++i) {
for (size_t i = start; i < end; ++i) {
accumulate[i] = accumulate[i] * moment + gradient[i];
weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate;
}
} else {
for (size_t i = 0; i < elem_num; ++i) {
for (size_t i = start; i < end; ++i) {
accumulate[i] = accumulate[i] * moment + gradient[i];
weight[i] -= accumulate[i] * learning_rate;
}
@ -64,7 +70,7 @@ int ApplyMomentumRun(void *cdata, int task_id) {
}
int ApplyMomentumCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, ApplyMomentumRun, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, ApplyMomentumRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Apply Momentum function error error_code[" << error_code << "]";
return RET_ERROR;
@ -74,6 +80,17 @@ int ApplyMomentumCPUKernel::Run() {
int ApplyMomentumCPUKernel::Init() { return RET_OK; }
int ApplyMomentumCPUKernel::SetLearningRate(float lr) {
auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
learning_rate_tensor[0] = lr;
return RET_OK;
}
float ApplyMomentumCPUKernel::GetLearningRate() {
auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
return learning_rate_tensor[0];
}
kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter, const lite::InnerContext *ctx,

View File

@ -18,16 +18,18 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_APPLY_MOMENTUM_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/train/optimizer_kernel.h"
#include "nnacl/fp32_grad/optimizer.h"
namespace mindspore::kernel {
class ApplyMomentumCPUKernel : public LiteKernel {
class ApplyMomentumCPUKernel : public OptimizerKernel {
public:
explicit ApplyMomentumCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), apply_momentum_param_(nullptr) {
: OptimizerKernel(parameter, inputs, outputs, ctx, primitive),
thread_count_(ctx->thread_num_),
apply_momentum_param_(nullptr) {
apply_momentum_param_ = reinterpret_cast<ApplyMomentumParameter *>(parameter);
}
~ApplyMomentumCPUKernel() override {}
@ -35,8 +37,11 @@ class ApplyMomentumCPUKernel : public LiteKernel {
int ReSize() override;
int Run() override;
int Execute(int task_id);
int SetLearningRate(float lr) override;
float GetLearningRate() override;
private:
int thread_count_;
ApplyMomentumParameter *apply_momentum_param_;
};
} // namespace mindspore::kernel

View File

@ -49,27 +49,24 @@ int ArithmeticSelfGradCPUKernel::Init() {
return RET_OK;
}
int ArithmeticSelfGradCPUKernel::DoArithmeticSelfGrad(int thread_id) {
auto dy = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto in_x = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
auto dx = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
int dy_size = in_tensors_.at(0)->ElementsNum();
int size = MSMIN(thread_stride_, static_cast<int>(dy_size - thread_id * thread_stride_));
if (size <= 0) {
return RET_OK;
}
int offset = thread_id * thread_stride_;
(*self_grad_operation_)(dy + offset, in_x + offset, dx + offset, size);
int ArithmeticSelfGradCPUKernel::DoArithmeticSelfGrad(int task_id) {
auto dy = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto in_x = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto dx = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
size_t length = in_tensors_.at(0)->ElementsNum();
size_t stride = UP_DIV(length, thread_count_);
size_t count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
(*self_grad_operation_)(dy + start, in_x + start, dx + start, count);
return RET_OK;
}
int ArithmeticSelfGradCPUKernel::ReSize() { return RET_OK; }
int ArithmeticSelfGradCPUKernel::Run() {
int dy_size = in_tensors_.at(0)->ElementsNum();
op_parameter_->thread_num_ = MSMIN(op_parameter_->thread_num_, static_cast<int>(dy_size));
thread_stride_ = UP_DIV(dy_size, op_parameter_->thread_num_);
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticSelfGradRun, this, op_parameter_->thread_num_);
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticSelfGradRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "parallel launch fail!ret: " << ret;
return ret;

View File

@ -30,7 +30,7 @@ class ArithmeticSelfGradCPUKernel : public LiteKernel {
ArithmeticSelfGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {}
~ArithmeticSelfGradCPUKernel() override {}
int Init() override;
int ReSize() override;
@ -38,7 +38,7 @@ class ArithmeticSelfGradCPUKernel : public LiteKernel {
int DoArithmeticSelfGrad(int thread_id);
private:
int thread_stride_;
int thread_count_;
ArithmeticSelfGradOperation self_grad_operation_;
};
} // namespace mindspore::kernel

View File

@ -32,11 +32,16 @@ namespace mindspore::kernel {
int AssignCPUKernel::ReSize() { return RET_OK; }
int AssignCPUKernel::Execute(int task_id) {
auto x = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto y = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
size_t size = in_tensors_[0]->Size();
auto x = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto y = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
size_t length = in_tensors_.at(0)->ElementsNum();
memcpy(x, y, size);
size_t stride = UP_DIV(length, thread_count_);
size_t count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
memcpy(&(x[start]), &(y[start]), count * sizeof(float));
return RET_OK;
}
@ -52,7 +57,7 @@ int AssignRun(void *cdata, int task_id) {
}
int AssignCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -27,12 +27,15 @@ class AssignCPUKernel : public LiteKernel {
explicit AssignCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {}
~AssignCPUKernel() override {}
int Init() override;
int ReSize() override;
int Run() override;
int Execute(int task_id);
protected:
int thread_count_ = 1;
};
} // namespace mindspore::kernel

View File

@ -29,7 +29,7 @@ using mindspore::schema::PrimitiveType_BiasGrad;
namespace mindspore::kernel {
int BiasGradCPUKernel::Init() {
int BiasGradCPUKernel::ReSize() {
auto dims = in_tensors_[0]->shape();
bias_param->ndim_ = dims.size();
for (unsigned int i = 0; i < bias_param->ndim_; i++) {
@ -44,7 +44,12 @@ int BiasGradCPUKernel::Init() {
return RET_OK;
}
int BiasGradCPUKernel::ReSize() { return RET_OK; }
int BiasGradCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int BiasGradCPUKernel::Execute(int task_id) {
auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());

View File

@ -31,17 +31,16 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BNGrad;
namespace mindspore::kernel {
int BNGradCPUKernel::Init() {
int BNGradCPUKernel::ReSize() {
auto *input_x = in_tensors_.at(1);
int channels = input_x->shape().at(kNHWC_C);
set_workspace_size(2 * channels * sizeof(float));
return RET_OK;
}
int BNGradCPUKernel::ReSize() { return RET_OK; }
int BNGradCPUKernel::Init() { return ReSize(); }
int BNGradCPUKernel::Execute(int task_id) {
auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_);
auto *input_yt = in_tensors_.at(0);
auto *input_x = in_tensors_.at(1);
auto *input_scale = in_tensors_.at(2);
@ -54,10 +53,9 @@ int BNGradCPUKernel::Execute(int task_id) {
auto *output_dx = out_tensors_.at(0);
auto *output_scale = out_tensors_.at(1);
auto *output_bias = out_tensors_.at(2);
size_t batch = input_x->Batch();
size_t channels = input_x->Channel();
size_t spatial = input_x->Height() * input_x->Width();
float eps = bn_param->epsilon_;
int32_t batch = input_x->Batch();
int32_t channels = input_x->Channel();
int32_t spatial = input_x->Height() * input_x->Width();
float *workspace_temp = static_cast<float *>(workspace());
std::fill(workspace_temp, workspace_temp + workspace_size() / sizeof(*workspace_temp), 0.f);
@ -68,34 +66,32 @@ int BNGradCPUKernel::Execute(int task_id) {
float *yt = reinterpret_cast<float *>(input_yt->MutableData());
float *scale = reinterpret_cast<float *>(input_scale->MutableData());
float *dx = reinterpret_cast<float *>(output_dx->MutableData());
float *dscale = reinterpret_cast<float *>(output_scale->MutableData());
float *dbias = reinterpret_cast<float *>(output_bias->MutableData());
var2Invar(save_var, input_var->ElementsNum(), eps);
// dx
backwardX(x, yt, scale, batch * spatial, channels, save_mean, save_var, dxhat_sum, dxhathat_sum, dx);
// dbias
sumSpatialBatch(yt, batch * spatial, channels, dbias);
// dscale
backwardScale(x, save_mean, save_var, yt, batch, channels, spatial, dscale);
float *dscale = reinterpret_cast<float *>(output_scale->MutableData());
std::fill(dbias, dbias + channels, 0.f);
std::fill(dscale, dscale + channels, 0.f);
backwardAll(x, yt, save_mean, save_var, scale, batch * spatial, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx);
return RET_OK;
}
int BNGradRun(void *cdata, int task_id) {
MS_ASSERT(cdata != nullptr);
auto bn_kernel = reinterpret_cast<BNGradCPUKernel *>(cdata);
if (task_id == 0) {
auto error_code = bn_kernel->Execute(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
auto error_code = bn_kernel->Execute(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int BNGradCPUKernel::Run() {
auto *input_var = in_tensors_.at(4);
float *save_var = reinterpret_cast<float *>(input_var->MutableData());
auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_);
float eps = bn_param->epsilon_;
var2Invar(save_var, input_var->ElementsNum(), eps);
int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]";

View File

@ -26,7 +26,7 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int ConvolutionTrainCPUKernel::Init() {
int ConvolutionTrainCPUKernel::ReSize() {
if (in_tensors_.size() < 2) {
MS_LOG(ERROR) << "Convolution should have at least two inputs";
return RET_ERROR;
@ -54,13 +54,21 @@ int ConvolutionTrainCPUKernel::Init() {
conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_;
const int n = conv_param_->output_channel_ * conv_param_->group_;
const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_;
ws_size = chunk * k;
int mat_alloc = MatSizeTotal(chunk, n, k, 0);
set_workspace_size((ws_size + mat_alloc) * sizeof(float));
ws_size_ = chunk_ * k;
int mat_alloc = MatSizeTotal(chunk_, n, k, 0);
set_workspace_size((ws_size_ + mat_alloc) * sizeof(float));
do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) &&
(conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) &&
(conv_param_->dilation_h_ == 1) && (conv_param_->dilation_w_ == 1) &&
(conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1)
? false
: true;
return RET_OK;
}
int ConvolutionTrainCPUKernel::ReSize() { return RET_OK; }
int ConvolutionTrainCPUKernel::Init() { return ReSize(); }
int ConvolutionTrainCPUKernel::Execute(int task_id) {
auto conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter_);
@ -87,17 +95,34 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) {
const int n = out_ch / groups;
const int k = k_h * k_w * in_ch / groups;
float *workspace_temp = static_cast<float *>(workspace());
float *mat_workspace = workspace_temp + ws_size;
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < groups; ++j) {
for (int ci = 0; ci < m; ci += chunk) {
int real_chunk = MSMIN(m - ci, chunk);
float *mat_a = workspace_temp;
const float *mat_b = w_addr + j * nweights / groups;
float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch;
float *im = x_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups);
RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci);
GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a, k, mat_b, k, 0, mat_c, out_ch, mat_workspace);
float *mat_workspace = workspace_temp + ws_size_;
if (do_img2col_) {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < groups; ++j) {
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_a = workspace_temp;
const float *mat_b = w_addr + j * nweights / groups;
float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch;
float *im = x_addr + i * in_ch * in_h * in_w + j * (in_ch / groups);
RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci);
GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a, k, mat_b, k, 0, mat_c, out_ch, mat_workspace);
}
}
}
} else {
const float *mat_b = w_addr;
const size_t in_plane_size = in_ch * in_h * in_w;
for (int i = 0; i < batch; ++i) {
float *im = x_addr + i * in_plane_size;
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_c = y_addr + i * n * m + ci * out_ch;
int input_height = ci / out_w * conv_param_->stride_h_;
int input_width = ci % out_w * conv_param_->stride_w_;
int offset = (input_height * in_w + input_width) * in_ch;
GemmMatmul(0, 1, real_chunk, n, k, 1, im + offset, k, mat_b, k, 0, mat_c, out_ch, mat_workspace);
}
}
}

View File

@ -35,11 +35,12 @@ class ConvolutionTrainCPUKernel : public LiteKernel {
int Execute(int task_id);
private:
int ws_size = 0;
int ws_size_ = 0;
bool do_img2col_ = true;
#ifdef ENABLE_ARM32
const int chunk = C4NUM;
const int chunk_ = C4NUM * 2;
#else
const int chunk = C12NUM;
const int chunk_ = C12NUM * 2;
#endif
};

View File

@ -29,7 +29,7 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2DGradFilter;
namespace mindspore::kernel {
int ConvolutionGradFilterCPUKernel::Init() {
int ConvolutionGradFilterCPUKernel::ReSize() {
// dy is in input 0
// x is in input 1
// dw is output 0
@ -51,16 +51,25 @@ int ConvolutionGradFilterCPUKernel::Init() {
conv_param->output_h_ = dy_tensor->shape()[kNHWC_H];
conv_param->output_w_ = dy_tensor->shape()[kNHWC_W];
ws_size = chunk * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
int k = conv_param->output_channel_ / conv_param->group_;
size_t mat_alloc = MatSizeTotal(k, n, chunk, n);
set_workspace_size((ws_size + mat_alloc) * sizeof(float));
int thread_num = context_->thread_num_;
mat_alloc_ = MatSizeTotal(k, n, chunk_, 0);
set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float));
do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) &&
(conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) &&
(conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) &&
(conv_param->stride_w_ == 1) && (conv_param->group_ == 1)
? false
: true;
return RET_OK;
}
int ConvolutionGradFilterCPUKernel::ReSize() { return RET_OK; }
int ConvolutionGradFilterCPUKernel::Init() { return ReSize(); }
int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_);
@ -72,7 +81,6 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
auto dy_addr = reinterpret_cast<float *>(input_dy->MutableData());
auto dw_addr = reinterpret_cast<float *>(out_dw->MutableData());
int i, j;
int nweights = out_dw->ElementsNum();
int in_ch = conv_param->input_channel_;
int in_h = conv_param->input_h_;
@ -88,22 +96,45 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
int m = out_h * out_w;
int n = k_h * k_w * in_ch / groups;
int k = out_ch / groups;
int thread_num = context_->thread_num_;
float *workspace_temp = reinterpret_cast<float *>(workspace());
float *mat_workspace = workspace_temp + ws_size;
// zero out pointer
memset(dw_addr, 0, out_dw->Size());
for (i = 0; i < batch; ++i) {
for (j = 0; j < groups; ++j) {
for (int ci = 0; ci < m; ci += chunk) {
int real_chunk = MSMIN(m - ci, chunk);
float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
float *mat_b = workspace_temp;
float *mat_c = dw_addr + j * nweights / groups;
float *im = x_addr + (i * in_ch * in_h * in_w) + j * (in_ch / groups);
memset(mat_b, 0, n * real_chunk * sizeof(float));
RollingIm2ColPackUnitFp32(im, conv_param, mat_b, real_chunk, ci);
GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b, n, 1, mat_c, n, mat_workspace);
float *mat_workspace = workspace_temp + ws_size_ * thread_num + task_id * (mat_alloc_ + k * n);
float *mat_tmp = mat_workspace + mat_alloc_;
int stride = UP_DIV(batch, thread_num);
int count = MSMIN(stride, batch - stride * task_id);
int start = stride * task_id;
int end = start + count;
if (do_img2col_) {
for (int i = start; i < end; ++i) {
for (int j = 0; j < groups; ++j) {
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
float *mat_b = workspace_temp + task_id * ws_size_;
float *mat_c = dw_addr + j * nweights / groups;
float *im = x_addr + (i * in_ch * in_h * in_w) + j * (in_ch / groups);
RollingIm2ColPackUnitFp32(im, conv_param, mat_b, real_chunk, ci);
GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b, n, 0, mat_tmp, n, mat_workspace);
std::unique_lock<std::mutex> merge_lock(lock_);
AddMatrix(mat_tmp, mat_c, 1, k, n, n);
}
}
}
} else {
float *mat_c = dw_addr;
const size_t in_plane_size = in_ch * in_h * in_w;
for (int i = start; i < end; ++i) {
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_a = dy_addr + i * m * k + ci * out_ch;
float *im = x_addr + i * in_plane_size;
int input_h = ci / out_w * conv_param->stride_h_;
int input_w = ci % out_w * conv_param->stride_w_;
int offset = (input_h * in_w + input_w) * in_ch;
GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, im + offset, n, 0, mat_tmp, n, mat_workspace);
std::unique_lock<std::mutex> merge_lock(lock_);
AddMatrix(mat_tmp, mat_c, 1, k, n, n);
}
}
}
@ -122,7 +153,10 @@ int ConvolutionGradFilterRun(void *cdata, int task_id) {
}
int ConvolutionGradFilterCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradFilterRun, this, 1);
auto *out_dw = out_tensors_.at(0);
auto dw_addr = reinterpret_cast<float *>(out_dw->MutableData());
memset(dw_addr, 0, out_dw->Size());
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradFilterRun, this, context_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -36,11 +36,14 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel {
int Execute(int task_id);
private:
size_t ws_size = 0;
size_t ws_size_ = 0;
bool do_img2col_ = true;
std::mutex lock_;
size_t mat_alloc_ = 0;
#ifdef ENABLE_ARM32
const int chunk = C4NUM;
const int chunk_ = C4NUM * 2;
#else
const int chunk = C12NUM;
const int chunk_ = C12NUM * 2;
#endif
};
} // namespace mindspore::kernel

View File

@ -30,7 +30,7 @@ using mindspore::schema::PrimitiveType_Conv2DGradInput;
using mindspore::schema::PrimitiveType_GroupConv2DGradInput;
namespace mindspore::kernel {
int ConvolutionGradInputCPUKernel::Init() {
int ConvolutionGradInputCPUKernel::ReSize() {
auto *dy_tensor = in_tensors_.at(kInputIndex);
MS_ASSERT(dy_tensor != nullptr);
auto *weight_tensor = in_tensors_.at(kWeightIndex);
@ -51,18 +51,17 @@ int ConvolutionGradInputCPUKernel::Init() {
conv_param->output_h_ = dy_tensor->shape()[kNHWC_H];
conv_param->output_w_ = dy_tensor->shape()[kNHWC_W];
ws_size = chunk * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_;
int n = conv_param->kernel_w_ * conv_param->kernel_h_ * conv_param->input_channel_ / conv_param->group_;
int k = conv_param->output_channel_ / conv_param->group_;
size_t mat_alloc = MatSizeTotal(chunk, n, k, 0);
set_workspace_size((ws_size + mat_alloc) * sizeof(float));
int thread_num = context_->thread_num_;
mat_alloc_ = MatSizeTotal(chunk_, n, k, 0);
set_workspace_size((ws_size_ + mat_alloc_) * sizeof(float) * thread_num);
return RET_OK;
}
int ConvolutionGradInputCPUKernel::ReSize() { return RET_OK; }
int ConvolutionGradInputCPUKernel::Init() { return ReSize(); }
int ConvolutionGradInputCPUKernel::Execute(int task_id) {
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_);
@ -86,17 +85,21 @@ int ConvolutionGradInputCPUKernel::Execute(int task_id) {
int groups = conv_param->group_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int thread_num = context_->thread_num_;
int m = out_h * out_w;
int n = k_w * k_h * in_ch / groups;
int k = out_ch / groups;
float *workspace_temp = reinterpret_cast<float *>(workspace());
float *mat_workspace = workspace_temp + ws_size;
memset(dx_addr, 0, sizeof(float) * batch * in_ch * in_h * in_w);
for (i = 0; i < batch; ++i) {
float *workspace_temp = reinterpret_cast<float *>(workspace()) + task_id * (mat_alloc_ + ws_size_);
float *mat_workspace = workspace_temp + ws_size_;
int stride = UP_DIV(batch, thread_num);
int count = MSMIN(stride, batch - stride * task_id);
int start = stride * task_id;
int end = start + count;
for (i = start; i < end; ++i) {
for (j = 0; j < groups; ++j) {
GemmCb gcb;
for (int ci = 0; ci < m; ci += chunk) {
for (int ci = 0; ci < m; ci += chunk_) {
float *mat_b = nullptr;
if (ci == 0) {
mat_b = w_addr + j * nweights / groups;
@ -108,7 +111,7 @@ int ConvolutionGradInputCPUKernel::Execute(int task_id) {
mat_b = gcb.mat_b;
gcb.cb = 1;
}
int real_chunk = MSMIN(m - ci, chunk);
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
float *mat_c = workspace_temp;
GemmMatmulPlus(0, 0, real_chunk, n, k, 1, mat_a, out_ch, mat_b, n, 0, mat_c, n, mat_workspace, &gcb);
@ -133,7 +136,15 @@ int ConvolutionGradInputRun(void *cdata, int task_id) {
}
int ConvolutionGradInputCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradInputRun, this, 1);
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_);
int batch = conv_param->output_batch_;
int in_ch = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
auto *out_dx = out_tensors_.at(0);
auto dx_addr = reinterpret_cast<float *>(out_dx->MutableData());
memset(dx_addr, 0, sizeof(float) * batch * in_ch * in_h * in_w);
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradInputRun, this, context_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -35,11 +35,12 @@ class ConvolutionGradInputCPUKernel : public LiteKernel {
int Execute(int task_id);
private:
size_t ws_size = 0;
size_t ws_size_ = 0;
size_t mat_alloc_ = 0;
#ifdef ENABLE_ARM32
const int chunk = C4NUM;
const int chunk_ = C4NUM;
#else
const int chunk = C12NUM;
const int chunk_ = C12NUM;
#endif
};
} // namespace mindspore::kernel

View File

@ -61,19 +61,26 @@ int DropoutCPUKernel::Execute(int task_id) {
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
auto mask = reinterpret_cast<float *>(out_tensors_.at(1)->MutableData());
auto length = in_tensors_.at(kInputIndex)->ElementsNum();
auto param = reinterpret_cast<DropoutParameter *>(op_parameter_);
auto length = in_tensors_.at(kInputIndex)->ElementsNum();
int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
size_t end = start + count;
if (param == nullptr) {
MS_LOG(ERROR) << "Dropout op_parameter_ nullptr";
return RET_NULL_PTR;
}
if (IsEval()) {
std::copy(input_ptr, input_ptr + length, output_ptr);
std::copy(&(input_ptr[start]), &(input_ptr[end]), &(output_ptr[start]));
} else {
std::default_random_engine generator;
std::bernoulli_distribution distribution(param->ratio_);
for (int i = 0; i < length; i++) {
for (size_t i = start; i < end; i++) {
mask[i] = distribution(generator);
output_ptr[i] = input_ptr[i] * mask[i] * scale_;
}
@ -92,7 +99,7 @@ int RunDropout(void *cdata, int task_id) {
}
int DropoutCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropout, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropout, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Dropout function error error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -25,7 +25,7 @@ class DropoutCPUKernel : public LiteKernel {
DropoutCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {}
~DropoutCPUKernel() override = default;
@ -35,7 +35,8 @@ class DropoutCPUKernel : public LiteKernel {
int Execute(int task_id);
private:
float scale_;
float scale_ = 1.0;
int thread_count_ = 1;
};
} // namespace mindspore::kernel

View File

@ -62,7 +62,13 @@ int DropoutGradCPUKernel::Execute(int task_id) {
auto mask_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
auto length = in_tensors_.at(kInputIndex)->ElementsNum();
DropoutGrad(yt_ptr, mask_ptr, output_ptr, length, scale_);
int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
DropoutGrad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, scale_);
return RET_OK;
}
@ -78,7 +84,7 @@ int RunDropoutGrad(void *cdata, int task_id) {
}
int DropoutGradCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropoutGrad, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropoutGrad, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Dropout Grad function error error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -25,7 +25,7 @@ class DropoutGradCPUKernel : public LiteKernel {
DropoutGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {}
~DropoutGradCPUKernel() override = default;
@ -36,6 +36,7 @@ class DropoutGradCPUKernel : public LiteKernel {
private:
float scale_;
int thread_count_ = 1;
};
} // namespace mindspore::kernel

View File

@ -29,36 +29,34 @@ using mindspore::schema::PrimitiveType_NegGrad;
namespace mindspore::kernel {
namespace {
int NegGradRun(void *cdata, int thread_id) {
int NegGradRun(void *cdata, int task_id) {
MS_ASSERT(cdata != nullptr);
auto kernel = reinterpret_cast<NegGradCPUKernel *>(cdata);
MS_ASSERT(kernel != nullptr);
return kernel->DoNegGrad(thread_id);
return kernel->DoNegGrad(task_id);
}
} // namespace
int NegGradCPUKernel::Init() { return RET_OK; }
int NegGradCPUKernel::DoNegGrad(int thread_id) {
auto dy = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto dx = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
int dy_size = in_tensors_.at(0)->ElementsNum();
int size = MSMIN(thread_stride_, static_cast<int>(dy_size - thread_id * thread_stride_));
if (size <= 0) {
return RET_OK;
}
int offset = thread_id * thread_stride_;
ElementNegative(dy + offset, dx + offset, size);
int NegGradCPUKernel::DoNegGrad(int task_id) {
auto dy = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto dx = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
size_t length = in_tensors_.at(0)->ElementsNum();
size_t stride = UP_DIV(length, thread_count_);
size_t count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
ElementNegative(dy + start, dx + start, count);
return RET_OK;
}
int NegGradCPUKernel::ReSize() { return RET_OK; }
int NegGradCPUKernel::Run() {
int dy_size = in_tensors_.at(0)->ElementsNum();
op_parameter_->thread_num_ = MSMIN(op_parameter_->thread_num_, static_cast<int>(dy_size));
thread_stride_ = UP_DIV(dy_size, op_parameter_->thread_num_);
auto ret = ParallelLaunch(this->context_->thread_pool_, NegGradRun, this, op_parameter_->thread_num_);
auto ret = ParallelLaunch(this->context_->thread_pool_, NegGradRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "parallel launch fail!ret: " << ret;
return ret;

View File

@ -28,7 +28,7 @@ class NegGradCPUKernel : public LiteKernel {
explicit NegGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {}
~NegGradCPUKernel() override {}
int Init() override;
int ReSize() override;
@ -36,7 +36,7 @@ class NegGradCPUKernel : public LiteKernel {
int DoNegGrad(int thread_id);
private:
int thread_stride_;
int thread_count_;
};
} // namespace mindspore::kernel

View File

@ -29,7 +29,7 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_PoolingGrad;
namespace mindspore::kernel {
int PoolingGradCPUKernel::Init() {
int PoolingGradCPUKernel::ReSize() {
PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_);
auto in_shape = in_tensors_.at(0)->shape();
@ -59,7 +59,7 @@ int PoolingGradCPUKernel::Init() {
return RET_OK;
}
int PoolingGradCPUKernel::ReSize() { return RET_OK; }
int PoolingGradCPUKernel::Init() { return ReSize(); }
int PoolingGradCPUKernel::Execute(int task_id) {
PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_);

View File

@ -45,13 +45,20 @@ int PowerGradCPUKernel::Execute(int task_id) {
auto dy_addr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto x_addr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto dx_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
auto size = in_tensors_.at(0)->ElementsNum();
size_t length = in_tensors_.at(0)->ElementsNum();
size_t stride = UP_DIV(length, thread_count_);
size_t count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
size_t end = start + count;
float exp = power_ - 1;
Power(x_addr, &exp, dx_addr, size, scale_, shift_, true);
ElementMul(dx_addr, dy_addr, dx_addr, size);
Power(&(x_addr[start]), &exp, &(dx_addr[start]), count, scale_, shift_, true);
ElementMul(&(dx_addr[start]), &(dy_addr[start]), &(dx_addr[start]), count);
float scale = scale_ * power_;
for (int i = 0; i < size; i++) {
for (size_t i = start; i < end; i++) {
dx_addr[i] *= scale;
}
@ -69,7 +76,7 @@ int PowerGradRun(void *cdata, int task_id) {
}
int PowerGradCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, PowerGradRun, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, PowerGradRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "power grad function error error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -27,7 +27,7 @@ class PowerGradCPUKernel : public LiteKernel {
PowerGradCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(param, inputs, outputs, ctx, primitive) {
: LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
PowerParameter *power_param = reinterpret_cast<PowerParameter *>(param);
power_ = power_param->power_;
scale_ = power_param->scale_;
@ -41,6 +41,7 @@ class PowerGradCPUKernel : public LiteKernel {
int Execute(int task_id);
private:
int thread_count_;
float power_;
float scale_;
float shift_;

View File

@ -16,6 +16,7 @@
*/
#include "src/runtime/kernel/arm/fp32_grad/sgd.h"
#include <algorithm>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -37,36 +38,42 @@ int SgdCPUKernel::Execute(int task_id) {
float learning_rate = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData())[0];
auto gradient = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0];
size_t elem_num = in_tensors_.at(0)->ElementsNum();
auto stat = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData());
size_t length = in_tensors_.at(0)->ElementsNum();
if (stat[0] > 0) {
stat[0] = 0;
memcpy(accumulate, gradient, elem_num * sizeof(float));
size_t stride = UP_DIV(length, thread_count_);
size_t count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
size_t end = start + count;
if (stat[task_id] > 0) {
stat[task_id] = 0; // Haim Please approve this
std::copy(&(gradient[start]), &(gradient[end]), &(accumulate[start]));
if (sgd_param_->use_nesterov_) {
for (size_t i = 0; i < elem_num; ++i) {
for (size_t i = start; i < end; ++i) {
weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate;
}
} else {
for (size_t i = 0; i < elem_num; ++i) {
for (size_t i = start; i < end; ++i) {
weight[i] -= accumulate[i] * learning_rate;
}
}
} else {
if (moment > 0.f) {
if (sgd_param_->use_nesterov_) {
for (size_t i = 0; i < elem_num; ++i) {
for (size_t i = start; i < end; ++i) {
accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_);
weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate;
}
} else {
for (size_t i = 0; i < elem_num; ++i) {
for (size_t i = start; i < end; ++i) {
accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_);
weight[i] -= accumulate[i] * learning_rate;
}
}
} else {
for (size_t i = 0; i < elem_num; ++i) {
for (size_t i = start; i < end; ++i) {
weight[i] -= gradient[i] * learning_rate;
}
}
@ -85,7 +92,7 @@ int SgdRun(void *cdata, int task_id) {
}
int SgdCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, SgdRun, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, SgdRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "SGD function error error_code[" << error_code << "]";
return RET_ERROR;
@ -114,6 +121,17 @@ int SgdCPUKernel::Init() {
return RET_OK;
}
int SgdCPUKernel::SetLearningRate(float lr) {
auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
learning_rate_tensor[0] = lr;
return RET_OK;
}
float SgdCPUKernel::GetLearningRate() {
auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
return learning_rate_tensor[0];
}
kernel::LiteKernel *CpuSgdFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,

View File

@ -18,16 +18,18 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SGD_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/train/optimizer_kernel.h"
#include "nnacl/fp32_grad/optimizer.h"
namespace mindspore::kernel {
class SgdCPUKernel : public LiteKernel {
class SgdCPUKernel : public OptimizerKernel {
public:
explicit SgdCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), sgd_param_(nullptr) {
: OptimizerKernel(parameter, inputs, outputs, ctx, primitive),
thread_count_(ctx->thread_num_),
sgd_param_(nullptr) {
sgd_param_ = reinterpret_cast<SgdParameter *>(parameter);
}
~SgdCPUKernel() override {}
@ -35,8 +37,11 @@ class SgdCPUKernel : public LiteKernel {
int ReSize() override;
int Run() override;
int Execute(int task_id);
int SetLearningRate(float lr) override;
float GetLearningRate() override;
private:
int thread_count_;
SgdParameter *sgd_param_;
};
} // namespace mindspore::kernel

View File

@ -35,12 +35,19 @@ int SmoothL1LossCPUKernel::Execute(int task_id) {
auto target = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
const size_t tensor_len = in_tensors_.at(0)->ElementsNum();
const size_t length = in_tensors_.at(0)->ElementsNum();
size_t stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
size_t end = start + count;
const float zero = 0.0f;
const float half = 0.5f;
const float beta = smooth_l1_loss_param->beta_;
for (uint64_t i = 0; i < tensor_len; ++i) {
for (uint64_t i = start; i < end; ++i) {
float diff = predict[i] - target[i];
if (diff < zero) {
diff = -diff;
@ -66,7 +73,7 @@ int SmoothL1LossRun(void *cdata, int task_id) {
}
int SmoothL1LossCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossRun, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "SmoothL1Loss function error error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -27,7 +27,9 @@ class SmoothL1LossCPUKernel : public LiteKernel {
explicit SmoothL1LossCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), smooth_l1_param_(nullptr) {
: LiteKernel(parameter, inputs, outputs, ctx, primitive),
smooth_l1_param_(nullptr),
thread_count_(ctx->thread_num_) {
smooth_l1_param_ = reinterpret_cast<SmoothL1LossParameter *>(parameter);
}
~SmoothL1LossCPUKernel() override {}
@ -38,6 +40,7 @@ class SmoothL1LossCPUKernel : public LiteKernel {
private:
SmoothL1LossParameter *smooth_l1_param_;
int thread_count_ = 1;
};
} // namespace mindspore::kernel

View File

@ -36,10 +36,17 @@ int SmoothL1LossGradCPUKernel::Execute(int task_id) {
auto d_loss = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
const size_t tensor_len = in_tensors_.at(0)->ElementsNum();
const size_t length = in_tensors_.at(0)->ElementsNum();
size_t stride = UP_DIV(length, thread_count_);
size_t count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
size_t end = start + count;
const float beta = smooth_l1_loss_param->beta_;
for (uint64_t i = 0; i < tensor_len; ++i) {
for (uint64_t i = start; i < end; ++i) {
float diff = predict[i] - target[i];
if (diff > beta) {
out[i] = d_loss[i];
@ -63,7 +70,7 @@ int SmoothL1LossGradRun(void *cdata, int task_id) {
}
int SmoothL1LossGradCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossGradRun, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossGradRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "SmoothL1LossGrad function error error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -27,7 +27,9 @@ class SmoothL1LossGradCPUKernel : public LiteKernel {
explicit SmoothL1LossGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), smooth_l1_param_(nullptr) {
: LiteKernel(parameter, inputs, outputs, ctx, primitive),
smooth_l1_param_(nullptr),
thread_count_(ctx->thread_num_) {
smooth_l1_param_ = reinterpret_cast<SmoothL1LossParameter *>(parameter);
}
~SmoothL1LossGradCPUKernel() override {}
@ -38,6 +40,7 @@ class SmoothL1LossGradCPUKernel : public LiteKernel {
private:
SmoothL1LossParameter *smooth_l1_param_;
int thread_count_;
};
} // namespace mindspore::kernel

View File

@ -29,7 +29,7 @@ using mindspore::schema::PrimitiveType_SoftmaxCrossEntropy;
namespace mindspore::kernel {
int SoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; }
int SoftmaxCrossEntropyWithLogitsCPUKernel::Init() { return ReSize(); }
void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *labels, const float *logits, float *grads,
float *output2) const {
@ -100,7 +100,7 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
return RET_OK;
}
int SoftmaxCrossEntropyWithLogitsCPUKernel::Init() {
int SoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() {
auto dims = in_tensors_.at(0)->shape();
param_->n_dim_ = 2;
param_->number_of_classes_ = dims.at(1);

View File

@ -15,6 +15,7 @@
*/
#include <vector>
#include <algorithm>
#include "src/runtime/kernel/arm/fp32_grad/tuple_getitem.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -47,7 +48,15 @@ int TupleGetItemCPUKernel::Execute(int task_id) {
auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
memcpy(out, in, in_tensors_.at(0)->Size());
size_t length = in_tensors_.at(0)->ElementsNum();
size_t stride = UP_DIV(length, thread_count_);
size_t count = MSMIN(stride, length - stride * task_id);
size_t start = stride * task_id;
size_t end = start + count;
std::copy(&(in[start]), &(in[end]), &(out[start]));
return RET_OK;
}
@ -62,7 +71,7 @@ int TupleRun(void *cdata, int task_id) {
}
int TupleGetItemCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, TupleRun, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, TupleRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "tuple function error error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -27,7 +27,7 @@ class TupleGetItemCPUKernel : public LiteKernel {
explicit TupleGetItemCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
param = parameter;
}
~TupleGetItemCPUKernel() override = default;
@ -38,6 +38,7 @@ class TupleGetItemCPUKernel : public LiteKernel {
int Execute(int task_id);
private:
int thread_count_ = 1;
OpParameter *param;
};
} // namespace mindspore::kernel

View File

@ -0,0 +1,98 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/train/classification_train_accuracy_monitor.h"
#include <sys/stat.h>
#include <algorithm>
#include <utility>
#include <vector>
#include <iostream>
#include <fstream>
#include <memory>
#include "include/errorcode.h"
#include "include/train_session.h"
#include "src/common/utils.h"
#include "src/tensor.h"
#include "src/train/loss_kernel.h"
#include "src/train/optimizer_kernel.h"
#include "src/sub_graph_kernel.h"
#include "src/train/train_populate_parameter.h"
#include "src/runtime/runtime_api.h"
#include "src/executor.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp32_grad/convolution.h"
namespace mindspore {
namespace lite {
void ClassificationTrainAccuracyMonitor::Begin(const session::TrainLoopCallBackData &cb_data) {
if (cb_data.epoch_ == 0) accuracies_.clear();
}
void ClassificationTrainAccuracyMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) {
if (accuracies_.size() != cb_data.epoch_) {
MS_LOG(WARNING) << "Accuracies array does not match epoch number";
} else {
accuracies_.push_back(std::make_pair(cb_data.epoch_, 0.0));
}
}
int ClassificationTrainAccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
if (cb_data.step_ > 0) accuracies_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_);
if ((cb_data.epoch_ + 1) % print_every_n_ == 0) {
std::cout << cb_data.epoch_ + 1 << ":\tTraining Accuracy is " << accuracies_.at(cb_data.epoch_).second << std::endl;
}
return mindspore::session::RET_CONTINUE;
}
void ClassificationTrainAccuracyMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) {
auto inputs = cb_data.session_->GetInputs();
auto outputs = cb_data.session_->GetPredictions();
auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
if (it->second->ElementsNum() == inputs.at(1)->ElementsNum()) {
int batch_size = inputs.at(1)->shape().at(0);
int num_of_classes = inputs.at(1)->shape().at(1);
auto predictions = reinterpret_cast<float *>(it->second->MutableData());
float accuracy = 0.0;
for (int b = 0; b < batch_size; b++) {
int label = 0;
int max_idx = 0;
float max_label_score = labels[num_of_classes * b];
float max_score = predictions[num_of_classes * b];
for (int c = 1; c < num_of_classes; c++) {
if (predictions[num_of_classes * b + c] > max_score) {
max_score = predictions[num_of_classes * b + c];
max_idx = c;
}
if (labels[num_of_classes * b + c] > max_label_score) {
max_label_score = labels[num_of_classes * b + c];
label = c;
}
}
if (label == max_idx) accuracy += 1.0;
}
accuracy /= static_cast<float>(batch_size);
accuracies_.at(cb_data.epoch_).second = accuracy;
return;
}
}
MS_LOG(WARNING) << "Model does not have a loss output tensor of size 1";
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,66 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/train/loss_monitor.h"
#include <sys/stat.h>
#include <algorithm>
#include <utility>
#include <vector>
#include <iostream>
#include <fstream>
#include <memory>
#include "include/errorcode.h"
#include "include/train_session.h"
#include "src/common/utils.h"
#include "src/tensor.h"
namespace mindspore {
namespace lite {
void LossMonitor::Begin(const session::TrainLoopCallBackData &cb_data) {
if (cb_data.epoch_ == 0) losses_.clear();
}
void LossMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) {
if (losses_.size() != cb_data.epoch_) {
MS_LOG(WARNING) << "losses array does not match epoch number";
} else {
losses_.push_back(std::make_pair(cb_data.epoch_, 0.0));
}
}
int LossMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
if (cb_data.step_ > 0) losses_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_);
if ((cb_data.epoch_ + 1) % print_every_n_ == 0) {
std::cout << cb_data.epoch_ + 1 << ":\tLoss is " << losses_.at(cb_data.epoch_).second << std::endl;
}
return mindspore::session::RET_CONTINUE;
}
void LossMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) {
auto outputs = cb_data.session_->GetOutputs();
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
if (it->second->ElementsNum() == 1) {
auto loss = reinterpret_cast<float *>(it->second->MutableData());
losses_.at(cb_data.epoch_).second += loss[0];
return;
}
}
MS_LOG(WARNING) << "Model does not have a loss output tensor of size 1";
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,75 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/train/lr_scheduler.h"
#include <sys/stat.h>
#include <algorithm>
#include <utility>
#include <vector>
#include <iostream>
#include <fstream>
#include <memory>
#include "include/errorcode.h"
#include "include/train_session.h"
#include "src/common/utils.h"
#include "src/tensor.h"
namespace mindspore {
namespace lite {
int MultiplicativeLRLambda(float *lr, int epoch, void *lr_cb_data) {
if ((lr == nullptr) || (lr_cb_data == nullptr)) {
MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda";
return DONT_UPDATE_LR;
}
float mult = *(static_cast<float *>(lr_cb_data));
*lr = *lr * mult;
return UPDATE_LR;
}
int StepLRLambda(float *lr, int epoch, void *lr_cb_data) {
if ((lr == nullptr) || (lr_cb_data == nullptr)) {
MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda";
return DONT_UPDATE_LR;
}
struct StepLRLambda *step_lr_data = (static_cast<struct StepLRLambda *>(lr_cb_data));
if (((epoch + 1) % step_lr_data->step_size) == 0) {
*lr = *lr * step_lr_data->gamma;
return UPDATE_LR;
}
return DONT_UPDATE_LR;
}
LRScheduler::LRScheduler(LR_Lambda lambda_func, void *lr_cb_data, int step)
: lambda_func_(lambda_func), lr_data_(lr_cb_data), step_(step) {}
int LRScheduler::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
if (((cb_data.epoch_ + 1) % step_) == 0) {
float lr = cb_data.session_->GetLearningRate();
int update = lambda_func_(&lr, cb_data.epoch_, lr_data_);
if (update == UPDATE_LR) {
int ret = cb_data.session_->SetLearningRate(lr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Error setting Leraning rate in train session";
return mindspore::session::RET_EXIT;
}
}
}
return mindspore::session::RET_CONTINUE;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_
#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_
#include <vector>
#include "src/lite_kernel.h"
namespace mindspore::kernel {
class OptimizerKernel : public LiteKernel {
public:
OptimizerKernel() = default;
OptimizerKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~OptimizerKernel() = default;
virtual int SetLearningRate(float lr) = 0;
virtual float GetLearningRate() = 0;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_

View File

@ -0,0 +1,99 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/train/train_loop.h"
#include <sys/stat.h>
#include <algorithm>
#include <utility>
#include <vector>
#include <iostream>
#include <fstream>
#include <memory>
#include "include/errorcode.h"
#include "include/train_session.h"
#include "src/common/utils.h"
#include "src/tensor.h"
#include "src/train/loss_kernel.h"
#include "src/train/optimizer_kernel.h"
#include "src/sub_graph_kernel.h"
#include "src/train/train_populate_parameter.h"
#include "src/runtime/runtime_api.h"
#include "src/executor.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp32_grad/convolution.h"
namespace mindspore {
namespace lite {
using session::RET_CONTINUE;
using session::RET_EXIT;
using session::RET_STOP_TRAINING;
TrainLoop::~TrainLoop() {
if (train_session_ != nullptr) delete train_session_;
}
int TrainLoop::Train(int epochs, std::vector<session::TrainLoopCallBack *> cbs) {
train_session_->Train();
session::TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
for (auto cb : cbs) cb->Begin(cb_data);
int steps_in_epoch = 1; // should be data_size/batch_size
for (int i = 0; i < epochs; i++) {
cb_data.epoch_ = epoch_++;
for (auto cb : cbs) cb->EpochBegin(cb_data);
for (int s = 0; s < steps_in_epoch; s++) {
cb_data.step_ = s;
for (auto cb : cbs) cb->StepBegin(cb_data);
train_session_->RunGraph(before_cb_, after_cb_);
for (auto cb : cbs) cb->StepEnd(cb_data);
}
int break_loop = false;
for (auto cb : cbs) {
int ret = cb->EpochEnd(cb_data);
if (ret != RET_CONTINUE) {
if (ret == RET_EXIT) {
MS_LOG(ERROR) << "Error in TrainLoop callback";
return RET_ERROR;
}
if (ret == RET_STOP_TRAINING) {
break_loop = true;
}
}
}
if (break_loop) {
break;
}
}
for (auto cb : cbs) cb->End(cb_data);
return RET_OK;
}
} // namespace lite
session::TrainLoop *session::TrainLoop::CreateTrainLoop(const std::string &model_filename, lite::Context *context,
int batch_size) {
auto train_session = session::TrainSession::CreateSession(model_filename, context);
auto loop = new (std::nothrow) lite::TrainLoop(train_session);
return loop;
}
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_
#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_
#include <vector>
#include <string>
#include <tuple>
#include <unordered_map>
#include "src/ops/primitive_c.h"
#include "include/train/train_loop.h"
#include "include/train_session.h"
namespace mindspore {
namespace lite {
class TrainLoop : virtual public session::TrainLoop {
public:
explicit TrainLoop(session::TrainSession *session) : train_session_(session) {}
session::TrainSession *train_session() override { return train_session_; }
int Reset() override {
epoch_ = 0;
return RET_OK;
}
virtual ~TrainLoop();
int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) override {
before_cb_ = before;
after_cb_ = after;
return RET_OK;
}
int Train(int epochs, std::vector<session::TrainLoopCallBack *> cbs) override;
protected:
session::TrainSession *train_session_ = nullptr;
unsigned int epoch_ = 0;
KernelCallBack before_cb_ = nullptr;
KernelCallBack after_cb_ = nullptr;
int batch_size;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_

View File

@ -15,6 +15,7 @@
*/
#include "src/train/train_populate_parameter.h"
#include <algorithm>
#include "src/ops/populate/populate_register.h"
#include "src/ops/pooling_grad.h"
#include "nnacl/pooling_parameter.h"
@ -517,12 +518,15 @@ OpParameter *PopulateArithmeticGradParameter(const mindspore::lite::PrimitiveC *
arithmetic_param->broadcasting_ = ((lite::ArithmeticGrad *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::ArithmeticGrad *)primitive)->NDims();
auto tmp_shape = ((lite::ArithmeticGrad *)primitive)->x1Shape();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::ArithmeticGrad *)primitive)->x2Shape();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::ArithmeticGrad *)primitive)->dyShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
auto shape = ((lite::ArithmeticGrad *)primitive)->x1Shape();
auto source = static_cast<int *>(shape.data());
std::copy(source, source + shape.size(), arithmetic_param->in_shape0_);
shape = ((lite::ArithmeticGrad *)primitive)->x2Shape();
source = static_cast<int *>(shape.data());
std::copy(source, source + shape.size(), arithmetic_param->in_shape1_);
shape = ((lite::ArithmeticGrad *)primitive)->dyShape();
source = static_cast<int *>(shape.data());
std::copy(source, source + shape.size(), arithmetic_param->out_shape_);
return reinterpret_cast<OpParameter *>(arithmetic_param);
}

View File

@ -26,6 +26,7 @@
#include "src/common/utils.h"
#include "src/tensor.h"
#include "src/train/loss_kernel.h"
#include "src/train/optimizer_kernel.h"
#include "src/sub_graph_kernel.h"
#include "src/train/train_populate_parameter.h"
#include "src/runtime/runtime_api.h"
@ -49,10 +50,8 @@ TrainSession::TrainSession() { kernel::PopulateTrainParameters(); }
std::vector<CreatorOp> TrainSession::ReplaceOps() {
const std::vector<CreatorOp> replace = {
{{mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Conv2D},
mindspore::kernel::CpuConvTrainFp32KernelCreator},
{{mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_DepthwiseConv2D},
mindspore::kernel::CpuConvTrainFp32KernelCreator}};
// currently no ops are Hijacked by TrainSession
};
mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance();
std::vector<CreatorOp> results;
for (auto v : replace) {
@ -98,7 +97,7 @@ int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) {
RestoreOps(restore);
CompileTrainKernels(); // Prepare a list of train kernels
CompileInferenceKernels(); // Prepare a list of eval kernels
CompileOptimizedKernels(); // Prepare a list of kenels which are optimized (weight update step)
CompileOptimizedKernels(); // Prepare a list of kernels which are optimized (weight update step)
CompileTrainOutputs(); // prepare outputs in train mode
CompileEvalOutputs(); // prepare outputs in eval mode
AllocWorkSpace();
@ -302,6 +301,30 @@ void TrainSession::CompileOptimizedKernels() {
}
}
int TrainSession::SetLearningRate(float learning_rate) {
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
auto ret = optimizer->SetLearningRate(learning_rate);
if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
return RET_ERROR;
}
}
}
return RET_OK;
}
float TrainSession::GetLearningRate() {
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
return optimizer->GetLearningRate();
}
}
return 0.0;
}
bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const {
return (kernel->Type() == schema::PrimitiveType_SoftmaxCrossEntropy ||
kernel->Type() == schema::PrimitiveType_SparseSoftmaxCrossEntropy ||

View File

@ -42,7 +42,6 @@
namespace mindspore {
namespace lite {
using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>;
class TrainSession : virtual public session::TrainSession, virtual public lite::LiteSession {
public:
@ -59,6 +58,8 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
int Train() override;
int Eval() override;
int SetLearningRate(float learning_rate) override;
float GetLearningRate() override;
void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); }
std::vector<tensor::MSTensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); }
@ -80,6 +81,10 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
return lite::RET_ERROR;
}
std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetPredictions() const override {
return eval_output_tensor_map_;
}
protected:
void AllocWorkSpace();
bool IsLossKernel(const kernel::LiteKernel *kernel) const;

View File

@ -1,11 +1,13 @@
mini_alexnet
#mobilenetv1
# mobilenetv1
mobilenetv2
mobilenetv3
lenet
effnet
effnet_tune
# effnet_tune
# lenetv1
# resnet
# effnetv1
# googlenet
# densenet
# one_net
#LAST

View File

@ -83,7 +83,7 @@ function Run_x86() {
--inDataFile=${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin \
--expectedDataFile=${train_io_path}/${model_name}_outputs.bin \
--exportFile=${ms_models_path}/${model_name}_train_exported.ms >> "${run_x86_log_file}" \
--epochs=${epoch_num}
--epochs=${epoch_num} --numThreads=${threads}
if [ $? = 0 ]; then
run_result='x86: '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file}
else
@ -178,7 +178,8 @@ function Run_arm() {
--modelFile=${model_name}_train.ms \
--inDataFile=${tmp_dir}/${model_name}_input1.bin,${tmp_dir}/${model_name}_input2.bin \
--expectedDataFile=${tmp_dir}/${model_name}_outputs.bin \
--exportFile=${tmp_dir}/${model_name}_train_exported.ms
--exportFile=${tmp_dir}/${model_name}_train_exported.ms \
--numThreads=${threads}
ENDM
)
echo "${adb_cmd}" >> ${run_arm_log_file}
@ -221,8 +222,9 @@ echo ${basepath}
# Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408"
# For running on arm64, use -t to set platform tools path (for using adb commands)
epoch_num=1
threads=1
train_io_path=""
while getopts "r:m:d:i:e:vt:" opt; do
while getopts "r:m:d:i:e:vt:q:" opt; do
case ${opt} in
r)
release_path=${OPTARG}
@ -249,9 +251,13 @@ while getopts "r:m:d:i:e:vt:" opt; do
run_valgrind="valgrind --log-file=valgrind.log "
echo "Run x86 with valgrind"
;;
q)
threads=${OPTARG}
echo "threads=${threads}"
;;
t)
epoch_num=${OPTARG}
echo "train epoch num is ${OPTARG}"
echo "train epoch num is ${epoch_num}"
;;
?)
echo "unknown para"

View File

@ -511,6 +511,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
auto valueNode = input_anode->cast<ValueNodePtr>();
auto paramTensor = std::make_unique<schema::TensorT>();
auto value = valueNode->value();
#ifdef SUPPORT_TRAIN
paramTensor->name = valueNode->fullname_with_scope();
#endif
if (value->isa<tensor::Tensor>()) {
auto valueAbstract = valueNode->abstract();
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
@ -527,7 +530,6 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
paramTensor->dims = dims;
#ifdef SUPPORT_TRAIN
if (paramTensor->dims.size() == 0) paramTensor->dims = {1};
paramTensor->name = valueNode->fullname_with_scope();
#endif
paramTensor->nodeType = schema::NodeType::NodeType_ValueNode;
auto data = value->cast<tensor::TensorPtr>();

View File

@ -135,6 +135,7 @@ int NetTrain::ReadCalibData() {
MS_LOG(INFO) << "Start reading calibData file";
std::string tensor_name;
while (!in_file.eof()) {
getline(in_file, line);
std::stringstream string_line1(line);
@ -189,7 +190,6 @@ int NetTrain::CompareOutput() {
MS_ASSERT(tensor->MutableData() != nullptr);
auto outputs = tensor->MutableData();
float bias = CompareData<float>(node_or_tensor_name, tensor->shape(), reinterpret_cast<float *>(outputs));
if (bias >= 0) {
total_bias += bias;
total_size++;
@ -228,7 +228,7 @@ int NetTrain::CompareOutput() {
int NetTrain::MarkPerformance() {
MS_LOG(INFO) << "Running train loops...";
std::cout << "Running train loops..." << std::endl;
uint64_t time_min = 1000000;
uint64_t time_min = 0xFFFFFFFFFFFFFFFF;
uint64_t time_max = 0;
uint64_t time_avg = 0;

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_NET_TRAIN_NET_TRAIN_H_
#define MINDSPORE_LITE_TOOLS_NET_TRAIN_NET_TRAIN_H_
#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_
#define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_
#include <getopt.h>
#include <signal.h>
@ -59,6 +59,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0);
AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false);
AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1);
AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1);
// MarkAccuracy
AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", "");
AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", "");
@ -239,4 +240,4 @@ class MS_API NetTrain {
int MS_API RunNetTrain(int argc, const char **argv);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_NET_TRAIN_NET_TRAIN_H_
#endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_

View File

@ -136,10 +136,10 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT
static const std::vector<schema::PrimitiveType> needInsertOpList = {
#ifdef SUPPORT_TRAIN
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split,
schema::PrimitiveType_Slice, schema::PrimitiveType_Crop, schema::PrimitiveType_Mul,
schema::PrimitiveType_Add
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split,
schema::PrimitiveType_Crop, schema::PrimitiveType_Mul, schema::PrimitiveType_Add,
schema::PrimitiveType_ActivationGrad
#else
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add,