!17856 [MS][LITE][TOD] Fixed train_lenet_java example
Merge pull request !17856 from ehaleva/train_lenet_java_example
This commit is contained in:
commit
c13adb6e33
13
build.sh
13
build.sh
|
@ -526,6 +526,11 @@ build_lite_x86_64_jni_and_jar()
|
||||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
||||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
||||||
cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/
|
cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/
|
||||||
|
if [[ "X$is_train" = "Xon" ]]; then
|
||||||
|
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
||||||
|
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
||||||
|
cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/
|
||||||
|
fi
|
||||||
|
|
||||||
# build java common
|
# build java common
|
||||||
cd ${LITE_JAVA_PATH}/java/common
|
cd ${LITE_JAVA_PATH}/java/common
|
||||||
|
@ -695,6 +700,10 @@ build_lite_arm64_and_jni() {
|
||||||
fi
|
fi
|
||||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
|
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
|
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
|
||||||
|
if [[ "X$is_train" = "Xon" ]]; then
|
||||||
|
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||||
|
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
|
||||||
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
build_lite_arm32_and_jni() {
|
build_lite_arm32_and_jni() {
|
||||||
|
@ -736,6 +745,10 @@ build_lite_arm32_and_jni() {
|
||||||
fi
|
fi
|
||||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
|
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
|
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
|
||||||
|
if [[ "X$is_train" = "Xon" ]]; then
|
||||||
|
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||||
|
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
|
||||||
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
check_java_home() {
|
check_java_home() {
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
display_usage()
|
||||||
|
{
|
||||||
|
echo -e "\nUsage: build.sh [-r release.tar.gz]\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
checkopts()
|
||||||
|
{
|
||||||
|
TARBALL=""
|
||||||
|
while getopts 'r:' opt
|
||||||
|
do
|
||||||
|
case "${opt}" in
|
||||||
|
r)
|
||||||
|
TARBALL=$OPTARG
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option ${opt}!"
|
||||||
|
display_usage
|
||||||
|
exit 1
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
checkopts "$@"
|
||||||
|
|
||||||
|
BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
|
||||||
|
get_version() {
|
||||||
|
VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||||
|
VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||||
|
VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||||
|
VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION}
|
||||||
|
}
|
||||||
|
get_version
|
||||||
|
MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64"
|
||||||
|
MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz"
|
||||||
|
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}"
|
||||||
|
|
||||||
|
mkdir -p build/${MINDSPORE_FILE_NAME}
|
||||||
|
mkdir -p lib
|
||||||
|
|
||||||
|
if [ -n "$TARBALL" ]; then
|
||||||
|
cp ${TARBALL} ${BASEPATH}/build/${MINDSPORE_FILE}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then
|
||||||
|
wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL}
|
||||||
|
fi
|
||||||
|
|
||||||
|
tar xzvf ${BASEPATH}/build/${MINDSPORE_FILE} -C ${BASEPATH}/build/${MINDSPORE_FILE_NAME} --strip-components=1
|
||||||
|
|
||||||
|
cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/inference/lib/* ${BASEPATH}/lib
|
||||||
|
cd ${BASEPATH}/ || exit
|
||||||
|
|
||||||
|
mvn package
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""lenet_export."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import context, Tensor
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.train.serialization import export
|
||||||
|
from lenet import LeNet5
|
||||||
|
from train_utils import TrainWrap
|
||||||
|
|
||||||
|
n = LeNet5()
|
||||||
|
n.set_train()
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs=False)
|
||||||
|
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)
|
||||||
|
label = Tensor(np.zeros([BATCH_SIZE]).astype(np.int32))
|
||||||
|
net = TrainWrap(n)
|
||||||
|
export(net, x, label, file_name="lenet_tod", file_format='MINDIR')
|
||||||
|
|
||||||
|
print("finished exporting")
|
|
@ -0,0 +1,24 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
echo "============Exporting=========="
|
||||||
|
if [ -n "$1" ]; then
|
||||||
|
DOCKER_IMG=$1
|
||||||
|
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py; chmod 444 lenet_tod.mindir; rm -rf __pycache__"
|
||||||
|
else
|
||||||
|
echo "MindSpore docker was not provided, attempting to run locally"
|
||||||
|
PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f "$CONVERTER" ]; then
|
||||||
|
echo "converter_lite could not be found in MindSpore build directory nor in system path"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "============Converting========="
|
||||||
|
QUANT_OPTIONS=""
|
||||||
|
if [[ ! -z ${QUANTIZE} ]]; then
|
||||||
|
echo "Quantizing weights"
|
||||||
|
QUANT_OPTIONS="--quantType=WeightQuant --bitNum=8 --quantWeightSize=100 --quantWeightChannel=15"
|
||||||
|
fi
|
||||||
|
$CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""train_utils."""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.common.parameter import ParameterTuple
|
||||||
|
|
||||||
|
def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
|
||||||
|
"""
|
||||||
|
TrainWrap
|
||||||
|
"""
|
||||||
|
if loss_fn is None:
|
||||||
|
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean', sparse=True)
|
||||||
|
loss_net = nn.WithLossCell(net, loss_fn)
|
||||||
|
loss_net.set_train()
|
||||||
|
if weights is None:
|
||||||
|
weights = ParameterTuple(net.trainable_params())
|
||||||
|
if optimizer is None:
|
||||||
|
optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False,
|
||||||
|
use_nesterov=False, weight_decay=4e-5, loss_scale=1.0)
|
||||||
|
train_net = nn.TrainOneStepCell(loss_net, optimizer)
|
||||||
|
return train_net
|
|
@ -0,0 +1,71 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
display_usage()
|
||||||
|
{
|
||||||
|
echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz]\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
checkopts()
|
||||||
|
{
|
||||||
|
DOCKER=""
|
||||||
|
MNIST_DATA_PATH=""
|
||||||
|
while getopts 'D:d:r:' opt
|
||||||
|
do
|
||||||
|
case "${opt}" in
|
||||||
|
D)
|
||||||
|
MNIST_DATA_PATH=$OPTARG
|
||||||
|
;;
|
||||||
|
d)
|
||||||
|
DOCKER=$OPTARG
|
||||||
|
;;
|
||||||
|
r)
|
||||||
|
TARBALL="-r $OPTARG"
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option ${opt}!"
|
||||||
|
display_usage
|
||||||
|
exit 1
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
checkopts "$@"
|
||||||
|
if [ "$MNIST_DATA_PATH" == "" ]; then
|
||||||
|
echo "MNIST Dataset directory path was not provided"
|
||||||
|
display_usage
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
./build.sh $TARBALL
|
||||||
|
|
||||||
|
BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
|
||||||
|
|
||||||
|
cd model/ || exit 1
|
||||||
|
MSLITE_LINUX=$(ls -d ${BASEPATH}/build/mindspore-lite-*-linux-x64)
|
||||||
|
CONVERTER=${MSLITE_LINUX}/tools/converter/converter/converter_lite
|
||||||
|
rm -f *.ms
|
||||||
|
LD_LIBRARY_PATH=${MSLITE_LINUX}/tools/converter/lib/:${MSLITE_LINUX}/tools/converter/third_party/glog/lib
|
||||||
|
LD_LIBRARY_PATH=${LD_LIBRARY_PATH} CONVERTER=${CONVERTER} ./prepare_model.sh $DOCKER || exit 1
|
||||||
|
cd ../
|
||||||
|
|
||||||
|
cd target || exit 1
|
||||||
|
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:../lib/
|
||||||
|
java -Djava.library.path=../lib/ -classpath .:./train_lenet_java.jar:../lib/mindspore-lite-java.jar com.mindspore.lite.train_lenet.Main ../model/lenet_tod.ms $MNIST_DATA_PATH
|
||||||
|
cd -
|
||||||
|
|
|
@ -1,5 +1,20 @@
|
||||||
package com.mindspore.lite.train_lenet;
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.lite.train_lenet;
|
||||||
|
|
||||||
import java.io.BufferedInputStream;
|
import java.io.BufferedInputStream;
|
||||||
import java.io.FileInputStream;
|
import java.io.FileInputStream;
|
||||||
|
|
|
@ -1,9 +1,26 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
package com.mindspore.lite.train_lenet;
|
package com.mindspore.lite.train_lenet;
|
||||||
|
|
||||||
import com.mindspore.lite.Version;
|
import com.mindspore.lite.Version;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
|
System.loadLibrary("mindspore-lite-train-jni");
|
||||||
System.out.println(Version.version());
|
System.out.println(Version.version());
|
||||||
if (args.length < 2) {
|
if (args.length < 2) {
|
||||||
System.err.println("model path and dataset path must be provided.");
|
System.err.println("model path and dataset path must be provided.");
|
||||||
|
|
|
@ -1,7 +1,23 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
package com.mindspore.lite.train_lenet;
|
package com.mindspore.lite.train_lenet;
|
||||||
|
|
||||||
import com.mindspore.lite.MSTensor;
|
import com.mindspore.lite.MSTensor;
|
||||||
import com.mindspore.lite.TrainSession;
|
import com.mindspore.lite.LiteSession;
|
||||||
import com.mindspore.lite.config.MSConfig;
|
import com.mindspore.lite.config.MSConfig;
|
||||||
|
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
@ -13,13 +29,14 @@ import java.util.Vector;
|
||||||
public class NetRunner {
|
public class NetRunner {
|
||||||
private int dataIndex = 0;
|
private int dataIndex = 0;
|
||||||
private int labelIndex = 1;
|
private int labelIndex = 1;
|
||||||
private TrainSession session;
|
private LiteSession session;
|
||||||
private long batchSize;
|
private long batchSize;
|
||||||
private long dataSize; // one input data size, in byte
|
private long dataSize; // one input data size, in byte
|
||||||
private DataSet ds = new DataSet();
|
private DataSet ds = new DataSet();
|
||||||
private long numOfClasses;
|
private long numOfClasses;
|
||||||
private long cycles = 3500;
|
private long cycles = 3500;
|
||||||
private int idx = 1;
|
private int idx = 1;
|
||||||
|
private int virtualBatch = 16;
|
||||||
private String trainedFilePath = "trained.ms";
|
private String trainedFilePath = "trained.ms";
|
||||||
|
|
||||||
public void initAndFigureInputs(String modelPath) {
|
public void initAndFigureInputs(String modelPath) {
|
||||||
|
@ -29,9 +46,10 @@ public class NetRunner {
|
||||||
// arg 2: cpuBindMode:NO_BIND -> 0
|
// arg 2: cpuBindMode:NO_BIND -> 0
|
||||||
// arg 3: enable_fp16 -> false
|
// arg 3: enable_fp16 -> false
|
||||||
msConfig.init(0, 2, 0, false);
|
msConfig.init(0, 2, 0, false);
|
||||||
session = new TrainSession();
|
session = new LiteSession();
|
||||||
session.init(modelPath, msConfig);
|
System.out.println("Model path is " + modelPath);
|
||||||
session.setLearningRate(0.01f);
|
session = session.createTrainSession(modelPath, msConfig, false);
|
||||||
|
session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
|
||||||
|
|
||||||
List<MSTensor> inputs = session.getInputs();
|
List<MSTensor> inputs = session.getInputs();
|
||||||
if (inputs.size() <= 1) {
|
if (inputs.size() <= 1) {
|
||||||
|
@ -44,6 +62,7 @@ public class NetRunner {
|
||||||
batchSize = inputs.get(dataIndex).getShape()[0];
|
batchSize = inputs.get(dataIndex).getShape()[0];
|
||||||
dataSize = inputs.get(dataIndex).size() / batchSize;
|
dataSize = inputs.get(dataIndex).size() / batchSize;
|
||||||
System.out.println("batch_size: " + batchSize);
|
System.out.println("batch_size: " + batchSize);
|
||||||
|
System.out.println("virtual batch multiplier: " + virtualBatch);
|
||||||
int index = modelPath.lastIndexOf(".ms");
|
int index = modelPath.lastIndexOf(".ms");
|
||||||
if (index == -1) {
|
if (index == -1) {
|
||||||
System.out.println("The model " + modelPath + " should be named *.ms");
|
System.out.println("The model " + modelPath + " should be named *.ms");
|
||||||
|
@ -92,20 +111,21 @@ public class NetRunner {
|
||||||
float min_loss = 1000;
|
float min_loss = 1000;
|
||||||
float max_acc = 0;
|
float max_acc = 0;
|
||||||
for (int i = 0; i < cycles; i++) {
|
for (int i = 0; i < cycles; i++) {
|
||||||
fillInputData(ds.getTrainData(), false);
|
for (int b = 0; b < virtualBatch; b++) {
|
||||||
session.runGraph();
|
fillInputData(ds.getTrainData(), false);
|
||||||
float loss = getLoss();
|
session.runGraph();
|
||||||
if (min_loss > loss) {
|
float loss = getLoss();
|
||||||
min_loss = loss;
|
if (min_loss > loss) {
|
||||||
}
|
min_loss = loss;
|
||||||
if ((i + 1) % 500 == 0) {
|
}
|
||||||
float acc = calculateAccuracy(10); // only test 10 batch size
|
if ((b == 0) && ((i + 1) % 500 == 0)) {
|
||||||
if (max_acc < acc) {
|
float acc = calculateAccuracy(10); // only test 10 batch size
|
||||||
max_acc = acc;
|
if (max_acc < acc) {
|
||||||
|
max_acc = acc;
|
||||||
|
}
|
||||||
|
System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc);
|
||||||
}
|
}
|
||||||
System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -208,7 +228,10 @@ public class NetRunner {
|
||||||
System.out.println("accuracy = " + acc);
|
System.out.println("accuracy = " + acc);
|
||||||
|
|
||||||
if (cycles > 0) {
|
if (cycles > 0) {
|
||||||
if (session.saveToFile(trainedFilePath)) {
|
// arg 0: FileName
|
||||||
|
// arg 1: model type MT_TRAIN -> 0
|
||||||
|
// arg 2: quantization type QT_DEFAULT -> 0
|
||||||
|
if (session.export(trainedFilePath, 0, 0)) {
|
||||||
System.out.println("Trained model successfully saved: " + trainedFilePath);
|
System.out.println("Trained model successfully saved: " + trainedFilePath);
|
||||||
} else {
|
} else {
|
||||||
System.err.println("Save model error.");
|
System.err.println("Save model error.");
|
||||||
|
|
|
@ -59,14 +59,6 @@ set(JNI_SRC
|
||||||
|
|
||||||
set(LITE_SO_NAME mindspore-lite)
|
set(LITE_SO_NAME mindspore-lite)
|
||||||
|
|
||||||
if(SUPPORT_TRAIN)
|
|
||||||
set(JNI_SRC
|
|
||||||
${JNI_SRC}
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
|
|
||||||
)
|
|
||||||
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_library(mindspore-lite-jni SHARED ${JNI_SRC})
|
add_library(mindspore-lite-jni SHARED ${JNI_SRC})
|
||||||
|
|
||||||
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||||
|
@ -75,3 +67,16 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||||
else()
|
else()
|
||||||
target_link_libraries(mindspore-lite-jni ${LITE_SO_NAME})
|
target_link_libraries(mindspore-lite-jni ${LITE_SO_NAME})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(SUPPORT_TRAIN)
|
||||||
|
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
|
||||||
|
set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp)
|
||||||
|
add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})
|
||||||
|
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||||
|
find_library(log-lib log)
|
||||||
|
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
|
||||||
|
else()
|
||||||
|
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
|
@ -443,11 +443,7 @@ void LiteSession::FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kern
|
||||||
bool LiteSession::IfUseMindrtExecutor() {
|
bool LiteSession::IfUseMindrtExecutor() {
|
||||||
bool use_mindrt_run = true;
|
bool use_mindrt_run = true;
|
||||||
#ifdef ENABLE_MINDRT
|
#ifdef ENABLE_MINDRT
|
||||||
#ifdef SUPPORT_TRAIN
|
use_mindrt_run = (is_train_session_) ? false : true;
|
||||||
use_mindrt_run = false;
|
|
||||||
#else
|
|
||||||
use_mindrt_run = true;
|
|
||||||
#endif
|
|
||||||
#else
|
#else
|
||||||
use_mindrt_run = false;
|
use_mindrt_run = false;
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -334,7 +334,7 @@ int TrainSession::SetLearningRate(float learning_rate) {
|
||||||
}
|
}
|
||||||
for (auto kernel : this->train_kernels_) {
|
for (auto kernel : this->train_kernels_) {
|
||||||
if (IsOptimizer(kernel)) {
|
if (IsOptimizer(kernel)) {
|
||||||
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
|
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||||
auto ret = optimizer->SetLearningRate(learning_rate);
|
auto ret = optimizer->SetLearningRate(learning_rate);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
|
MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
|
||||||
|
@ -348,7 +348,7 @@ int TrainSession::SetLearningRate(float learning_rate) {
|
||||||
float TrainSession::GetLearningRate() {
|
float TrainSession::GetLearningRate() {
|
||||||
for (auto kernel : this->train_kernels_) {
|
for (auto kernel : this->train_kernels_) {
|
||||||
if (IsOptimizer(kernel)) {
|
if (IsOptimizer(kernel)) {
|
||||||
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
|
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||||
return optimizer->GetLearningRate();
|
return optimizer->GetLearningRate();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -363,7 +363,7 @@ int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr,
|
||||||
|
|
||||||
for (auto kernel : this->train_kernels_) {
|
for (auto kernel : this->train_kernels_) {
|
||||||
if (IsOptimizer(kernel)) {
|
if (IsOptimizer(kernel)) {
|
||||||
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
|
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||||
auto ret = optimizer->SetOptimizerMode(mod);
|
auto ret = optimizer->SetOptimizerMode(mod);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
|
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
|
||||||
|
@ -382,7 +382,7 @@ int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IsBN(kernel) && kernel->is_trainable()) {
|
if (IsBN(kernel) && kernel->is_trainable()) {
|
||||||
auto batchnorm = reinterpret_cast<kernel::BatchnormCPUKernel *>(kernel);
|
auto batchnorm = static_cast<kernel::BatchnormCPUKernel *>(kernel->kernel());
|
||||||
auto ret = RET_OK;
|
auto ret = RET_OK;
|
||||||
if (mod == kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
if (mod == kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
||||||
momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum;
|
momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum;
|
||||||
|
@ -409,7 +409,7 @@ int TrainSession::SetupVirtualBatch(int virtual_batch_multiplier, float lr, floa
|
||||||
int TrainSession::OptimizerStep() {
|
int TrainSession::OptimizerStep() {
|
||||||
for (auto kernel : this->train_kernels_) {
|
for (auto kernel : this->train_kernels_) {
|
||||||
if (IsOptimizer(kernel)) {
|
if (IsOptimizer(kernel)) {
|
||||||
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
|
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||||
auto ret = optimizer->OptimizerStep();
|
auto ret = optimizer->OptimizerStep();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << kernel->name() << " failed to do optimize step";
|
MS_LOG(ERROR) << kernel->name() << " failed to do optimize step";
|
||||||
|
|
Loading…
Reference in New Issue