!9307 add mindspore train example

From: @xutianchun
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-02 14:57:04 +08:00 committed by Gitee
commit d91b5c864c
13 changed files with 919 additions and 0 deletions

View File

@ -0,0 +1,134 @@
# Content
<!-- TOC -->
- [Overview](#overview)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Detailed Description](#script-detailed-description)
<!-- /TOC -->
# Overview
This folder holds code for Training-on-Device of a LeNet model. Part of the code runs on a server using MindSpore infrastructure, another part uses MindSpore Lite conversion utility, and the last part is the actual training of the model on some android-based device.
# Model Architecture
LeNet is a very simple network which is composed of only 5 layers, 2 of which are convolutional layers and the remaining 3 are fully connected layers. Such a small network can be fully trained (from scratch) on a device in a short time. Therefore, it is a good example.
# Dataset
In this example we use the MNIST dataset of handwritten digits as published in [THE MNIST DATABASE](<http://yann.lecun.com/exdb/mnist/>)
- Dataset size52.4M60,000 28*28 in 10 classes
- Test10,000 images
- Train60,000 images
- Data formatbinary files
- NoteData will be processed in dataset.cc
- The dataset directory structure is as follows:
```python
mnist/
├── test
│   ├── t10k-images-idx3-ubyte
│   └── t10k-labels-idx1-ubyte
└── train
├── train-images-idx3-ubyte
└── train-labels-idx1-ubyte
```
# Environment Requirements
- Server side
- [MindSpore Framework](https://www.mindspore.cn/install/en): it is recommended to install a docker image
- [MindSpore ToD Framework](https://www.mindspore.cn/tutorial/tod/en/use/prparation.html)
- [Android NDK r20b](https://dl.google.com/android/repository/android-ndk-r20b-linux-x86_64.zip)
- [Android SDK](https://developer.android.com/studio?hl=zh-cn#cmdline-tools)
- A connected Android device
# Quick Start
After installing all the above mentioned, the script in the home directory could be run with the following arguments:
```python
sh ./prepare_and_run.sh DATASET_PATH [MINDSPORE_DOCKER] [RELEASE.tar.gz]
```
where:
- DATASET_PATH is the path to the [dataset](#dataset),
- MINDSPORE_DOCKER is the image name of the docker that runs [MindSpore](#environment-requirements). If not provided MindSpore will be run locally
- and REALEASE.tar.gz is a pointer to the MindSpore ToD release tar ball. If not provided, the script will attempt to find MindSpore ToD compilation output.
# Script Detailed Description
The provided `prepare_and_run.sh` script is performing the followings:
- Prepare the trainable lenet model in a `.ms` format
- Prepare the folder that should be pushed into the device
- Copy this folder into the device and run the scripts on the device
See how to run the script and paramaters definitions in the [Quick Start Section](#quick-start)
## Preparing the model
Within the model folder a `prepare_model.sh` script uses MindSpore infrastructure to export the model into a `.mindir` file. The user can specify a docker image on which MindSpore is installed. Otherwise, the pyhton script will be run locally.
The script then converts the `.mindir` to a `.ms` format using the MindSpore ToD converter.
The script accepts a tar ball where the converter resides. Otherwise, the script will attempt to find the converter in the MindSpore ToD build output directory.
## Preparing the Folder
The `lenet_tod.ms` model file is then copied into the `package` folder as well as scripts, the MindSpore ToD library and the MNIST dataset.
Finally, the code (in src) is compiled for arm64 and the binary is copied into the `package` folder.
### Running the code on the device
To run the code on the device the script first uses `adb` tool to push the `package` folder into the device. It then runs training (which takes some time) and finally runs evaluation of the trained model using the test data.
# Folder Directory tree
``` python
train_lenet/
├── Makefile # Makefile of src code
├── model
│   ├── lenet_export.py # Python script that exports the LeNet model to .mindir
│   ├── prepare_model.sh # script that export model (using docker) then converts it
│   └── train_utils.py # utility function used during the export
├── prepare_and_run.sh # main script that creates model, compiles it and send to device for running
├── README.md # this manual
├── scripts
│   ├── eval.sh # on-device script that load the train model and evaluates its accuracy
│   ├── run_eval.sh # adb script that launches eval.sh
│   ├── run_train.sh # adb script that launches train.sh
│   └── train.sh # on-device script that load the initial model and train it
├── src
│   ├── dataset.cc # dataset handler
│   ├── dataset.h # dataset class header
│   ├── net_runner.cc # program that runs training/evaluation of models
│   └── net_runner.h # net_runner header
```
When the `prepare_and_run.sh` script is run, the following folder is prepared. It is pushed to the device and then training runs
``` python
├── package
│   ├── bin
│   │   └── net_runner # the executable that performs the training/evaluation
│   ├── dataset
│   │   ├── test
│   │   │   ├── t10k-images-idx3-ubyte # test images
│   │   │   └── t10k-labels-idx1-ubyte # test labels
│   │   └── train
│   │   ├── train-images-idx3-ubyte # train images
│   │   └── train-labels-idx1-ubyte # train labels
│   ├── eval.sh # on-device script that load the train model and evaluates its accuracy
│   ├── lib
│   │   └── libmindspore-lite.so # MindSpore Lite library
│   ├── model
│   │   └── lenet_tod.ms # model to train
│   └── train.sh # on-device script that load the initial model and train it
```

View File

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lenet_export."""
import sys
from mindspore import context, Tensor
import mindspore.common.dtype as mstype
from mindspore.train.serialization import export
from lenet import LeNet5
import numpy as np
from train_utils import TrainWrap
sys.path.append('../../../cv/lenet/src/')
n = LeNet5()
n.set_train()
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
batch_size = 32
x = Tensor(np.ones((batch_size, 1, 32, 32)), mstype.float32)
label = Tensor(np.zeros([batch_size, 10]).astype(np.float32))
net = TrainWrap(n)
export(net, x, label, file_name="lenet_tod.mindir", file_format='MINDIR')
print("finished exporting")

View File

@ -0,0 +1,24 @@
CONVERTER="../../../../../mindspore/lite/build/tools/converter/converter_lite"
if [ ! -f "$CONVERTER" ]; then
if ! command -v converter_lite &> /dev/null
then
echo "converter_lite could not be found in MindSpore build directory nor in system path"
exit
else
CONVERTER=converter_lite
fi
fi
echo "============Exporting=========="
if [ -n "$1" ]; then
DOCKER_IMG=$1
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "python lenet_export.py; chmod 444 lenet_tod.mindir; rm -rf __pycache__"
else
echo "MindSpore docker was not provided, attempting to run locally"
python lenet_export.py
fi
echo "============Converting========="
$CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod

View File

@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_utils."""
import mindspore.nn as nn
from mindspore.common.parameter import ParameterTuple
def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
"""
TrainWrap
"""
if loss_fn is None:
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
loss_net = nn.WithLossCell(net, loss_fn)
loss_net.set_train()
if weights is None:
weights = ParameterTuple(net.trainable_params())
if optimizer is None:
optimizer = nn.Adam(weights, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0)
train_net = nn.TrainOneStepCell(loss_net, optimizer)
return train_net

View File

@ -0,0 +1,82 @@
#!/bin/bash
display_usage() {
echo -e "\nUsage: prepare_and_run.sh dataset_path [mindspore_docker] [release.tar.gz]\n"
}
if [ -n "$1" ]; then
MNIST_DATA_PATH=$1
else
echo "MNIST Dataset directory path was not provided"
display_usage
exit 0
fi
if [ -n "$2" ]; then
DOCKER=$2
else
DOCKER=""
#echo "MindSpore docker was not provided"
#display_usage
#exit 0
fi
if [ -n "$3" ]; then
TARBALL=$3
else
if [ -f ../../../../output/mindspore-lite-*-runtime-arm64-cpu-train.tar.gz ]; then
TARBALL="../../../../output/mindspore-lite-*-runtime-arm64-cpu-train.tar.gz"
else
echo "release.tar.gz was not found"
display_usage
exit 0
fi
fi
# Prepare the model
cd model/
rm -f *.ms
./prepare_model.sh $DOCKER
cd -
# Copy the .ms model to the package folder
rm -rf package
mkdir -p package/model
cp model/*.ms package/model
# Copy the running script to the package
cp scripts/train.sh package/
cp scripts/eval.sh package/
# Copy the shared MindSpore ToD library
tar -xzvf ${TARBALL} --wildcards --no-anchored libmindspore-lite.so
tar -xzvf ${TARBALL} --wildcards --no-anchored include
mv mindspore-*/lib package/
mkdir msl
mv mindspore-*/* msl/
rm -rf mindspore-*
# Copy the dataset to the package
cp -r ${MNIST_DATA_PATH} package/dataset
# Compile program
make TARGET=arm64
# Copy the executable to the package
mv bin package/
# Push the folder to the device
adb push package /data/local/tmp/
echo "Training on Device"
adb shell < scripts/run_train.sh
echo
echo "Load trained model and evaluate accuracy"
adb shell < scripts/run_eval.sh
echo
#rm -rf src/*.o package model/__pycache__ model/*.ms
#./prepare_and_run.sh /opt/share/dataset/mnist mindspore_dev:5

View File

@ -0,0 +1,19 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
DATA_PATH=$1
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod_trained_3000.ms -e 0 -d dataset

View File

@ -0,0 +1,2 @@
cd /data/local/tmp/package
/system/bin/sh eval.sh

View File

@ -0,0 +1,2 @@
cd /data/local/tmp/package
/system/bin/sh train.sh

View File

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
DATA_PATH=$1
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod.ms -e 3000 -d dataset

View File

@ -0,0 +1,200 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/dataset.h"
#include <assert.h>
#include <arpa/inet.h>
#include <map>
#include <iostream>
#include <fstream>
#include <memory>
#include <filesystem>
using LabelId = std::map<std::string, int>;
char *ReadFile(const std::string &file, size_t *size) {
assert(size != nullptr);
std::string realPath(file);
std::ifstream ifs(realPath);
if (!ifs.good()) {
std::cerr << "file: " << realPath << " does not exist";
return nullptr;
}
if (!ifs.is_open()) {
std::cerr << "file: " << realPath << " open failed";
return nullptr;
}
ifs.seekg(0, std::ios::end);
*size = ifs.tellg();
std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
if (buf == nullptr) {
std::cerr << "malloc buf failed, file: " << realPath;
ifs.close();
return nullptr;
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf.get(), *size);
ifs.close();
return buf.release();
}
DataSet::~DataSet() {
for (auto itr = train_data_.begin(); itr != train_data_.end(); ++itr) {
auto ptr = std::get<0>(*itr);
delete[] ptr;
}
for (auto itr = test_data_.begin(); itr != test_data_.end(); ++itr) {
auto ptr = std::get<0>(*itr);
delete[] ptr;
}
}
int DataSet::Init(const std::string &data_base_directory, database_type type) {
InitializeMNISTDatabase(data_base_directory);
return 0;
}
void DataSet::InitializeMNISTDatabase(std::string dpath) {
// int total_data = 0;
num_of_classes_ = 10;
// total_data +=
ReadMNISTFile(dpath + "/train/train-images-idx3-ubyte", dpath + "/train/train-labels-idx1-ubyte", &train_data_);
// total_data +=
ReadMNISTFile(dpath + "/test/t10k-images-idx3-ubyte", dpath + "/test/t10k-labels-idx1-ubyte", &test_data_);
}
int DataSet::ReadMNISTFile(const std::string &ifile_name, const std::string &lfile_name,
std::vector<DataLabelTuple> *dataset) {
std::ifstream lfile(lfile_name, std::ios::binary);
if (!lfile.is_open()) {
std::cerr << "Cannot open label file " << lfile_name << std::endl;
return 0;
}
std::ifstream ifile(ifile_name, std::ios::binary);
if (!ifile.is_open()) {
std::cerr << "Cannot open data file " << ifile_name << std::endl;
return 0;
}
int magic_number = 0;
lfile.read(reinterpret_cast<char *>(&magic_number), sizeof(magic_number));
magic_number = ntohl(magic_number);
if (magic_number != 2049) {
std::cout << "Invalid MNIST label file!" << std::endl;
return 0;
}
int number_of_labels = 0;
lfile.read(reinterpret_cast<char *>(&number_of_labels), sizeof(number_of_labels));
number_of_labels = ntohl(number_of_labels);
ifile.read(reinterpret_cast<char *>(&magic_number), sizeof(magic_number));
magic_number = ntohl(magic_number);
if (magic_number != 2051) {
std::cout << "Invalid MNIST image file!" << std::endl;
return 0;
}
int number_of_images = 0;
ifile.read(reinterpret_cast<char *>(&number_of_images), sizeof(number_of_images));
number_of_images = ntohl(number_of_images);
int n_rows = 0;
ifile.read(reinterpret_cast<char *>(&n_rows), sizeof(n_rows));
n_rows = ntohl(n_rows);
int n_cols = 0;
ifile.read(reinterpret_cast<char *>(&n_cols), sizeof(n_cols));
n_cols = ntohl(n_cols);
if (number_of_labels != number_of_images) {
std::cout << "number of records in labels and images files does not match" << std::endl;
return 0;
}
int image_size = n_rows * n_cols;
unsigned char labels[number_of_labels];
unsigned char data[image_size];
lfile.read(reinterpret_cast<char *>(labels), number_of_labels);
for (int i = 0; i < number_of_labels; ++i) {
std::unique_ptr<float[]> hwc_bin_image(new (std::nothrow) float[32 * 32]);
ifile.read(reinterpret_cast<char *>(data), image_size);
for (size_t r = 0; r < 32; r++) {
for (size_t c = 0; c < 32; c++) {
if (r < 2 || r > 29 || c < 2 || c > 29)
hwc_bin_image[r * 32 + c] = 0.0;
else
hwc_bin_image[r * 32 + c] = (static_cast<float>(data[(r - 2) * 28 + (c - 2)])) / 255.0;
}
}
DataLabelTuple data_entry = std::make_tuple(reinterpret_cast<char *>(hwc_bin_image.release()), labels[i]);
dataset->push_back(data_entry);
}
return number_of_labels;
}
std::vector<FileTuple> DataSet::ReadFileList(std::string dpath) {
std::vector<FileTuple> vec;
std::ifstream ifs(dpath + "/file_list.txt");
std::string file_name;
if (ifs.is_open()) {
int label;
while (!ifs.eof()) {
ifs >> label >> file_name;
vec.push_back(make_tuple(label, file_name));
}
}
return vec;
}
std::vector<FileTuple> DataSet::ReadDir(const std::string dpath) {
std::filesystem::directory_iterator dir(dpath);
std::vector<FileTuple> vec;
LabelId label_id;
int class_id = 0;
int class_label;
for (const auto p : dir) {
if (p.is_directory()) {
std::string path = p.path().stem().string();
auto label = label_id.find(path);
if (label == label_id.end()) {
label_id[path] = class_id;
class_label = class_id;
class_id++;
num_of_classes_ = class_id;
} else {
class_label = label->second;
}
std::filesystem::directory_iterator ndir(dpath + "/" + path);
for (const auto np : ndir) {
if (np.path().extension().string() == ".bin") {
std::string entry =
dpath + "/" + np.path().parent_path().stem().string() + "/" + np.path().filename().string();
FileTuple ft = make_tuple(class_label, entry);
vec.push_back(ft);
}
}
}
}
return vec;
}

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_DATASET_H_
#define MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_DATASET_H_
#include <tuple>
#include <string>
#include <vector>
using DataLabelTuple = std::tuple<char *, int>;
using FileTuple = std::tuple<int, std::string>;
enum database_type { DS_CIFAR10_BINARY = 0, DS_MNIST_BINARY, DS_OTHER };
char *ReadFile(const std::string &file, size_t *size); // utility function
class DataSet {
public:
DataSet() {}
~DataSet();
int Init(const std::string &data_base_directory, database_type type = DS_OTHER);
const std::vector<DataLabelTuple> &train_data() const { return train_data_; }
const std::vector<DataLabelTuple> &test_data() const { return test_data_; }
unsigned int num_of_classes() { return num_of_classes_; }
void set_expected_data_size(unsigned int expected_data_size) { expected_data_size_ = expected_data_size; }
unsigned int expected_data_size() { return expected_data_size_; }
private:
std::vector<FileTuple> ReadFileList(std::string dpath);
std::vector<FileTuple> ReadDir(const std::string dpath);
int ReadMNISTFile(const std::string &ifile, const std::string &lfile, std::vector<DataLabelTuple> *dataset);
void InitializeMNISTDatabase(std::string dpath);
std::vector<DataLabelTuple> train_data_;
std::vector<DataLabelTuple> test_data_;
unsigned int num_of_classes_ = 0;
unsigned int expected_data_size_ = 0;
};
#endif // MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_DATASET_H_

View File

@ -0,0 +1,247 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/net_runner.h"
#include <math.h>
#include <getopt.h>
#include <iostream>
#include <fstream>
#include "include/context.h"
unsigned int NetRunner::seed_ = time(NULL);
// Definition of callback function after forwarding operator.
bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
const mindspore::CallBackParam &call_param) {
printf("%s\n", call_param.node_name.c_str());
for (size_t i = 0; i < after_inputs.size(); i++) {
int num2p = (after_inputs.at(i)->ElementsNum());
printf("in%zu(%d): ", i, num2p);
if (num2p > 10) num2p = 10;
if (after_inputs.at(i)->data_type() == mindspore::kNumberTypeInt32) {
auto d = reinterpret_cast<int *>(after_inputs.at(i)->MutableData());
for (int j = 0; j < num2p; j++) printf("%d, ", d[j]);
} else {
auto d = reinterpret_cast<float *>(after_inputs.at(i)->MutableData());
for (int j = 0; j < num2p; j++) printf("%f, ", d[j]);
}
printf("\n");
}
for (size_t i = 0; i < after_outputs.size(); i++) {
auto d = reinterpret_cast<float *>(after_outputs.at(i)->MutableData());
int num2p = (after_outputs.at(i)->ElementsNum());
printf("ou%zu(%d): ", i, num2p);
if (num2p > 10) num2p = 10;
for (int j = 0; j < num2p; j++) printf("%f, ", d[j]);
printf("\n");
}
return true;
}
NetRunner::~NetRunner() {
if (session_ != nullptr) delete session_;
}
void NetRunner::InitAndFigureInputs() {
mindspore::lite::Context context;
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
context.thread_num_ = 1;
session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context);
assert(nullptr != session_);
auto inputs = session_->GetInputs();
assert(inputs.size() > 1);
data_index_ = 0;
label_index_ = 1;
batch_size_ = inputs[data_index_]->shape()[0];
data_size_ = inputs[data_index_]->Size() / batch_size_; // in bytes
if (verbose_) {
std::cout << "data size: " << data_size_ << std::endl << "batch size: " << batch_size_ << std::endl;
}
}
mindspore::tensor::MSTensor *NetRunner::SearchOutputsForSize(size_t size) const {
auto outputs = session_->GetOutputs();
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
if (it->second->ElementsNum() == size) return it->second;
}
std::cout << "Model does not have an output tensor with size " << size << std::endl;
return nullptr;
}
std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dataset, bool serially) const {
std::vector<int> labels_vec;
static unsigned int idx = 1;
int total_size = dataset.size();
auto inputs = session_->GetInputs();
char *input_data = reinterpret_cast<char *>(inputs.at(data_index_)->MutableData());
auto labels = reinterpret_cast<float *>(inputs.at(label_index_)->MutableData());
assert(total_size > 0);
assert(input_data != nullptr);
std::fill(labels, labels + inputs.at(label_index_)->ElementsNum(), 0.f);
for (int i = 0; i < batch_size_; i++) {
if (serially) {
idx = ++idx % total_size;
} else {
idx = rand_r(&seed_) % total_size;
}
int label = 0;
char *data = nullptr;
std::tie(data, label) = dataset[idx];
memcpy(input_data + i * data_size_, data, data_size_);
labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation
labels_vec.push_back(label);
}
return labels_vec;
}
float NetRunner::CalculateAccuracy(int max_tests) const {
float accuracy = 0.0;
const std::vector<DataLabelTuple> test_set = ds_.test_data();
int tests = test_set.size() / batch_size_;
if (max_tests != -1 && tests < max_tests) tests = max_tests;
session_->Eval();
for (int i = 0; i < tests; i++) {
auto labels = FillInputData(test_set, (max_tests == -1));
session_->RunGraph();
auto outputsv = SearchOutputsForSize(batch_size_ * num_of_classes_);
assert(outputsv != nullptr);
auto scores = reinterpret_cast<float *>(outputsv->MutableData());
for (int b = 0; b < batch_size_; b++) {
int max_idx = 0;
float max_score = scores[num_of_classes_ * b];
for (int c = 0; c < num_of_classes_; c++) {
if (scores[num_of_classes_ * b + c] > max_score) {
max_score = scores[num_of_classes_ * b + c];
max_idx = c;
}
}
if (labels[b] == max_idx) accuracy += 1.0;
}
}
session_->Train();
accuracy /= static_cast<float>(batch_size_ * tests);
return accuracy;
}
int NetRunner::InitDB() {
if (data_size_ != 0) ds_.set_expected_data_size(data_size_);
int ret = ds_.Init(data_dir_, DS_MNIST_BINARY);
num_of_classes_ = ds_.num_of_classes();
if (ds_.test_data().size() == 0) {
std::cout << "No relevant data was found in " << data_dir_ << std::endl;
assert(ds_.test_data().size() != 0);
}
return ret;
}
float NetRunner::GetLoss() const {
auto outputsv = SearchOutputsForSize(1); // Search for Loss which is a single value tensor
assert(outputsv != nullptr);
auto loss = reinterpret_cast<float *>(outputsv->MutableData());
return loss[0];
}
int NetRunner::TrainLoop() {
session_->Train();
float min_loss = 1000.;
float max_acc = 0.;
for (int i = 0; i < cycles_; i++) {
FillInputData(ds_.train_data());
session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr);
float loss = GetLoss();
if (min_loss > loss) min_loss = loss;
if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) {
auto cpkt_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms";
session_->SaveToFile(cpkt_fn);
}
if ((i + 1) % 100 == 0) {
float acc = CalculateAccuracy(10);
if (max_acc < acc) max_acc = acc;
std::cout << i + 1 << ":\tLoss is " << std::setw(7) << loss << " [min=" << min_loss << "] "
<< " max_acc=" << max_acc << std::endl;
}
}
return 0;
}
int NetRunner::Main() {
InitAndFigureInputs();
InitDB();
TrainLoop();
float acc = CalculateAccuracy();
std::cout << "accuracy = " << acc << std::endl;
if (cycles_ > 0) {
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(cycles_) + ".ms";
session_->SaveToFile(trained_fn);
}
return 0;
}
void NetRunner::Usage() {
std::cout << "Usage: net_runner -f <.ms model file> -d <data_dir> [-c <num of training cycles>] "
<< "[-v (verbose mode)] [-s <save checkpoint every X iterations>]" << std::endl;
}
bool NetRunner::ReadArgs(int argc, char *argv[]) {
int opt;
while ((opt = getopt(argc, argv, "f:e:d:s:ihc:v")) != -1) {
switch (opt) {
case 'f':
ms_file_ = std::string(optarg);
break;
case 'e':
cycles_ = atoi(optarg);
break;
case 'd':
data_dir_ = std::string(optarg);
break;
case 'v':
verbose_ = true;
break;
case 's':
save_checkpoint_ = atoi(optarg);
break;
case 'h':
default:
Usage();
return false;
}
}
return true;
}
int main(int argc, char **argv) {
NetRunner nr;
if (nr.ReadArgs(argc, argv)) {
nr.Main();
} else {
return -1;
}
return 0;
}

View File

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_NET_RUNNER_H_
#define MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_NET_RUNNER_H_
#include <tuple>
#include <filesystem>
#include <map>
#include <vector>
#include <string>
#include "include/train_session.h"
#include "include/ms_tensor.h"
#include "src/dataset.h"
class NetRunner {
public:
int Main();
bool ReadArgs(int argc, char *argv[]);
~NetRunner();
private:
void Usage();
void InitAndFigureInputs();
int InitDB();
int TrainLoop();
std::vector<int> FillInputData(const std::vector<DataLabelTuple> &dataset, bool is_train_set = false) const;
float CalculateAccuracy(int max_tests = -1) const;
float GetLoss() const;
mindspore::tensor::MSTensor *SearchOutputsForSize(size_t size) const;
DataSet ds_;
mindspore::session::TrainSession *session_ = nullptr;
std::string ms_file_ = "";
std::string data_dir_ = "";
size_t data_size_ = 0;
size_t batch_size_ = 0;
unsigned int cycles_ = 100;
int data_index_ = 0;
int label_index_ = -1;
int num_of_classes_ = 0;
bool verbose_ = false;
int save_checkpoint_ = 0;
static unsigned int seed_;
};
#endif // MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_NET_RUNNER_H_