forked from mindspore-Ecosystem/mindspore
!9307 add mindspore train example
From: @xutianchun Reviewed-by: Signed-off-by:
This commit is contained in:
commit
d91b5c864c
|
@ -0,0 +1,134 @@
|
|||
# Content
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Detailed Description](#script-detailed-description)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# Overview
|
||||
|
||||
This folder holds code for Training-on-Device of a LeNet model. Part of the code runs on a server using MindSpore infrastructure, another part uses MindSpore Lite conversion utility, and the last part is the actual training of the model on some android-based device.
|
||||
|
||||
# Model Architecture
|
||||
|
||||
LeNet is a very simple network which is composed of only 5 layers, 2 of which are convolutional layers and the remaining 3 are fully connected layers. Such a small network can be fully trained (from scratch) on a device in a short time. Therefore, it is a good example.
|
||||
|
||||
# Dataset
|
||||
|
||||
In this example we use the MNIST dataset of handwritten digits as published in [THE MNIST DATABASE](<http://yann.lecun.com/exdb/mnist/>)
|
||||
|
||||
- Dataset size:52.4M,60,000 28*28 in 10 classes
|
||||
- Test:10,000 images
|
||||
- Train:60,000 images
|
||||
- Data format:binary files
|
||||
- Note:Data will be processed in dataset.cc
|
||||
|
||||
- The dataset directory structure is as follows:
|
||||
|
||||
```python
|
||||
mnist/
|
||||
├── test
|
||||
│ ├── t10k-images-idx3-ubyte
|
||||
│ └── t10k-labels-idx1-ubyte
|
||||
└── train
|
||||
├── train-images-idx3-ubyte
|
||||
└── train-labels-idx1-ubyte
|
||||
```
|
||||
|
||||
# Environment Requirements
|
||||
|
||||
- Server side
|
||||
- [MindSpore Framework](https://www.mindspore.cn/install/en): it is recommended to install a docker image
|
||||
- [MindSpore ToD Framework](https://www.mindspore.cn/tutorial/tod/en/use/prparation.html)
|
||||
- [Android NDK r20b](https://dl.google.com/android/repository/android-ndk-r20b-linux-x86_64.zip)
|
||||
- [Android SDK](https://developer.android.com/studio?hl=zh-cn#cmdline-tools)
|
||||
- A connected Android device
|
||||
|
||||
# Quick Start
|
||||
|
||||
After installing all the above mentioned, the script in the home directory could be run with the following arguments:
|
||||
|
||||
```python
|
||||
sh ./prepare_and_run.sh DATASET_PATH [MINDSPORE_DOCKER] [RELEASE.tar.gz]
|
||||
```
|
||||
|
||||
where:
|
||||
|
||||
- DATASET_PATH is the path to the [dataset](#dataset),
|
||||
- MINDSPORE_DOCKER is the image name of the docker that runs [MindSpore](#environment-requirements). If not provided MindSpore will be run locally
|
||||
- and REALEASE.tar.gz is a pointer to the MindSpore ToD release tar ball. If not provided, the script will attempt to find MindSpore ToD compilation output.
|
||||
|
||||
# Script Detailed Description
|
||||
|
||||
The provided `prepare_and_run.sh` script is performing the followings:
|
||||
|
||||
- Prepare the trainable lenet model in a `.ms` format
|
||||
- Prepare the folder that should be pushed into the device
|
||||
- Copy this folder into the device and run the scripts on the device
|
||||
|
||||
See how to run the script and paramaters definitions in the [Quick Start Section](#quick-start)
|
||||
|
||||
## Preparing the model
|
||||
|
||||
Within the model folder a `prepare_model.sh` script uses MindSpore infrastructure to export the model into a `.mindir` file. The user can specify a docker image on which MindSpore is installed. Otherwise, the pyhton script will be run locally.
|
||||
The script then converts the `.mindir` to a `.ms` format using the MindSpore ToD converter.
|
||||
The script accepts a tar ball where the converter resides. Otherwise, the script will attempt to find the converter in the MindSpore ToD build output directory.
|
||||
|
||||
## Preparing the Folder
|
||||
|
||||
The `lenet_tod.ms` model file is then copied into the `package` folder as well as scripts, the MindSpore ToD library and the MNIST dataset.
|
||||
Finally, the code (in src) is compiled for arm64 and the binary is copied into the `package` folder.
|
||||
|
||||
### Running the code on the device
|
||||
|
||||
To run the code on the device the script first uses `adb` tool to push the `package` folder into the device. It then runs training (which takes some time) and finally runs evaluation of the trained model using the test data.
|
||||
|
||||
# Folder Directory tree
|
||||
|
||||
``` python
|
||||
train_lenet/
|
||||
├── Makefile # Makefile of src code
|
||||
├── model
|
||||
│ ├── lenet_export.py # Python script that exports the LeNet model to .mindir
|
||||
│ ├── prepare_model.sh # script that export model (using docker) then converts it
|
||||
│ └── train_utils.py # utility function used during the export
|
||||
├── prepare_and_run.sh # main script that creates model, compiles it and send to device for running
|
||||
├── README.md # this manual
|
||||
├── scripts
|
||||
│ ├── eval.sh # on-device script that load the train model and evaluates its accuracy
|
||||
│ ├── run_eval.sh # adb script that launches eval.sh
|
||||
│ ├── run_train.sh # adb script that launches train.sh
|
||||
│ └── train.sh # on-device script that load the initial model and train it
|
||||
├── src
|
||||
│ ├── dataset.cc # dataset handler
|
||||
│ ├── dataset.h # dataset class header
|
||||
│ ├── net_runner.cc # program that runs training/evaluation of models
|
||||
│ └── net_runner.h # net_runner header
|
||||
```
|
||||
|
||||
When the `prepare_and_run.sh` script is run, the following folder is prepared. It is pushed to the device and then training runs
|
||||
|
||||
``` python
|
||||
├── package
|
||||
│ ├── bin
|
||||
│ │ └── net_runner # the executable that performs the training/evaluation
|
||||
│ ├── dataset
|
||||
│ │ ├── test
|
||||
│ │ │ ├── t10k-images-idx3-ubyte # test images
|
||||
│ │ │ └── t10k-labels-idx1-ubyte # test labels
|
||||
│ │ └── train
|
||||
│ │ ├── train-images-idx3-ubyte # train images
|
||||
│ │ └── train-labels-idx1-ubyte # train labels
|
||||
│ ├── eval.sh # on-device script that load the train model and evaluates its accuracy
|
||||
│ ├── lib
|
||||
│ │ └── libmindspore-lite.so # MindSpore Lite library
|
||||
│ ├── model
|
||||
│ │ └── lenet_tod.ms # model to train
|
||||
│ └── train.sh # on-device script that load the initial model and train it
|
||||
```
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lenet_export."""
|
||||
|
||||
import sys
|
||||
from mindspore import context, Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import export
|
||||
from lenet import LeNet5
|
||||
import numpy as np
|
||||
from train_utils import TrainWrap
|
||||
|
||||
sys.path.append('../../../cv/lenet/src/')
|
||||
|
||||
n = LeNet5()
|
||||
n.set_train()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
|
||||
|
||||
batch_size = 32
|
||||
x = Tensor(np.ones((batch_size, 1, 32, 32)), mstype.float32)
|
||||
label = Tensor(np.zeros([batch_size, 10]).astype(np.float32))
|
||||
net = TrainWrap(n)
|
||||
export(net, x, label, file_name="lenet_tod.mindir", file_format='MINDIR')
|
||||
|
||||
print("finished exporting")
|
|
@ -0,0 +1,24 @@
|
|||
CONVERTER="../../../../../mindspore/lite/build/tools/converter/converter_lite"
|
||||
if [ ! -f "$CONVERTER" ]; then
|
||||
if ! command -v converter_lite &> /dev/null
|
||||
then
|
||||
echo "converter_lite could not be found in MindSpore build directory nor in system path"
|
||||
exit
|
||||
else
|
||||
CONVERTER=converter_lite
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "============Exporting=========="
|
||||
if [ -n "$1" ]; then
|
||||
DOCKER_IMG=$1
|
||||
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "python lenet_export.py; chmod 444 lenet_tod.mindir; rm -rf __pycache__"
|
||||
else
|
||||
echo "MindSpore docker was not provided, attempting to run locally"
|
||||
python lenet_export.py
|
||||
fi
|
||||
|
||||
|
||||
echo "============Converting========="
|
||||
$CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_utils."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
|
||||
def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
|
||||
"""
|
||||
TrainWrap
|
||||
"""
|
||||
if loss_fn is None:
|
||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
loss_net = nn.WithLossCell(net, loss_fn)
|
||||
loss_net.set_train()
|
||||
if weights is None:
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
if optimizer is None:
|
||||
optimizer = nn.Adam(weights, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0)
|
||||
train_net = nn.TrainOneStepCell(loss_net, optimizer)
|
||||
return train_net
|
|
@ -0,0 +1,82 @@
|
|||
#!/bin/bash
|
||||
|
||||
display_usage() {
|
||||
echo -e "\nUsage: prepare_and_run.sh dataset_path [mindspore_docker] [release.tar.gz]\n"
|
||||
}
|
||||
|
||||
if [ -n "$1" ]; then
|
||||
MNIST_DATA_PATH=$1
|
||||
else
|
||||
echo "MNIST Dataset directory path was not provided"
|
||||
display_usage
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ -n "$2" ]; then
|
||||
DOCKER=$2
|
||||
else
|
||||
DOCKER=""
|
||||
#echo "MindSpore docker was not provided"
|
||||
#display_usage
|
||||
#exit 0
|
||||
fi
|
||||
|
||||
if [ -n "$3" ]; then
|
||||
TARBALL=$3
|
||||
else
|
||||
if [ -f ../../../../output/mindspore-lite-*-runtime-arm64-cpu-train.tar.gz ]; then
|
||||
TARBALL="../../../../output/mindspore-lite-*-runtime-arm64-cpu-train.tar.gz"
|
||||
else
|
||||
echo "release.tar.gz was not found"
|
||||
display_usage
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
# Prepare the model
|
||||
cd model/
|
||||
rm -f *.ms
|
||||
./prepare_model.sh $DOCKER
|
||||
cd -
|
||||
|
||||
# Copy the .ms model to the package folder
|
||||
rm -rf package
|
||||
mkdir -p package/model
|
||||
cp model/*.ms package/model
|
||||
|
||||
# Copy the running script to the package
|
||||
cp scripts/train.sh package/
|
||||
cp scripts/eval.sh package/
|
||||
|
||||
# Copy the shared MindSpore ToD library
|
||||
tar -xzvf ${TARBALL} --wildcards --no-anchored libmindspore-lite.so
|
||||
tar -xzvf ${TARBALL} --wildcards --no-anchored include
|
||||
mv mindspore-*/lib package/
|
||||
mkdir msl
|
||||
mv mindspore-*/* msl/
|
||||
rm -rf mindspore-*
|
||||
|
||||
# Copy the dataset to the package
|
||||
cp -r ${MNIST_DATA_PATH} package/dataset
|
||||
|
||||
# Compile program
|
||||
make TARGET=arm64
|
||||
|
||||
# Copy the executable to the package
|
||||
mv bin package/
|
||||
|
||||
# Push the folder to the device
|
||||
adb push package /data/local/tmp/
|
||||
|
||||
echo "Training on Device"
|
||||
adb shell < scripts/run_train.sh
|
||||
|
||||
echo
|
||||
echo "Load trained model and evaluate accuracy"
|
||||
adb shell < scripts/run_eval.sh
|
||||
echo
|
||||
|
||||
#rm -rf src/*.o package model/__pycache__ model/*.ms
|
||||
|
||||
#./prepare_and_run.sh /opt/share/dataset/mnist mindspore_dev:5
|
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
DATA_PATH=$1
|
||||
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod_trained_3000.ms -e 0 -d dataset
|
|
@ -0,0 +1,2 @@
|
|||
cd /data/local/tmp/package
|
||||
/system/bin/sh eval.sh
|
|
@ -0,0 +1,2 @@
|
|||
cd /data/local/tmp/package
|
||||
/system/bin/sh train.sh
|
|
@ -0,0 +1,21 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
DATA_PATH=$1
|
||||
LD_LIBRARY_PATH=./lib/ bin/net_runner -f model/lenet_tod.ms -e 3000 -d dataset
|
|
@ -0,0 +1,200 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/dataset.h"
|
||||
#include <assert.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <filesystem>
|
||||
|
||||
using LabelId = std::map<std::string, int>;
|
||||
|
||||
char *ReadFile(const std::string &file, size_t *size) {
|
||||
assert(size != nullptr);
|
||||
std::string realPath(file);
|
||||
std::ifstream ifs(realPath);
|
||||
if (!ifs.good()) {
|
||||
std::cerr << "file: " << realPath << " does not exist";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cerr << "file: " << realPath << " open failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
*size = ifs.tellg();
|
||||
std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
|
||||
if (buf == nullptr) {
|
||||
std::cerr << "malloc buf failed, file: " << realPath;
|
||||
ifs.close();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(buf.get(), *size);
|
||||
ifs.close();
|
||||
|
||||
return buf.release();
|
||||
}
|
||||
|
||||
DataSet::~DataSet() {
|
||||
for (auto itr = train_data_.begin(); itr != train_data_.end(); ++itr) {
|
||||
auto ptr = std::get<0>(*itr);
|
||||
delete[] ptr;
|
||||
}
|
||||
for (auto itr = test_data_.begin(); itr != test_data_.end(); ++itr) {
|
||||
auto ptr = std::get<0>(*itr);
|
||||
delete[] ptr;
|
||||
}
|
||||
}
|
||||
|
||||
int DataSet::Init(const std::string &data_base_directory, database_type type) {
|
||||
InitializeMNISTDatabase(data_base_directory);
|
||||
return 0;
|
||||
}
|
||||
|
||||
void DataSet::InitializeMNISTDatabase(std::string dpath) {
|
||||
// int total_data = 0;
|
||||
num_of_classes_ = 10;
|
||||
// total_data +=
|
||||
ReadMNISTFile(dpath + "/train/train-images-idx3-ubyte", dpath + "/train/train-labels-idx1-ubyte", &train_data_);
|
||||
// total_data +=
|
||||
ReadMNISTFile(dpath + "/test/t10k-images-idx3-ubyte", dpath + "/test/t10k-labels-idx1-ubyte", &test_data_);
|
||||
}
|
||||
|
||||
int DataSet::ReadMNISTFile(const std::string &ifile_name, const std::string &lfile_name,
|
||||
std::vector<DataLabelTuple> *dataset) {
|
||||
std::ifstream lfile(lfile_name, std::ios::binary);
|
||||
if (!lfile.is_open()) {
|
||||
std::cerr << "Cannot open label file " << lfile_name << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::ifstream ifile(ifile_name, std::ios::binary);
|
||||
if (!ifile.is_open()) {
|
||||
std::cerr << "Cannot open data file " << ifile_name << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int magic_number = 0;
|
||||
lfile.read(reinterpret_cast<char *>(&magic_number), sizeof(magic_number));
|
||||
magic_number = ntohl(magic_number);
|
||||
if (magic_number != 2049) {
|
||||
std::cout << "Invalid MNIST label file!" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int number_of_labels = 0;
|
||||
lfile.read(reinterpret_cast<char *>(&number_of_labels), sizeof(number_of_labels));
|
||||
number_of_labels = ntohl(number_of_labels);
|
||||
|
||||
ifile.read(reinterpret_cast<char *>(&magic_number), sizeof(magic_number));
|
||||
magic_number = ntohl(magic_number);
|
||||
if (magic_number != 2051) {
|
||||
std::cout << "Invalid MNIST image file!" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int number_of_images = 0;
|
||||
ifile.read(reinterpret_cast<char *>(&number_of_images), sizeof(number_of_images));
|
||||
number_of_images = ntohl(number_of_images);
|
||||
|
||||
int n_rows = 0;
|
||||
ifile.read(reinterpret_cast<char *>(&n_rows), sizeof(n_rows));
|
||||
n_rows = ntohl(n_rows);
|
||||
|
||||
int n_cols = 0;
|
||||
ifile.read(reinterpret_cast<char *>(&n_cols), sizeof(n_cols));
|
||||
n_cols = ntohl(n_cols);
|
||||
|
||||
if (number_of_labels != number_of_images) {
|
||||
std::cout << "number of records in labels and images files does not match" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int image_size = n_rows * n_cols;
|
||||
unsigned char labels[number_of_labels];
|
||||
unsigned char data[image_size];
|
||||
lfile.read(reinterpret_cast<char *>(labels), number_of_labels);
|
||||
|
||||
for (int i = 0; i < number_of_labels; ++i) {
|
||||
std::unique_ptr<float[]> hwc_bin_image(new (std::nothrow) float[32 * 32]);
|
||||
ifile.read(reinterpret_cast<char *>(data), image_size);
|
||||
|
||||
for (size_t r = 0; r < 32; r++) {
|
||||
for (size_t c = 0; c < 32; c++) {
|
||||
if (r < 2 || r > 29 || c < 2 || c > 29)
|
||||
hwc_bin_image[r * 32 + c] = 0.0;
|
||||
else
|
||||
hwc_bin_image[r * 32 + c] = (static_cast<float>(data[(r - 2) * 28 + (c - 2)])) / 255.0;
|
||||
}
|
||||
}
|
||||
DataLabelTuple data_entry = std::make_tuple(reinterpret_cast<char *>(hwc_bin_image.release()), labels[i]);
|
||||
dataset->push_back(data_entry);
|
||||
}
|
||||
return number_of_labels;
|
||||
}
|
||||
|
||||
std::vector<FileTuple> DataSet::ReadFileList(std::string dpath) {
|
||||
std::vector<FileTuple> vec;
|
||||
std::ifstream ifs(dpath + "/file_list.txt");
|
||||
std::string file_name;
|
||||
if (ifs.is_open()) {
|
||||
int label;
|
||||
while (!ifs.eof()) {
|
||||
ifs >> label >> file_name;
|
||||
vec.push_back(make_tuple(label, file_name));
|
||||
}
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
||||
std::vector<FileTuple> DataSet::ReadDir(const std::string dpath) {
|
||||
std::filesystem::directory_iterator dir(dpath);
|
||||
std::vector<FileTuple> vec;
|
||||
LabelId label_id;
|
||||
int class_id = 0;
|
||||
int class_label;
|
||||
for (const auto p : dir) {
|
||||
if (p.is_directory()) {
|
||||
std::string path = p.path().stem().string();
|
||||
auto label = label_id.find(path);
|
||||
if (label == label_id.end()) {
|
||||
label_id[path] = class_id;
|
||||
class_label = class_id;
|
||||
class_id++;
|
||||
num_of_classes_ = class_id;
|
||||
} else {
|
||||
class_label = label->second;
|
||||
}
|
||||
std::filesystem::directory_iterator ndir(dpath + "/" + path);
|
||||
for (const auto np : ndir) {
|
||||
if (np.path().extension().string() == ".bin") {
|
||||
std::string entry =
|
||||
dpath + "/" + np.path().parent_path().stem().string() + "/" + np.path().filename().string();
|
||||
FileTuple ft = make_tuple(class_label, entry);
|
||||
vec.push_back(ft);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return vec;
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_DATASET_H_
|
||||
#define MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_DATASET_H_
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using DataLabelTuple = std::tuple<char *, int>;
|
||||
using FileTuple = std::tuple<int, std::string>;
|
||||
|
||||
enum database_type { DS_CIFAR10_BINARY = 0, DS_MNIST_BINARY, DS_OTHER };
|
||||
|
||||
char *ReadFile(const std::string &file, size_t *size); // utility function
|
||||
|
||||
class DataSet {
|
||||
public:
|
||||
DataSet() {}
|
||||
~DataSet();
|
||||
|
||||
int Init(const std::string &data_base_directory, database_type type = DS_OTHER);
|
||||
|
||||
const std::vector<DataLabelTuple> &train_data() const { return train_data_; }
|
||||
const std::vector<DataLabelTuple> &test_data() const { return test_data_; }
|
||||
unsigned int num_of_classes() { return num_of_classes_; }
|
||||
void set_expected_data_size(unsigned int expected_data_size) { expected_data_size_ = expected_data_size; }
|
||||
unsigned int expected_data_size() { return expected_data_size_; }
|
||||
|
||||
private:
|
||||
std::vector<FileTuple> ReadFileList(std::string dpath);
|
||||
std::vector<FileTuple> ReadDir(const std::string dpath);
|
||||
int ReadMNISTFile(const std::string &ifile, const std::string &lfile, std::vector<DataLabelTuple> *dataset);
|
||||
void InitializeMNISTDatabase(std::string dpath);
|
||||
|
||||
std::vector<DataLabelTuple> train_data_;
|
||||
std::vector<DataLabelTuple> test_data_;
|
||||
unsigned int num_of_classes_ = 0;
|
||||
unsigned int expected_data_size_ = 0;
|
||||
};
|
||||
|
||||
#endif // MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_DATASET_H_
|
|
@ -0,0 +1,247 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/net_runner.h"
|
||||
#include <math.h>
|
||||
#include <getopt.h>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include "include/context.h"
|
||||
|
||||
unsigned int NetRunner::seed_ = time(NULL);
|
||||
// Definition of callback function after forwarding operator.
|
||||
bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
|
||||
const mindspore::CallBackParam &call_param) {
|
||||
printf("%s\n", call_param.node_name.c_str());
|
||||
for (size_t i = 0; i < after_inputs.size(); i++) {
|
||||
int num2p = (after_inputs.at(i)->ElementsNum());
|
||||
printf("in%zu(%d): ", i, num2p);
|
||||
if (num2p > 10) num2p = 10;
|
||||
if (after_inputs.at(i)->data_type() == mindspore::kNumberTypeInt32) {
|
||||
auto d = reinterpret_cast<int *>(after_inputs.at(i)->MutableData());
|
||||
for (int j = 0; j < num2p; j++) printf("%d, ", d[j]);
|
||||
} else {
|
||||
auto d = reinterpret_cast<float *>(after_inputs.at(i)->MutableData());
|
||||
for (int j = 0; j < num2p; j++) printf("%f, ", d[j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
for (size_t i = 0; i < after_outputs.size(); i++) {
|
||||
auto d = reinterpret_cast<float *>(after_outputs.at(i)->MutableData());
|
||||
int num2p = (after_outputs.at(i)->ElementsNum());
|
||||
printf("ou%zu(%d): ", i, num2p);
|
||||
if (num2p > 10) num2p = 10;
|
||||
for (int j = 0; j < num2p; j++) printf("%f, ", d[j]);
|
||||
printf("\n");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
NetRunner::~NetRunner() {
|
||||
if (session_ != nullptr) delete session_;
|
||||
}
|
||||
|
||||
void NetRunner::InitAndFigureInputs() {
|
||||
mindspore::lite::Context context;
|
||||
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
|
||||
context.thread_num_ = 1;
|
||||
|
||||
session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context);
|
||||
assert(nullptr != session_);
|
||||
|
||||
auto inputs = session_->GetInputs();
|
||||
assert(inputs.size() > 1);
|
||||
data_index_ = 0;
|
||||
label_index_ = 1;
|
||||
batch_size_ = inputs[data_index_]->shape()[0];
|
||||
data_size_ = inputs[data_index_]->Size() / batch_size_; // in bytes
|
||||
if (verbose_) {
|
||||
std::cout << "data size: " << data_size_ << std::endl << "batch size: " << batch_size_ << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
mindspore::tensor::MSTensor *NetRunner::SearchOutputsForSize(size_t size) const {
|
||||
auto outputs = session_->GetOutputs();
|
||||
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
|
||||
if (it->second->ElementsNum() == size) return it->second;
|
||||
}
|
||||
std::cout << "Model does not have an output tensor with size " << size << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dataset, bool serially) const {
|
||||
std::vector<int> labels_vec;
|
||||
static unsigned int idx = 1;
|
||||
int total_size = dataset.size();
|
||||
|
||||
auto inputs = session_->GetInputs();
|
||||
char *input_data = reinterpret_cast<char *>(inputs.at(data_index_)->MutableData());
|
||||
auto labels = reinterpret_cast<float *>(inputs.at(label_index_)->MutableData());
|
||||
assert(total_size > 0);
|
||||
assert(input_data != nullptr);
|
||||
std::fill(labels, labels + inputs.at(label_index_)->ElementsNum(), 0.f);
|
||||
for (int i = 0; i < batch_size_; i++) {
|
||||
if (serially) {
|
||||
idx = ++idx % total_size;
|
||||
} else {
|
||||
idx = rand_r(&seed_) % total_size;
|
||||
}
|
||||
int label = 0;
|
||||
char *data = nullptr;
|
||||
std::tie(data, label) = dataset[idx];
|
||||
memcpy(input_data + i * data_size_, data, data_size_);
|
||||
labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation
|
||||
labels_vec.push_back(label);
|
||||
}
|
||||
|
||||
return labels_vec;
|
||||
}
|
||||
|
||||
float NetRunner::CalculateAccuracy(int max_tests) const {
|
||||
float accuracy = 0.0;
|
||||
const std::vector<DataLabelTuple> test_set = ds_.test_data();
|
||||
int tests = test_set.size() / batch_size_;
|
||||
if (max_tests != -1 && tests < max_tests) tests = max_tests;
|
||||
|
||||
session_->Eval();
|
||||
for (int i = 0; i < tests; i++) {
|
||||
auto labels = FillInputData(test_set, (max_tests == -1));
|
||||
session_->RunGraph();
|
||||
auto outputsv = SearchOutputsForSize(batch_size_ * num_of_classes_);
|
||||
assert(outputsv != nullptr);
|
||||
auto scores = reinterpret_cast<float *>(outputsv->MutableData());
|
||||
for (int b = 0; b < batch_size_; b++) {
|
||||
int max_idx = 0;
|
||||
float max_score = scores[num_of_classes_ * b];
|
||||
for (int c = 0; c < num_of_classes_; c++) {
|
||||
if (scores[num_of_classes_ * b + c] > max_score) {
|
||||
max_score = scores[num_of_classes_ * b + c];
|
||||
max_idx = c;
|
||||
}
|
||||
}
|
||||
if (labels[b] == max_idx) accuracy += 1.0;
|
||||
}
|
||||
}
|
||||
session_->Train();
|
||||
accuracy /= static_cast<float>(batch_size_ * tests);
|
||||
return accuracy;
|
||||
}
|
||||
|
||||
int NetRunner::InitDB() {
|
||||
if (data_size_ != 0) ds_.set_expected_data_size(data_size_);
|
||||
int ret = ds_.Init(data_dir_, DS_MNIST_BINARY);
|
||||
num_of_classes_ = ds_.num_of_classes();
|
||||
if (ds_.test_data().size() == 0) {
|
||||
std::cout << "No relevant data was found in " << data_dir_ << std::endl;
|
||||
assert(ds_.test_data().size() != 0);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
float NetRunner::GetLoss() const {
|
||||
auto outputsv = SearchOutputsForSize(1); // Search for Loss which is a single value tensor
|
||||
assert(outputsv != nullptr);
|
||||
auto loss = reinterpret_cast<float *>(outputsv->MutableData());
|
||||
return loss[0];
|
||||
}
|
||||
|
||||
int NetRunner::TrainLoop() {
|
||||
session_->Train();
|
||||
float min_loss = 1000.;
|
||||
float max_acc = 0.;
|
||||
for (int i = 0; i < cycles_; i++) {
|
||||
FillInputData(ds_.train_data());
|
||||
session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr);
|
||||
float loss = GetLoss();
|
||||
if (min_loss > loss) min_loss = loss;
|
||||
|
||||
if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) {
|
||||
auto cpkt_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms";
|
||||
session_->SaveToFile(cpkt_fn);
|
||||
}
|
||||
|
||||
if ((i + 1) % 100 == 0) {
|
||||
float acc = CalculateAccuracy(10);
|
||||
if (max_acc < acc) max_acc = acc;
|
||||
std::cout << i + 1 << ":\tLoss is " << std::setw(7) << loss << " [min=" << min_loss << "] "
|
||||
<< " max_acc=" << max_acc << std::endl;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int NetRunner::Main() {
|
||||
InitAndFigureInputs();
|
||||
|
||||
InitDB();
|
||||
|
||||
TrainLoop();
|
||||
|
||||
float acc = CalculateAccuracy();
|
||||
std::cout << "accuracy = " << acc << std::endl;
|
||||
|
||||
if (cycles_ > 0) {
|
||||
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(cycles_) + ".ms";
|
||||
session_->SaveToFile(trained_fn);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void NetRunner::Usage() {
|
||||
std::cout << "Usage: net_runner -f <.ms model file> -d <data_dir> [-c <num of training cycles>] "
|
||||
<< "[-v (verbose mode)] [-s <save checkpoint every X iterations>]" << std::endl;
|
||||
}
|
||||
|
||||
bool NetRunner::ReadArgs(int argc, char *argv[]) {
|
||||
int opt;
|
||||
while ((opt = getopt(argc, argv, "f:e:d:s:ihc:v")) != -1) {
|
||||
switch (opt) {
|
||||
case 'f':
|
||||
ms_file_ = std::string(optarg);
|
||||
break;
|
||||
case 'e':
|
||||
cycles_ = atoi(optarg);
|
||||
break;
|
||||
case 'd':
|
||||
data_dir_ = std::string(optarg);
|
||||
break;
|
||||
case 'v':
|
||||
verbose_ = true;
|
||||
break;
|
||||
case 's':
|
||||
save_checkpoint_ = atoi(optarg);
|
||||
break;
|
||||
case 'h':
|
||||
default:
|
||||
Usage();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
NetRunner nr;
|
||||
|
||||
if (nr.ReadArgs(argc, argv)) {
|
||||
nr.Main();
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_NET_RUNNER_H_
|
||||
#define MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_NET_RUNNER_H_
|
||||
|
||||
#include <tuple>
|
||||
#include <filesystem>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "include/train_session.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "src/dataset.h"
|
||||
|
||||
class NetRunner {
|
||||
public:
|
||||
int Main();
|
||||
bool ReadArgs(int argc, char *argv[]);
|
||||
~NetRunner();
|
||||
|
||||
private:
|
||||
void Usage();
|
||||
void InitAndFigureInputs();
|
||||
int InitDB();
|
||||
int TrainLoop();
|
||||
std::vector<int> FillInputData(const std::vector<DataLabelTuple> &dataset, bool is_train_set = false) const;
|
||||
float CalculateAccuracy(int max_tests = -1) const;
|
||||
float GetLoss() const;
|
||||
mindspore::tensor::MSTensor *SearchOutputsForSize(size_t size) const;
|
||||
|
||||
DataSet ds_;
|
||||
mindspore::session::TrainSession *session_ = nullptr;
|
||||
|
||||
std::string ms_file_ = "";
|
||||
std::string data_dir_ = "";
|
||||
size_t data_size_ = 0;
|
||||
size_t batch_size_ = 0;
|
||||
unsigned int cycles_ = 100;
|
||||
int data_index_ = 0;
|
||||
int label_index_ = -1;
|
||||
int num_of_classes_ = 0;
|
||||
bool verbose_ = false;
|
||||
int save_checkpoint_ = 0;
|
||||
static unsigned int seed_;
|
||||
};
|
||||
|
||||
#endif // MODEL_ZOO_OFFICIAL_TOD_TRAIN_LENET_SRC_NET_RUNNER_H_
|
Loading…
Reference in New Issue