From 0d53477912de0d46fd78be439ebff98b8cb0e852 Mon Sep 17 00:00:00 2001 From: Emir Haleva Date: Sun, 6 Jun 2021 18:15:16 +0300 Subject: [PATCH] train_lenet_jav to meet new API --- build.sh | 13 ++++ .../lite/examples/train_lenet_java/build.sh | 69 ++++++++++++++++++ .../train_lenet_java/model/lenet_export.py | 34 +++++++++ .../train_lenet_java/model/prepare_model.sh | 24 +++++++ .../train_lenet_java/model/train_utils.py | 34 +++++++++ .../train_lenet_java/prepare_and_run.sh | 71 +++++++++++++++++++ .../mindspore/lite/train_lenet/DataSet.java | 17 ++++- .../com/mindspore/lite/train_lenet/Main.java | 17 +++++ .../mindspore/lite/train_lenet/NetRunner.java | 59 ++++++++++----- mindspore/lite/java/native/CMakeLists.txt | 21 +++--- mindspore/lite/src/lite_session.cc | 6 +- mindspore/lite/src/train/train_session.cc | 10 +-- 12 files changed, 338 insertions(+), 37 deletions(-) create mode 100755 mindspore/lite/examples/train_lenet_java/build.sh create mode 100644 mindspore/lite/examples/train_lenet_java/model/lenet_export.py create mode 100755 mindspore/lite/examples/train_lenet_java/model/prepare_model.sh create mode 100644 mindspore/lite/examples/train_lenet_java/model/train_utils.py create mode 100755 mindspore/lite/examples/train_lenet_java/prepare_and_run.sh diff --git a/build.sh b/build.sh index 3cc04904854..b63aa1b52e6 100755 --- a/build.sh +++ b/build.sh @@ -526,6 +526,11 @@ build_lite_x86_64_jni_and_jar() cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/ cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/ + if [[ "X$is_train" = "Xon" ]]; then + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/ + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/ + cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/ + fi # build java common cd ${LITE_JAVA_PATH}/java/common @@ -677,6 +682,10 @@ build_lite_arm64_and_jni() { fi cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + if [[ "X$is_train" = "Xon" ]]; then + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + fi } build_lite_arm32_and_jni() { @@ -718,6 +727,10 @@ build_lite_arm32_and_jni() { fi cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + if [[ "X$is_train" = "Xon" ]]; then + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + fi } check_java_home() { diff --git a/mindspore/lite/examples/train_lenet_java/build.sh b/mindspore/lite/examples/train_lenet_java/build.sh new file mode 100755 index 00000000000..79af76b0ef9 --- /dev/null +++ b/mindspore/lite/examples/train_lenet_java/build.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +display_usage() +{ + echo -e "\nUsage: build.sh [-r release.tar.gz]\n" +} + +checkopts() +{ + TARBALL="" + while getopts 'r:' opt + do + case "${opt}" in + r) + TARBALL=$OPTARG + ;; + *) + echo "Unknown option ${opt}!" + display_usage + exit 1 + esac + done +} + +checkopts "$@" + +BASEPATH=$(cd "$(dirname $0)" || exit; pwd) +get_version() { + VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]") + VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]") + VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]") + VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION} +} +get_version +MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64" +MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz" +MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}" + +mkdir -p build/${MINDSPORE_FILE_NAME} +mkdir -p lib + +if [ -n "$TARBALL" ]; then + cp ${TARBALL} ${BASEPATH}/build/${MINDSPORE_FILE} +fi + +if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then + wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL} +fi + +tar xzvf ${BASEPATH}/build/${MINDSPORE_FILE} -C ${BASEPATH}/build/${MINDSPORE_FILE_NAME} --strip-components=1 + +cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/inference/lib/* ${BASEPATH}/lib +cd ${BASEPATH}/ || exit + +mvn package diff --git a/mindspore/lite/examples/train_lenet_java/model/lenet_export.py b/mindspore/lite/examples/train_lenet_java/model/lenet_export.py new file mode 100644 index 00000000000..15bdd21fd47 --- /dev/null +++ b/mindspore/lite/examples/train_lenet_java/model/lenet_export.py @@ -0,0 +1,34 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lenet_export.""" + +import numpy as np +from mindspore import context, Tensor +import mindspore.common.dtype as mstype +from mindspore.train.serialization import export +from lenet import LeNet5 +from train_utils import TrainWrap + +n = LeNet5() +n.set_train() +context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs=False) + +BATCH_SIZE = 4 +x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32) +label = Tensor(np.zeros([BATCH_SIZE]).astype(np.int32)) +net = TrainWrap(n) +export(net, x, label, file_name="lenet_tod", file_format='MINDIR') + +print("finished exporting") diff --git a/mindspore/lite/examples/train_lenet_java/model/prepare_model.sh b/mindspore/lite/examples/train_lenet_java/model/prepare_model.sh new file mode 100755 index 00000000000..e1ea1ab252f --- /dev/null +++ b/mindspore/lite/examples/train_lenet_java/model/prepare_model.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +echo "============Exporting==========" +if [ -n "$1" ]; then + DOCKER_IMG=$1 + docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py; chmod 444 lenet_tod.mindir; rm -rf __pycache__" +else + echo "MindSpore docker was not provided, attempting to run locally" + PYTHONPATH=../../../../../model_zoo/official/cv/lenet/src python lenet_export.py +fi + +if [ ! -f "$CONVERTER" ]; then + echo "converter_lite could not be found in MindSpore build directory nor in system path" + exit 1 +fi + +echo "============Converting=========" +QUANT_OPTIONS="" +if [[ ! -z ${QUANTIZE} ]]; then + echo "Quantizing weights" + QUANT_OPTIONS="--quantType=WeightQuant --bitNum=8 --quantWeightSize=100 --quantWeightChannel=15" +fi +$CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS + diff --git a/mindspore/lite/examples/train_lenet_java/model/train_utils.py b/mindspore/lite/examples/train_lenet_java/model/train_utils.py new file mode 100644 index 00000000000..0e422baac16 --- /dev/null +++ b/mindspore/lite/examples/train_lenet_java/model/train_utils.py @@ -0,0 +1,34 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_utils.""" + +import mindspore.nn as nn +from mindspore.common.parameter import ParameterTuple + +def TrainWrap(net, loss_fn=None, optimizer=None, weights=None): + """ + TrainWrap + """ + if loss_fn is None: + loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean', sparse=True) + loss_net = nn.WithLossCell(net, loss_fn) + loss_net.set_train() + if weights is None: + weights = ParameterTuple(net.trainable_params()) + if optimizer is None: + optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False, + use_nesterov=False, weight_decay=4e-5, loss_scale=1.0) + train_net = nn.TrainOneStepCell(loss_net, optimizer) + return train_net diff --git a/mindspore/lite/examples/train_lenet_java/prepare_and_run.sh b/mindspore/lite/examples/train_lenet_java/prepare_and_run.sh new file mode 100755 index 00000000000..806464177ea --- /dev/null +++ b/mindspore/lite/examples/train_lenet_java/prepare_and_run.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +#!/bin/bash + +display_usage() +{ + echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz]\n" +} + +checkopts() +{ + DOCKER="" + MNIST_DATA_PATH="" + while getopts 'D:d:r:' opt + do + case "${opt}" in + D) + MNIST_DATA_PATH=$OPTARG + ;; + d) + DOCKER=$OPTARG + ;; + r) + TARBALL="-r $OPTARG" + ;; + *) + echo "Unknown option ${opt}!" + display_usage + exit 1 + esac + done +} + +checkopts "$@" +if [ "$MNIST_DATA_PATH" == "" ]; then + echo "MNIST Dataset directory path was not provided" + display_usage + exit 1 +fi + +./build.sh $TARBALL + +BASEPATH=$(cd "$(dirname $0)" || exit; pwd) + +cd model/ || exit 1 +MSLITE_LINUX=$(ls -d ${BASEPATH}/build/mindspore-lite-*-linux-x64) +CONVERTER=${MSLITE_LINUX}/tools/converter/converter/converter_lite +rm -f *.ms +LD_LIBRARY_PATH=${MSLITE_LINUX}/tools/converter/lib/:${MSLITE_LINUX}/tools/converter/third_party/glog/lib +LD_LIBRARY_PATH=${LD_LIBRARY_PATH} CONVERTER=${CONVERTER} ./prepare_model.sh $DOCKER || exit 1 +cd ../ + +cd target || exit 1 +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:../lib/ +java -Djava.library.path=../lib/ -classpath .:./train_lenet_java.jar:../lib/mindspore-lite-java.jar com.mindspore.lite.train_lenet.Main ../model/lenet_tod.ms $MNIST_DATA_PATH +cd - + diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/DataSet.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/DataSet.java index 47cb469c2e4..09eb3aca5f9 100644 --- a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/DataSet.java +++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/DataSet.java @@ -1,5 +1,20 @@ -package com.mindspore.lite.train_lenet; +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.lite.train_lenet; import java.io.BufferedInputStream; import java.io.FileInputStream; diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java index 4f70d451fc2..39e2fd5af85 100644 --- a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java +++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java @@ -1,9 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.mindspore.lite.train_lenet; import com.mindspore.lite.Version; public class Main { public static void main(String[] args) { + System.loadLibrary("mindspore-lite-train-jni"); System.out.println(Version.version()); if (args.length < 2) { System.err.println("model path and dataset path must be provided."); diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java index 4fbcfaeed24..cbe9b7eb225 100644 --- a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java +++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java @@ -1,7 +1,23 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.mindspore.lite.train_lenet; import com.mindspore.lite.MSTensor; -import com.mindspore.lite.TrainSession; +import com.mindspore.lite.LiteSession; import com.mindspore.lite.config.MSConfig; import java.nio.ByteBuffer; @@ -13,13 +29,14 @@ import java.util.Vector; public class NetRunner { private int dataIndex = 0; private int labelIndex = 1; - private TrainSession session; + private LiteSession session; private long batchSize; private long dataSize; // one input data size, in byte private DataSet ds = new DataSet(); private long numOfClasses; private long cycles = 3500; private int idx = 1; + private int virtualBatch = 16; private String trainedFilePath = "trained.ms"; public void initAndFigureInputs(String modelPath) { @@ -29,9 +46,10 @@ public class NetRunner { // arg 2: cpuBindMode:NO_BIND -> 0 // arg 3: enable_fp16 -> false msConfig.init(0, 2, 0, false); - session = new TrainSession(); - session.init(modelPath, msConfig); - session.setLearningRate(0.01f); + session = new LiteSession(); + System.out.println("Model path is " + modelPath); + session = session.createTrainSession(modelPath, msConfig, false); + session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f); List inputs = session.getInputs(); if (inputs.size() <= 1) { @@ -44,6 +62,7 @@ public class NetRunner { batchSize = inputs.get(dataIndex).getShape()[0]; dataSize = inputs.get(dataIndex).size() / batchSize; System.out.println("batch_size: " + batchSize); + System.out.println("virtual batch multiplier: " + virtualBatch); int index = modelPath.lastIndexOf(".ms"); if (index == -1) { System.out.println("The model " + modelPath + " should be named *.ms"); @@ -92,20 +111,21 @@ public class NetRunner { float min_loss = 1000; float max_acc = 0; for (int i = 0; i < cycles; i++) { - fillInputData(ds.getTrainData(), false); - session.runGraph(); - float loss = getLoss(); - if (min_loss > loss) { - min_loss = loss; - } - if ((i + 1) % 500 == 0) { - float acc = calculateAccuracy(10); // only test 10 batch size - if (max_acc < acc) { - max_acc = acc; + for (int b = 0; b < virtualBatch; b++) { + fillInputData(ds.getTrainData(), false); + session.runGraph(); + float loss = getLoss(); + if (min_loss > loss) { + min_loss = loss; + } + if ((b == 0) && ((i + 1) % 500 == 0)) { + float acc = calculateAccuracy(10); // only test 10 batch size + if (max_acc < acc) { + max_acc = acc; + } + System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc); } - System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc); } - } return 0; } @@ -208,7 +228,10 @@ public class NetRunner { System.out.println("accuracy = " + acc); if (cycles > 0) { - if (session.saveToFile(trainedFilePath)) { + // arg 0: FileName + // arg 1: model type MT_TRAIN -> 0 + // arg 2: quantization type QT_DEFAULT -> 0 + if (session.export(trainedFilePath, 0, 0)) { System.out.println("Trained model successfully saved: " + trainedFilePath); } else { System.err.println("Save model error."); diff --git a/mindspore/lite/java/native/CMakeLists.txt b/mindspore/lite/java/native/CMakeLists.txt index c2399e40d6c..34d12b0bd99 100644 --- a/mindspore/lite/java/native/CMakeLists.txt +++ b/mindspore/lite/java/native/CMakeLists.txt @@ -59,14 +59,6 @@ set(JNI_SRC set(LITE_SO_NAME mindspore-lite) -if(SUPPORT_TRAIN) - set(JNI_SRC - ${JNI_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp - ) - -endif() - add_library(mindspore-lite-jni SHARED ${JNI_SRC}) if(PLATFORM_ARM64 OR PLATFORM_ARM32) @@ -75,3 +67,16 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32) else() target_link_libraries(mindspore-lite-jni ${LITE_SO_NAME}) endif() + +if(SUPPORT_TRAIN) + set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite) + set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp) + add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC}) + if(PLATFORM_ARM64 OR PLATFORM_ARM32) + find_library(log-lib log) + target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib}) + else() + target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME}) + endif() +endif() + diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index cd9dab68760..4175fdcc01a 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -440,11 +440,7 @@ void LiteSession::FreePackOpWeight(const std::vector &kern bool LiteSession::IfUseMindrtExecutor() { bool use_mindrt_run = true; #ifdef ENABLE_MINDRT -#ifdef SUPPORT_TRAIN - use_mindrt_run = false; -#else - use_mindrt_run = true; -#endif + use_mindrt_run = (is_train_session_) ? false : true; #else use_mindrt_run = false; #endif diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 856d39dc68b..e520f457d65 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -334,7 +334,7 @@ int TrainSession::SetLearningRate(float learning_rate) { } for (auto kernel : this->train_kernels_) { if (IsOptimizer(kernel)) { - auto optimizer = reinterpret_cast(kernel); + auto optimizer = static_cast(kernel->kernel()); auto ret = optimizer->SetLearningRate(learning_rate); if (ret != RET_OK) { MS_LOG(ERROR) << kernel->name() << " failed to set learning rate"; @@ -348,7 +348,7 @@ int TrainSession::SetLearningRate(float learning_rate) { float TrainSession::GetLearningRate() { for (auto kernel : this->train_kernels_) { if (IsOptimizer(kernel)) { - auto optimizer = reinterpret_cast(kernel); + auto optimizer = static_cast(kernel->kernel()); return optimizer->GetLearningRate(); } } @@ -363,7 +363,7 @@ int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, for (auto kernel : this->train_kernels_) { if (IsOptimizer(kernel)) { - auto optimizer = reinterpret_cast(kernel); + auto optimizer = static_cast(kernel->kernel()); auto ret = optimizer->SetOptimizerMode(mod); if (ret != RET_OK) { MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode"; @@ -382,7 +382,7 @@ int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, } if (IsBN(kernel) && kernel->is_trainable()) { - auto batchnorm = reinterpret_cast(kernel); + auto batchnorm = static_cast(kernel->kernel()); auto ret = RET_OK; if (mod == kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) { momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum; @@ -409,7 +409,7 @@ int TrainSession::SetupVirtualBatch(int virtual_batch_multiplier, float lr, floa int TrainSession::OptimizerStep() { for (auto kernel : this->train_kernels_) { if (IsOptimizer(kernel)) { - auto optimizer = reinterpret_cast(kernel); + auto optimizer = static_cast(kernel->kernel()); auto ret = optimizer->OptimizerStep(); if (ret != RET_OK) { MS_LOG(ERROR) << kernel->name() << " failed to do optimize step";