From 10d188fd14310658513e7db951a1183daa03c1f7 Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Wed, 19 May 2021 17:48:09 +0800 Subject: [PATCH] fix train example --- build.sh | 12 ++++++------ .../java/com/mindspore/lite/TrainSession.java | 6 ------ .../lite/java/native/runtime/train_session.cpp | 15 +-------------- 3 files changed, 7 insertions(+), 26 deletions(-) diff --git a/build.sh b/build.sh index cf8b8c29a8b..a084c2969a4 100755 --- a/build.sh +++ b/build.sh @@ -723,8 +723,8 @@ build_lite_java_arm64() { mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/ mkdir -p ${JAVA_PATH}/native/libs/arm64-v8a/ if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then - cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ - cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/native/libs/arm64-v8a/ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/arm64-v8a/ @@ -759,8 +759,8 @@ build_lite_java_arm32() { mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/ mkdir -p ${JAVA_PATH}/native/libs/armeabi-v7a/ if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then - cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ - cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/native/libs/armeabi-v7a/ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/armeabi-v7a/ @@ -796,8 +796,8 @@ build_lite_java_x86() { mkdir -p ${JAVA_PATH}/java/linux_x86/libs/ mkdir -p ${JAVA_PATH}/native/libs/linux_x86/ if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then - cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/ - cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/java/linux_x86/libs/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/native/libs/linux_x86/ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/linux_x86/libs/ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/linux_x86/ diff --git a/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java b/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java index 36925102af8..b056bd1a3ec 100644 --- a/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java +++ b/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java @@ -149,10 +149,6 @@ public class TrainSession { public boolean setupVirtualBatch(int virtualBatchMultiplier) { return this.setupVirtualBatch(this.sessionPtr, virtualBatchMultiplier, -1.0f, -1.0f); } - - public boolean setLossName(String lossName) { - return this.setLossName(this.sessionPtr, lossName); - } private native long createSession(String modelFilename, long msConfigPtr); @@ -190,6 +186,4 @@ public class TrainSession { private native boolean setLearningRate(long sessionPtr, float learning_rate); private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum); - - private native boolean setLossName(long sessionPtr, String lossName); } diff --git a/mindspore/lite/java/native/runtime/train_session.cpp b/mindspore/lite/java/native/runtime/train_session.cpp index 14d515e7f67..8b8a96cefaf 100644 --- a/mindspore/lite/java/native/runtime/train_session.cpp +++ b/mindspore/lite/java/native/runtime/train_session.cpp @@ -239,7 +239,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_saveT return (jboolean) false; } auto *train_session_ptr = static_cast(session_pointer); - auto ret = train_session_ptr->SaveToFile(JstringToChar(env, model_file_name)); + auto ret = train_session_ptr->Export(JstringToChar(env, model_file_name)); return (jboolean)(ret == 0); } @@ -318,16 +318,3 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setup auto ret = train_session_ptr->SetupVirtualBatch(virtualBatchMultiplier, learningRate, momentum); return (jboolean)(ret == mindspore::lite::RET_OK); } - -extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setLossName(JNIEnv *env, jobject thiz, - jlong session_ptr, - jstring lossName) { - auto *session_pointer = reinterpret_cast(session_ptr); - if (session_pointer == nullptr) { - MS_LOGE("Session pointer from java is nullptr"); - return (jboolean) false; - } - auto *train_session_ptr = static_cast(session_pointer); - auto ret = train_session_ptr->SetLossName(JstringToChar(env, lossName)); - return (jboolean)(ret == mindspore::lite::RET_OK); -}