!17856 [MS][LITE][TOD] Fixed train_lenet_java example
Merge pull request !17856 from ehaleva/train_lenet_java_example
This commit is contained in:
commit
c13adb6e33
13
build.sh
13
build.sh
|
@ -526,6 +526,11 @@ build_lite_x86_64_jni_and_jar()
|
|||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
||||
cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
||||
cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/
|
||||
fi
|
||||
|
||||
# build java common
|
||||
cd ${LITE_JAVA_PATH}/java/common
|
||||
|
@ -695,6 +700,10 @@ build_lite_arm64_and_jni() {
|
|||
fi
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
|
||||
fi
|
||||
}
|
||||
|
||||
build_lite_arm32_and_jni() {
|
||||
|
@ -736,6 +745,10 @@ build_lite_arm32_and_jni() {
|
|||
fi
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
fi
|
||||
}
|
||||
|
||||
check_java_home() {
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
display_usage()
|
||||
{
|
||||
echo -e "\nUsage: build.sh [-r release.tar.gz]\n"
|
||||
}
|
||||
|
||||
checkopts()
|
||||
{
|
||||
TARBALL=""
|
||||
while getopts 'r:' opt
|
||||
do
|
||||
case "${opt}" in
|
||||
r)
|
||||
TARBALL=$OPTARG
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option ${opt}!"
|
||||
display_usage
|
||||
exit 1
|
||||
esac
|
||||
done
|
||||
}
|
||||
|
||||
checkopts "$@"
|
||||
|
||||
BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
|
||||
get_version() {
|
||||
VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||
VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||
VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||
VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION}
|
||||
}
|
||||
get_version
|
||||
MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64"
|
||||
MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz"
|
||||
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}"
|
||||
|
||||
mkdir -p build/${MINDSPORE_FILE_NAME}
|
||||
mkdir -p lib
|
||||
|
||||
if [ -n "$TARBALL" ]; then
|
||||
cp ${TARBALL} ${BASEPATH}/build/${MINDSPORE_FILE}
|
||||
fi
|
||||
|
||||
if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then
|
||||
wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL}
|
||||
fi
|
||||
|
||||
tar xzvf ${BASEPATH}/build/${MINDSPORE_FILE} -C ${BASEPATH}/build/${MINDSPORE_FILE_NAME} --strip-components=1
|
||||
|
||||
cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/inference/lib/* ${BASEPATH}/lib
|
||||
cd ${BASEPATH}/ || exit
|
||||
|
||||
mvn package
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lenet_export."""
|
||||
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import export
|
||||
from lenet import LeNet5
|
||||
from train_utils import TrainWrap
|
||||
|
||||
n = LeNet5()
|
||||
n.set_train()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs=False)
|
||||
|
||||
BATCH_SIZE = 4
|
||||
x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)
|
||||
label = Tensor(np.zeros([BATCH_SIZE]).astype(np.int32))
|
||||
net = TrainWrap(n)
|
||||
export(net, x, label, file_name="lenet_tod", file_format='MINDIR')
|
||||
|
||||
print("finished exporting")
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
|
||||
echo "============Exporting=========="
|
||||
if [ -n "$1" ]; then
|
||||
DOCKER_IMG=$1
|
||||
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py; chmod 444 lenet_tod.mindir; rm -rf __pycache__"
|
||||
else
|
||||
echo "MindSpore docker was not provided, attempting to run locally"
|
||||
PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py
|
||||
fi
|
||||
|
||||
if [ ! -f "$CONVERTER" ]; then
|
||||
echo "converter_lite could not be found in MindSpore build directory nor in system path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "============Converting========="
|
||||
QUANT_OPTIONS=""
|
||||
if [[ ! -z ${QUANTIZE} ]]; then
|
||||
echo "Quantizing weights"
|
||||
QUANT_OPTIONS="--quantType=WeightQuant --bitNum=8 --quantWeightSize=100 --quantWeightChannel=15"
|
||||
fi
|
||||
$CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_utils."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
|
||||
def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
|
||||
"""
|
||||
TrainWrap
|
||||
"""
|
||||
if loss_fn is None:
|
||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean', sparse=True)
|
||||
loss_net = nn.WithLossCell(net, loss_fn)
|
||||
loss_net.set_train()
|
||||
if weights is None:
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
if optimizer is None:
|
||||
optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False,
|
||||
use_nesterov=False, weight_decay=4e-5, loss_scale=1.0)
|
||||
train_net = nn.TrainOneStepCell(loss_net, optimizer)
|
||||
return train_net
|
|
@ -0,0 +1,71 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
display_usage()
|
||||
{
|
||||
echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz]\n"
|
||||
}
|
||||
|
||||
checkopts()
|
||||
{
|
||||
DOCKER=""
|
||||
MNIST_DATA_PATH=""
|
||||
while getopts 'D:d:r:' opt
|
||||
do
|
||||
case "${opt}" in
|
||||
D)
|
||||
MNIST_DATA_PATH=$OPTARG
|
||||
;;
|
||||
d)
|
||||
DOCKER=$OPTARG
|
||||
;;
|
||||
r)
|
||||
TARBALL="-r $OPTARG"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option ${opt}!"
|
||||
display_usage
|
||||
exit 1
|
||||
esac
|
||||
done
|
||||
}
|
||||
|
||||
checkopts "$@"
|
||||
if [ "$MNIST_DATA_PATH" == "" ]; then
|
||||
echo "MNIST Dataset directory path was not provided"
|
||||
display_usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
./build.sh $TARBALL
|
||||
|
||||
BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
|
||||
|
||||
cd model/ || exit 1
|
||||
MSLITE_LINUX=$(ls -d ${BASEPATH}/build/mindspore-lite-*-linux-x64)
|
||||
CONVERTER=${MSLITE_LINUX}/tools/converter/converter/converter_lite
|
||||
rm -f *.ms
|
||||
LD_LIBRARY_PATH=${MSLITE_LINUX}/tools/converter/lib/:${MSLITE_LINUX}/tools/converter/third_party/glog/lib
|
||||
LD_LIBRARY_PATH=${LD_LIBRARY_PATH} CONVERTER=${CONVERTER} ./prepare_model.sh $DOCKER || exit 1
|
||||
cd ../
|
||||
|
||||
cd target || exit 1
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:../lib/
|
||||
java -Djava.library.path=../lib/ -classpath .:./train_lenet_java.jar:../lib/mindspore-lite-java.jar com.mindspore.lite.train_lenet.Main ../model/lenet_tod.ms $MNIST_DATA_PATH
|
||||
cd -
|
||||
|
|
@ -1,5 +1,20 @@
|
|||
package com.mindspore.lite.train_lenet;
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.lite.train_lenet;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.FileInputStream;
|
||||
|
|
|
@ -1,9 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.lite.train_lenet;
|
||||
|
||||
import com.mindspore.lite.Version;
|
||||
|
||||
public class Main {
|
||||
public static void main(String[] args) {
|
||||
System.loadLibrary("mindspore-lite-train-jni");
|
||||
System.out.println(Version.version());
|
||||
if (args.length < 2) {
|
||||
System.err.println("model path and dataset path must be provided.");
|
||||
|
|
|
@ -1,7 +1,23 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.lite.train_lenet;
|
||||
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.TrainSession;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -13,13 +29,14 @@ import java.util.Vector;
|
|||
public class NetRunner {
|
||||
private int dataIndex = 0;
|
||||
private int labelIndex = 1;
|
||||
private TrainSession session;
|
||||
private LiteSession session;
|
||||
private long batchSize;
|
||||
private long dataSize; // one input data size, in byte
|
||||
private DataSet ds = new DataSet();
|
||||
private long numOfClasses;
|
||||
private long cycles = 3500;
|
||||
private int idx = 1;
|
||||
private int virtualBatch = 16;
|
||||
private String trainedFilePath = "trained.ms";
|
||||
|
||||
public void initAndFigureInputs(String modelPath) {
|
||||
|
@ -29,9 +46,10 @@ public class NetRunner {
|
|||
// arg 2: cpuBindMode:NO_BIND -> 0
|
||||
// arg 3: enable_fp16 -> false
|
||||
msConfig.init(0, 2, 0, false);
|
||||
session = new TrainSession();
|
||||
session.init(modelPath, msConfig);
|
||||
session.setLearningRate(0.01f);
|
||||
session = new LiteSession();
|
||||
System.out.println("Model path is " + modelPath);
|
||||
session = session.createTrainSession(modelPath, msConfig, false);
|
||||
session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
|
||||
|
||||
List<MSTensor> inputs = session.getInputs();
|
||||
if (inputs.size() <= 1) {
|
||||
|
@ -44,6 +62,7 @@ public class NetRunner {
|
|||
batchSize = inputs.get(dataIndex).getShape()[0];
|
||||
dataSize = inputs.get(dataIndex).size() / batchSize;
|
||||
System.out.println("batch_size: " + batchSize);
|
||||
System.out.println("virtual batch multiplier: " + virtualBatch);
|
||||
int index = modelPath.lastIndexOf(".ms");
|
||||
if (index == -1) {
|
||||
System.out.println("The model " + modelPath + " should be named *.ms");
|
||||
|
@ -92,20 +111,21 @@ public class NetRunner {
|
|||
float min_loss = 1000;
|
||||
float max_acc = 0;
|
||||
for (int i = 0; i < cycles; i++) {
|
||||
fillInputData(ds.getTrainData(), false);
|
||||
session.runGraph();
|
||||
float loss = getLoss();
|
||||
if (min_loss > loss) {
|
||||
min_loss = loss;
|
||||
}
|
||||
if ((i + 1) % 500 == 0) {
|
||||
float acc = calculateAccuracy(10); // only test 10 batch size
|
||||
if (max_acc < acc) {
|
||||
max_acc = acc;
|
||||
for (int b = 0; b < virtualBatch; b++) {
|
||||
fillInputData(ds.getTrainData(), false);
|
||||
session.runGraph();
|
||||
float loss = getLoss();
|
||||
if (min_loss > loss) {
|
||||
min_loss = loss;
|
||||
}
|
||||
if ((b == 0) && ((i + 1) % 500 == 0)) {
|
||||
float acc = calculateAccuracy(10); // only test 10 batch size
|
||||
if (max_acc < acc) {
|
||||
max_acc = acc;
|
||||
}
|
||||
System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc);
|
||||
}
|
||||
System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc);
|
||||
}
|
||||
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
@ -208,7 +228,10 @@ public class NetRunner {
|
|||
System.out.println("accuracy = " + acc);
|
||||
|
||||
if (cycles > 0) {
|
||||
if (session.saveToFile(trainedFilePath)) {
|
||||
// arg 0: FileName
|
||||
// arg 1: model type MT_TRAIN -> 0
|
||||
// arg 2: quantization type QT_DEFAULT -> 0
|
||||
if (session.export(trainedFilePath, 0, 0)) {
|
||||
System.out.println("Trained model successfully saved: " + trainedFilePath);
|
||||
} else {
|
||||
System.err.println("Save model error.");
|
||||
|
|
|
@ -59,14 +59,6 @@ set(JNI_SRC
|
|||
|
||||
set(LITE_SO_NAME mindspore-lite)
|
||||
|
||||
if(SUPPORT_TRAIN)
|
||||
set(JNI_SRC
|
||||
${JNI_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
add_library(mindspore-lite-jni SHARED ${JNI_SRC})
|
||||
|
||||
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||
|
@ -75,3 +67,16 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
|||
else()
|
||||
target_link_libraries(mindspore-lite-jni ${LITE_SO_NAME})
|
||||
endif()
|
||||
|
||||
if(SUPPORT_TRAIN)
|
||||
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
|
||||
set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp)
|
||||
add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})
|
||||
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||
find_library(log-lib log)
|
||||
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
|
||||
else()
|
||||
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
@ -443,11 +443,7 @@ void LiteSession::FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kern
|
|||
bool LiteSession::IfUseMindrtExecutor() {
|
||||
bool use_mindrt_run = true;
|
||||
#ifdef ENABLE_MINDRT
|
||||
#ifdef SUPPORT_TRAIN
|
||||
use_mindrt_run = false;
|
||||
#else
|
||||
use_mindrt_run = true;
|
||||
#endif
|
||||
use_mindrt_run = (is_train_session_) ? false : true;
|
||||
#else
|
||||
use_mindrt_run = false;
|
||||
#endif
|
||||
|
|
|
@ -334,7 +334,7 @@ int TrainSession::SetLearningRate(float learning_rate) {
|
|||
}
|
||||
for (auto kernel : this->train_kernels_) {
|
||||
if (IsOptimizer(kernel)) {
|
||||
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
|
||||
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||
auto ret = optimizer->SetLearningRate(learning_rate);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
|
||||
|
@ -348,7 +348,7 @@ int TrainSession::SetLearningRate(float learning_rate) {
|
|||
float TrainSession::GetLearningRate() {
|
||||
for (auto kernel : this->train_kernels_) {
|
||||
if (IsOptimizer(kernel)) {
|
||||
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
|
||||
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||
return optimizer->GetLearningRate();
|
||||
}
|
||||
}
|
||||
|
@ -363,7 +363,7 @@ int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr,
|
|||
|
||||
for (auto kernel : this->train_kernels_) {
|
||||
if (IsOptimizer(kernel)) {
|
||||
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
|
||||
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||
auto ret = optimizer->SetOptimizerMode(mod);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
|
||||
|
@ -382,7 +382,7 @@ int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr,
|
|||
}
|
||||
|
||||
if (IsBN(kernel) && kernel->is_trainable()) {
|
||||
auto batchnorm = reinterpret_cast<kernel::BatchnormCPUKernel *>(kernel);
|
||||
auto batchnorm = static_cast<kernel::BatchnormCPUKernel *>(kernel->kernel());
|
||||
auto ret = RET_OK;
|
||||
if (mod == kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
||||
momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum;
|
||||
|
@ -409,7 +409,7 @@ int TrainSession::SetupVirtualBatch(int virtual_batch_multiplier, float lr, floa
|
|||
int TrainSession::OptimizerStep() {
|
||||
for (auto kernel : this->train_kernels_) {
|
||||
if (IsOptimizer(kernel)) {
|
||||
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
|
||||
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||
auto ret = optimizer->OptimizerStep();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << kernel->name() << " failed to do optimize step";
|
||||
|
|
Loading…
Reference in New Issue