!17856 [MS][LITE][TOD] Fixed train_lenet_java example

Merge pull request !17856 from ehaleva/train_lenet_java_example
This commit is contained in:
i-robot 2021-06-09 14:22:41 +08:00 committed by Gitee
commit c13adb6e33
12 changed files with 338 additions and 37 deletions

View File

@ -526,6 +526,11 @@ build_lite_x86_64_jni_and_jar()
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/ cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/
fi
# build java common # build java common
cd ${LITE_JAVA_PATH}/java/common cd ${LITE_JAVA_PATH}/java/common
@ -695,6 +700,10 @@ build_lite_arm64_and_jni() {
fi fi
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
fi
} }
build_lite_arm32_and_jni() { build_lite_arm32_and_jni() {
@ -736,6 +745,10 @@ build_lite_arm32_and_jni() {
fi fi
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
fi
} }
check_java_home() { check_java_home() {

View File

@ -0,0 +1,69 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
display_usage()
{
echo -e "\nUsage: build.sh [-r release.tar.gz]\n"
}
checkopts()
{
TARBALL=""
while getopts 'r:' opt
do
case "${opt}" in
r)
TARBALL=$OPTARG
;;
*)
echo "Unknown option ${opt}!"
display_usage
exit 1
esac
done
}
checkopts "$@"
BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
get_version() {
VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION}
}
get_version
MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64"
MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz"
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}"
mkdir -p build/${MINDSPORE_FILE_NAME}
mkdir -p lib
if [ -n "$TARBALL" ]; then
cp ${TARBALL} ${BASEPATH}/build/${MINDSPORE_FILE}
fi
if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then
wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL}
fi
tar xzvf ${BASEPATH}/build/${MINDSPORE_FILE} -C ${BASEPATH}/build/${MINDSPORE_FILE_NAME} --strip-components=1
cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/inference/lib/* ${BASEPATH}/lib
cd ${BASEPATH}/ || exit
mvn package

View File

@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lenet_export."""
import numpy as np
from mindspore import context, Tensor
import mindspore.common.dtype as mstype
from mindspore.train.serialization import export
from lenet import LeNet5
from train_utils import TrainWrap
n = LeNet5()
n.set_train()
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs=False)
BATCH_SIZE = 4
x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)
label = Tensor(np.zeros([BATCH_SIZE]).astype(np.int32))
net = TrainWrap(n)
export(net, x, label, file_name="lenet_tod", file_format='MINDIR')
print("finished exporting")

View File

@ -0,0 +1,24 @@
#!/bin/bash
echo "============Exporting=========="
if [ -n "$1" ]; then
DOCKER_IMG=$1
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py; chmod 444 lenet_tod.mindir; rm -rf __pycache__"
else
echo "MindSpore docker was not provided, attempting to run locally"
PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py
fi
if [ ! -f "$CONVERTER" ]; then
echo "converter_lite could not be found in MindSpore build directory nor in system path"
exit 1
fi
echo "============Converting========="
QUANT_OPTIONS=""
if [[ ! -z ${QUANTIZE} ]]; then
echo "Quantizing weights"
QUANT_OPTIONS="--quantType=WeightQuant --bitNum=8 --quantWeightSize=100 --quantWeightChannel=15"
fi
$CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS

View File

@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_utils."""
import mindspore.nn as nn
from mindspore.common.parameter import ParameterTuple
def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
"""
TrainWrap
"""
if loss_fn is None:
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean', sparse=True)
loss_net = nn.WithLossCell(net, loss_fn)
loss_net.set_train()
if weights is None:
weights = ParameterTuple(net.trainable_params())
if optimizer is None:
optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False,
use_nesterov=False, weight_decay=4e-5, loss_scale=1.0)
train_net = nn.TrainOneStepCell(loss_net, optimizer)
return train_net

View File

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#!/bin/bash
display_usage()
{
echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz]\n"
}
checkopts()
{
DOCKER=""
MNIST_DATA_PATH=""
while getopts 'D:d:r:' opt
do
case "${opt}" in
D)
MNIST_DATA_PATH=$OPTARG
;;
d)
DOCKER=$OPTARG
;;
r)
TARBALL="-r $OPTARG"
;;
*)
echo "Unknown option ${opt}!"
display_usage
exit 1
esac
done
}
checkopts "$@"
if [ "$MNIST_DATA_PATH" == "" ]; then
echo "MNIST Dataset directory path was not provided"
display_usage
exit 1
fi
./build.sh $TARBALL
BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
cd model/ || exit 1
MSLITE_LINUX=$(ls -d ${BASEPATH}/build/mindspore-lite-*-linux-x64)
CONVERTER=${MSLITE_LINUX}/tools/converter/converter/converter_lite
rm -f *.ms
LD_LIBRARY_PATH=${MSLITE_LINUX}/tools/converter/lib/:${MSLITE_LINUX}/tools/converter/third_party/glog/lib
LD_LIBRARY_PATH=${LD_LIBRARY_PATH} CONVERTER=${CONVERTER} ./prepare_model.sh $DOCKER || exit 1
cd ../
cd target || exit 1
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:../lib/
java -Djava.library.path=../lib/ -classpath .:./train_lenet_java.jar:../lib/mindspore-lite-java.jar com.mindspore.lite.train_lenet.Main ../model/lenet_tod.ms $MNIST_DATA_PATH
cd -

View File

@ -1,5 +1,20 @@
package com.mindspore.lite.train_lenet; /**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.lite.train_lenet;
import java.io.BufferedInputStream; import java.io.BufferedInputStream;
import java.io.FileInputStream; import java.io.FileInputStream;

View File

@ -1,9 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.lite.train_lenet; package com.mindspore.lite.train_lenet;
import com.mindspore.lite.Version; import com.mindspore.lite.Version;
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
System.loadLibrary("mindspore-lite-train-jni");
System.out.println(Version.version()); System.out.println(Version.version());
if (args.length < 2) { if (args.length < 2) {
System.err.println("model path and dataset path must be provided."); System.err.println("model path and dataset path must be provided.");

View File

@ -1,7 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.lite.train_lenet; package com.mindspore.lite.train_lenet;
import com.mindspore.lite.MSTensor; import com.mindspore.lite.MSTensor;
import com.mindspore.lite.TrainSession; import com.mindspore.lite.LiteSession;
import com.mindspore.lite.config.MSConfig; import com.mindspore.lite.config.MSConfig;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -13,13 +29,14 @@ import java.util.Vector;
public class NetRunner { public class NetRunner {
private int dataIndex = 0; private int dataIndex = 0;
private int labelIndex = 1; private int labelIndex = 1;
private TrainSession session; private LiteSession session;
private long batchSize; private long batchSize;
private long dataSize; // one input data size, in byte private long dataSize; // one input data size, in byte
private DataSet ds = new DataSet(); private DataSet ds = new DataSet();
private long numOfClasses; private long numOfClasses;
private long cycles = 3500; private long cycles = 3500;
private int idx = 1; private int idx = 1;
private int virtualBatch = 16;
private String trainedFilePath = "trained.ms"; private String trainedFilePath = "trained.ms";
public void initAndFigureInputs(String modelPath) { public void initAndFigureInputs(String modelPath) {
@ -29,9 +46,10 @@ public class NetRunner {
// arg 2: cpuBindMode:NO_BIND -> 0 // arg 2: cpuBindMode:NO_BIND -> 0
// arg 3: enable_fp16 -> false // arg 3: enable_fp16 -> false
msConfig.init(0, 2, 0, false); msConfig.init(0, 2, 0, false);
session = new TrainSession(); session = new LiteSession();
session.init(modelPath, msConfig); System.out.println("Model path is " + modelPath);
session.setLearningRate(0.01f); session = session.createTrainSession(modelPath, msConfig, false);
session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
List<MSTensor> inputs = session.getInputs(); List<MSTensor> inputs = session.getInputs();
if (inputs.size() <= 1) { if (inputs.size() <= 1) {
@ -44,6 +62,7 @@ public class NetRunner {
batchSize = inputs.get(dataIndex).getShape()[0]; batchSize = inputs.get(dataIndex).getShape()[0];
dataSize = inputs.get(dataIndex).size() / batchSize; dataSize = inputs.get(dataIndex).size() / batchSize;
System.out.println("batch_size: " + batchSize); System.out.println("batch_size: " + batchSize);
System.out.println("virtual batch multiplier: " + virtualBatch);
int index = modelPath.lastIndexOf(".ms"); int index = modelPath.lastIndexOf(".ms");
if (index == -1) { if (index == -1) {
System.out.println("The model " + modelPath + " should be named *.ms"); System.out.println("The model " + modelPath + " should be named *.ms");
@ -92,20 +111,21 @@ public class NetRunner {
float min_loss = 1000; float min_loss = 1000;
float max_acc = 0; float max_acc = 0;
for (int i = 0; i < cycles; i++) { for (int i = 0; i < cycles; i++) {
fillInputData(ds.getTrainData(), false); for (int b = 0; b < virtualBatch; b++) {
session.runGraph(); fillInputData(ds.getTrainData(), false);
float loss = getLoss(); session.runGraph();
if (min_loss > loss) { float loss = getLoss();
min_loss = loss; if (min_loss > loss) {
} min_loss = loss;
if ((i + 1) % 500 == 0) { }
float acc = calculateAccuracy(10); // only test 10 batch size if ((b == 0) && ((i + 1) % 500 == 0)) {
if (max_acc < acc) { float acc = calculateAccuracy(10); // only test 10 batch size
max_acc = acc; if (max_acc < acc) {
max_acc = acc;
}
System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc);
} }
System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc);
} }
} }
return 0; return 0;
} }
@ -208,7 +228,10 @@ public class NetRunner {
System.out.println("accuracy = " + acc); System.out.println("accuracy = " + acc);
if (cycles > 0) { if (cycles > 0) {
if (session.saveToFile(trainedFilePath)) { // arg 0: FileName
// arg 1: model type MT_TRAIN -> 0
// arg 2: quantization type QT_DEFAULT -> 0
if (session.export(trainedFilePath, 0, 0)) {
System.out.println("Trained model successfully saved: " + trainedFilePath); System.out.println("Trained model successfully saved: " + trainedFilePath);
} else { } else {
System.err.println("Save model error."); System.err.println("Save model error.");

View File

@ -59,14 +59,6 @@ set(JNI_SRC
set(LITE_SO_NAME mindspore-lite) set(LITE_SO_NAME mindspore-lite)
if(SUPPORT_TRAIN)
set(JNI_SRC
${JNI_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
)
endif()
add_library(mindspore-lite-jni SHARED ${JNI_SRC}) add_library(mindspore-lite-jni SHARED ${JNI_SRC})
if(PLATFORM_ARM64 OR PLATFORM_ARM32) if(PLATFORM_ARM64 OR PLATFORM_ARM32)
@ -75,3 +67,16 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32)
else() else()
target_link_libraries(mindspore-lite-jni ${LITE_SO_NAME}) target_link_libraries(mindspore-lite-jni ${LITE_SO_NAME})
endif() endif()
if(SUPPORT_TRAIN)
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp)
add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
find_library(log-lib log)
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
else()
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME})
endif()
endif()

View File

@ -443,11 +443,7 @@ void LiteSession::FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kern
bool LiteSession::IfUseMindrtExecutor() { bool LiteSession::IfUseMindrtExecutor() {
bool use_mindrt_run = true; bool use_mindrt_run = true;
#ifdef ENABLE_MINDRT #ifdef ENABLE_MINDRT
#ifdef SUPPORT_TRAIN use_mindrt_run = (is_train_session_) ? false : true;
use_mindrt_run = false;
#else
use_mindrt_run = true;
#endif
#else #else
use_mindrt_run = false; use_mindrt_run = false;
#endif #endif

View File

@ -334,7 +334,7 @@ int TrainSession::SetLearningRate(float learning_rate) {
} }
for (auto kernel : this->train_kernels_) { for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) { if (IsOptimizer(kernel)) {
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel); auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
auto ret = optimizer->SetLearningRate(learning_rate); auto ret = optimizer->SetLearningRate(learning_rate);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set learning rate"; MS_LOG(ERROR) << kernel->name() << " failed to set learning rate";
@ -348,7 +348,7 @@ int TrainSession::SetLearningRate(float learning_rate) {
float TrainSession::GetLearningRate() { float TrainSession::GetLearningRate() {
for (auto kernel : this->train_kernels_) { for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) { if (IsOptimizer(kernel)) {
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel); auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
return optimizer->GetLearningRate(); return optimizer->GetLearningRate();
} }
} }
@ -363,7 +363,7 @@ int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr,
for (auto kernel : this->train_kernels_) { for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) { if (IsOptimizer(kernel)) {
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel); auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
auto ret = optimizer->SetOptimizerMode(mod); auto ret = optimizer->SetOptimizerMode(mod);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode"; MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
@ -382,7 +382,7 @@ int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr,
} }
if (IsBN(kernel) && kernel->is_trainable()) { if (IsBN(kernel) && kernel->is_trainable()) {
auto batchnorm = reinterpret_cast<kernel::BatchnormCPUKernel *>(kernel); auto batchnorm = static_cast<kernel::BatchnormCPUKernel *>(kernel->kernel());
auto ret = RET_OK; auto ret = RET_OK;
if (mod == kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) { if (mod == kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum; momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum;
@ -409,7 +409,7 @@ int TrainSession::SetupVirtualBatch(int virtual_batch_multiplier, float lr, floa
int TrainSession::OptimizerStep() { int TrainSession::OptimizerStep() {
for (auto kernel : this->train_kernels_) { for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) { if (IsOptimizer(kernel)) {
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel); auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
auto ret = optimizer->OptimizerStep(); auto ret = optimizer->OptimizerStep();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to do optimize step"; MS_LOG(ERROR) << kernel->name() << " failed to do optimize step";