diff --git a/mindspore/lite/build_lite.sh b/mindspore/lite/build_lite.sh index be87708fd80..b6adc876475 100755 --- a/mindspore/lite/build_lite.sh +++ b/mindspore/lite/build_lite.sh @@ -73,6 +73,11 @@ build_lite_x86_64_jni_and_jar() { cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/ cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/ + if [[ "X$is_train" = "Xon" ]]; then + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/ + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/ + cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/ + fi cd ${LITE_JAVA_PATH}/java rm -rf gradle .gradle gradlew gradlew.bat @@ -256,6 +261,10 @@ build_lite_arm64_and_jni() { fi cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + if [[ "X$is_train" = "Xon" ]]; then + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + fi } build_lite_arm32_and_jni() { @@ -296,6 +305,10 @@ build_lite_arm32_and_jni() { fi cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + if [[ "X$is_train" = "Xon" ]]; then + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + fi } build_aar() { diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java index cbe9b7eb225..971e0729f5d 100644 --- a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java +++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java @@ -18,6 +18,7 @@ package com.mindspore.lite.train_lenet; import com.mindspore.lite.MSTensor; import com.mindspore.lite.LiteSession; +import com.mindspore.lite.TrainSession; import com.mindspore.lite.config.MSConfig; import java.nio.ByteBuffer; @@ -48,7 +49,7 @@ public class NetRunner { msConfig.init(0, 2, 0, false); session = new LiteSession(); System.out.println("Model path is " + modelPath); - session = session.createTrainSession(modelPath, msConfig, false); + session = TrainSession.createTrainSession(modelPath, msConfig, false); session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f); List inputs = session.getInputs(); diff --git a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java index 331da0954c7..7088b77eb4a 100644 --- a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java +++ b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java @@ -63,20 +63,14 @@ public class LiteSession { } } - public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) { - LiteSession liteSession = new LiteSession(); - liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0); - if (liteSession.sessionPtr == 0) { - return null; - } else { - return liteSession; - } - } - public long getSessionPtr() { return sessionPtr; } + public void setSessionPtr(long sessionPtr) { + this.sessionPtr = sessionPtr; + } + public void bindThread(boolean ifBind) { this.bindThread(this.sessionPtr, ifBind); } @@ -204,8 +198,6 @@ public class LiteSession { private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr); - private native long createTrainSession(String filename, long msContextPtr, boolean trainMode, long msTrainCfgPtr); - private native boolean compileGraph(long sessionPtr, long modelPtr); private native void bindThread(long sessionPtr, boolean ifBind); diff --git a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java new file mode 100644 index 00000000000..13af2b4d74b --- /dev/null +++ b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.lite; + +import com.mindspore.lite.LiteSession; +import com.mindspore.lite.config.MSConfig; + +public class TrainSession { + static { + System.loadLibrary("mindspore-lite-train-jni"); + } + public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) { + LiteSession liteSession = new LiteSession(); + long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0); + if (sessionPtr == 0) { + return null; + } else { + liteSession.setSessionPtr(sessionPtr); + return liteSession; + } + } + + private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, + long msTrainCfgPtr); +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java index 513b59d36ee..0b219096f40 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java @@ -18,6 +18,7 @@ package com.mindspore.flclient.model; import com.mindspore.flclient.Common; import com.mindspore.lite.LiteSession; +import com.mindspore.lite.TrainSession; import com.mindspore.lite.MSTensor; import com.mindspore.lite.config.MSConfig; import mindspore.schema.FeatureMap; @@ -84,7 +85,7 @@ public class SessionUtil { // arg 2: cpuBindMode:NO_BIND -> 0 // arg 3: enable_fp16 -> false msConfig.init(0, 1, 0, false); - LiteSession trainSession = LiteSession.createTrainSession(modelPath, msConfig,false); + LiteSession trainSession = TrainSession.createTrainSession(modelPath, msConfig,false); if (trainSession == null) { logger.severe(Common.addTag("init session failed,please check model path:" + modelPath)); return null; diff --git a/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java index a49e94078e9..859c021997e 100644 --- a/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java +++ b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java @@ -63,20 +63,14 @@ public class LiteSession { } } - public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) { - LiteSession liteSession = new LiteSession(); - liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0); - if (liteSession.sessionPtr == 0) { - return null; - } else { - return liteSession; - } - } - public long getSessionPtr() { return sessionPtr; } + public void setSessionPtr(long sessionPtr) { + this.sessionPtr = sessionPtr; + } + public void bindThread(boolean ifBind) { this.bindThread(this.sessionPtr, ifBind); } @@ -204,8 +198,6 @@ public class LiteSession { private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr); - private native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, long msTrainCfgPtr); - private native boolean compileGraph(long sessionPtr, long modelPtr); private native void bindThread(long sessionPtr, boolean ifBind); diff --git a/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/TrainSession.java b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/TrainSession.java new file mode 100644 index 00000000000..13af2b4d74b --- /dev/null +++ b/mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/TrainSession.java @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.lite; + +import com.mindspore.lite.LiteSession; +import com.mindspore.lite.config.MSConfig; + +public class TrainSession { + static { + System.loadLibrary("mindspore-lite-train-jni"); + } + public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) { + LiteSession liteSession = new LiteSession(); + long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0); + if (sessionPtr == 0) { + return null; + } else { + liteSession.setSessionPtr(sessionPtr); + return liteSession; + } + } + + private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, + long msTrainCfgPtr); +} diff --git a/mindspore/lite/java/native/CMakeLists.txt b/mindspore/lite/java/native/CMakeLists.txt index 1696c0b600a..f2b3990cafe 100644 --- a/mindspore/lite/java/native/CMakeLists.txt +++ b/mindspore/lite/java/native/CMakeLists.txt @@ -92,12 +92,6 @@ set(JNI_SRC set(LITE_SO_NAME mindspore-lite) -if(SUPPORT_TRAIN) - set(JNI_SRC - ${JNI_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp - ) -endif() add_library(mindspore-lite-jni SHARED ${JNI_SRC}) if(PLATFORM_ARM64 OR PLATFORM_ARM32) @@ -108,13 +102,15 @@ else() endif() if(SUPPORT_TRAIN) - set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite) - if(PLATFORM_ARM64 OR PLATFORM_ARM32) - find_library(log-lib log) - target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME} ${log-lib}) - else() - target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME}) - endif() + set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite) + set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp) + add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC}) + if(PLATFORM_ARM64 OR PLATFORM_ARM32) + find_library(log-lib log) + target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib}) + else() + target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME}) + endif() endif() set(NDK_STRIP diff --git a/mindspore/lite/java/native/runtime/train_session.cpp b/mindspore/lite/java/native/runtime/train_session.cpp index 071e8306144..d959b76b58d 100644 --- a/mindspore/lite/java/native/runtime/train_session.cpp +++ b/mindspore/lite/java/native/runtime/train_session.cpp @@ -20,7 +20,7 @@ #include "include/train/train_cfg.h" #include "include/errorcode.h" -extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession(JNIEnv *env, jobject thiz, +extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createTrainSession(JNIEnv *env, jobject thiz, jstring file_name, jlong ms_context_ptr, jboolean train_mode,