From 5d5675335b2f48b2cdfdea6c8e59cf32afb314b6 Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Sat, 3 Jul 2021 16:18:05 +0800 Subject: [PATCH] fix lite java inference bug --- build.sh | 15 +------ .../com/mindspore/lite/train_lenet/Main.java | 2 +- .../java/com/mindspore/lite/LiteSession.java | 21 +++++----- .../java/com.mindspore.lite/LiteSession.java | 25 ++++++------ mindspore/lite/java/native/CMakeLists.txt | 6 +-- .../lite/java/native/runtime/lite_session.cpp | 27 +++++++++++-- .../java/native/runtime/train_session.cpp | 39 ------------------- 7 files changed, 50 insertions(+), 85 deletions(-) delete mode 100644 mindspore/lite/java/native/runtime/train_session.cpp diff --git a/build.sh b/build.sh index a4c1039d8b3..ac91ebd205f 100755 --- a/build.sh +++ b/build.sh @@ -526,11 +526,6 @@ build_lite_x86_64_jni_and_jar() cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/ cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/ - if [[ "X$is_train" = "Xon" ]]; then - cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/ - cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/ - cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/ - fi # build java common cd ${LITE_JAVA_PATH}/java/common @@ -715,10 +710,6 @@ build_lite_arm64_and_jni() { fi cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ - if [[ "X$is_train" = "Xon" ]]; then - cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ - cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ - fi } build_lite_arm32_and_jni() { @@ -760,10 +751,6 @@ build_lite_arm32_and_jni() { fi cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ - if [[ "X$is_train" = "Xon" ]]; then - cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ - cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ - fi } check_java_home() { @@ -794,7 +781,7 @@ build_aar() { # build java fl_client local is_train=on - local train_so=${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/libmindspore-lite-train-jni.so + local train_so=${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/libmindspore-lite-train.so if [ ! -f "$train_so" ]; then echo "not exist" is_train=off diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java index 39e2fd5af85..57ae0fe2998 100644 --- a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java +++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java @@ -20,7 +20,7 @@ import com.mindspore.lite.Version; public class Main { public static void main(String[] args) { - System.loadLibrary("mindspore-lite-train-jni"); + System.loadLibrary("mindspore-lite-jni"); System.out.println(Version.version()); if (args.length < 2) { System.err.println("model path and dataset path must be provided."); diff --git a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java index da727138a69..331da0954c7 100644 --- a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java +++ b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java @@ -28,7 +28,6 @@ import com.mindspore.lite.config.MSConfig; public class LiteSession { static { System.loadLibrary("mindspore-lite-jni"); - System.loadLibrary("mindspore-lite-train-jni"); } private long sessionPtr = 0; @@ -64,9 +63,9 @@ public class LiteSession { } } - public static LiteSession createTrainSession(String modelname, final MSConfig config, boolean train_mode) { + public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) { LiteSession liteSession = new LiteSession(); - liteSession.sessionPtr = liteSession.createTrainSession(modelname, config.getMSConfigPtr(), train_mode, 0); + liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0); if (liteSession.sessionPtr == 0) { return null; } else { @@ -78,8 +77,8 @@ public class LiteSession { return sessionPtr; } - public void bindThread(boolean if_bind) { - this.bindThread(this.sessionPtr, if_bind); + public void bindThread(boolean ifBind) { + this.bindThread(this.sessionPtr, ifBind); } public boolean compileGraph(Model model) { @@ -156,8 +155,8 @@ public class LiteSession { return this.resize(this.sessionPtr, inputsArray, dims); } - public boolean export(String modelFilename, int model_type, int quantization_type) { - return this.export(this.sessionPtr, modelFilename, model_type, quantization_type); + public boolean export(String modelFileName, int modelType, int quantizationType) { + return this.export(this.sessionPtr, modelFileName, modelType, quantizationType); } public boolean train() { @@ -205,11 +204,11 @@ public class LiteSession { private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr); - private native long createTrainSession(String filename, long msContextPtr, boolean train_mode, long msTrainCfgPtr); + private native long createTrainSession(String filename, long msContextPtr, boolean trainMode, long msTrainCfgPtr); private native boolean compileGraph(long sessionPtr, long modelPtr); - private native void bindThread(long sessionPtr, boolean if_bind); + private native void bindThread(long sessionPtr, boolean ifBind); private native boolean runGraph(long sessionPtr); @@ -229,7 +228,7 @@ public class LiteSession { private native boolean resize(long sessionPtr, long[] inputs, int[][] dims); - private native boolean export(long sessionPtr, String modelFilename, int model_type, int quantization_type); + private native boolean export(long sessionPtr, String modelFileName, int modelType, int quantizationType); private native boolean train(long sessionPtr); @@ -239,7 +238,7 @@ public class LiteSession { private native boolean isEval(long sessionPtr); - private native boolean setLearningRate(long sessionPtr, float learning_rate); + private native boolean setLearningRate(long sessionPtr, float learningRate); private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum); diff --git a/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java index da727138a69..a49e94078e9 100644 --- a/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java +++ b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java @@ -28,7 +28,6 @@ import com.mindspore.lite.config.MSConfig; public class LiteSession { static { System.loadLibrary("mindspore-lite-jni"); - System.loadLibrary("mindspore-lite-train-jni"); } private long sessionPtr = 0; @@ -64,9 +63,9 @@ public class LiteSession { } } - public static LiteSession createTrainSession(String modelname, final MSConfig config, boolean train_mode) { + public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) { LiteSession liteSession = new LiteSession(); - liteSession.sessionPtr = liteSession.createTrainSession(modelname, config.getMSConfigPtr(), train_mode, 0); + liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0); if (liteSession.sessionPtr == 0) { return null; } else { @@ -78,8 +77,8 @@ public class LiteSession { return sessionPtr; } - public void bindThread(boolean if_bind) { - this.bindThread(this.sessionPtr, if_bind); + public void bindThread(boolean ifBind) { + this.bindThread(this.sessionPtr, ifBind); } public boolean compileGraph(Model model) { @@ -156,8 +155,8 @@ public class LiteSession { return this.resize(this.sessionPtr, inputsArray, dims); } - public boolean export(String modelFilename, int model_type, int quantization_type) { - return this.export(this.sessionPtr, modelFilename, model_type, quantization_type); + public boolean export(String modelFileName, int modelType, int quantizationType) { + return this.export(this.sessionPtr, modelFileName, modelType, quantizationType); } public boolean train() { @@ -176,8 +175,8 @@ public class LiteSession { return this.isEval(this.sessionPtr); } - public boolean setLearningRate(float learning_rate) { - return this.setLearningRate(this.sessionPtr, learning_rate); + public boolean setLearningRate(float learningRate) { + return this.setLearningRate(this.sessionPtr, learningRate); } public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum) { @@ -205,11 +204,11 @@ public class LiteSession { private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr); - private native long createTrainSession(String filename, long msContextPtr, boolean train_mode, long msTrainCfgPtr); + private native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, long msTrainCfgPtr); private native boolean compileGraph(long sessionPtr, long modelPtr); - private native void bindThread(long sessionPtr, boolean if_bind); + private native void bindThread(long sessionPtr, boolean ifBind); private native boolean runGraph(long sessionPtr); @@ -229,7 +228,7 @@ public class LiteSession { private native boolean resize(long sessionPtr, long[] inputs, int[][] dims); - private native boolean export(long sessionPtr, String modelFilename, int model_type, int quantization_type); + private native boolean export(long sessionPtr, String modelFilename, int modelType, int quantizationType); private native boolean train(long sessionPtr); @@ -239,7 +238,7 @@ public class LiteSession { private native boolean isEval(long sessionPtr); - private native boolean setLearningRate(long sessionPtr, float learning_rate); + private native boolean setLearningRate(long sessionPtr, float learningRate); private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum); diff --git a/mindspore/lite/java/native/CMakeLists.txt b/mindspore/lite/java/native/CMakeLists.txt index 239005f5e95..2353b1d43bb 100644 --- a/mindspore/lite/java/native/CMakeLists.txt +++ b/mindspore/lite/java/native/CMakeLists.txt @@ -103,13 +103,11 @@ endif() if(SUPPORT_TRAIN) set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite) - set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp) - add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC}) if(PLATFORM_ARM64 OR PLATFORM_ARM32) find_library(log-lib log) - target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib}) + target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME} ${log-lib}) else() - target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME}) + target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME}) endif() endif() diff --git a/mindspore/lite/java/native/runtime/lite_session.cpp b/mindspore/lite/java/native/runtime/lite_session.cpp index f33a5576bc7..0d2af8a9e6f 100644 --- a/mindspore/lite/java/native/runtime/lite_session.cpp +++ b/mindspore/lite/java/native/runtime/lite_session.cpp @@ -356,8 +356,8 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setLea extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setupVirtualBatch(JNIEnv *env, jobject thiz, jlong session_ptr, - jint virtualBatchMultiplier, - jfloat learningRate, + jint virtual_batch_factor, + jfloat learning_rate, jfloat momentum) { auto *session_pointer = reinterpret_cast(session_ptr); if (session_pointer == nullptr) { @@ -365,7 +365,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setupV return (jboolean) false; } auto *lite_session_ptr = static_cast(session_pointer); - auto ret = lite_session_ptr->SetupVirtualBatch(virtualBatchMultiplier, learningRate, momentum); + auto ret = lite_session_ptr->SetupVirtualBatch(virtual_batch_factor, learning_rate, momentum); return (jboolean)(ret == mindspore::lite::RET_OK); } @@ -411,3 +411,24 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getFeat } return ret; } + +extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession(JNIEnv *env, jobject thiz, + jstring file_name, + jlong ms_context_ptr, + jboolean train_mode, + jlong train_config_ptr) { + auto *pointer = reinterpret_cast(ms_context_ptr); + if (pointer == nullptr) { + MS_LOGE("Context pointer from java is nullptr"); + return jlong(nullptr); + } + auto *lite_context_ptr = static_cast(pointer); + + auto session = mindspore::session::LiteSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE), + lite_context_ptr, train_mode, nullptr); + if (session == nullptr) { + MS_LOGE("CreateTrainSession failed"); + return jlong(nullptr); + } + return jlong(session); +} diff --git a/mindspore/lite/java/native/runtime/train_session.cpp b/mindspore/lite/java/native/runtime/train_session.cpp deleted file mode 100644 index a8381470c72..00000000000 --- a/mindspore/lite/java/native/runtime/train_session.cpp +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "common/ms_log.h" -#include "include/lite_session.h" -#include "include/train/train_cfg.h" -#include "include/errorcode.h" - -extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession( - JNIEnv *env, jobject thiz, jstring file_name, jlong ms_context_ptr, jboolean train_mode, jlong train_config_ptr) { - auto *pointer = reinterpret_cast(ms_context_ptr); - if (pointer == nullptr) { - MS_LOGE("Context pointer from java is nullptr"); - return jlong(nullptr); - } - auto *lite_context_ptr = static_cast(pointer); - - auto session = mindspore::session::LiteSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE), - lite_context_ptr, train_mode, nullptr); - if (session == nullptr) { - MS_LOGE("CreateSession failed"); - return jlong(nullptr); - } - return jlong(session); -}