add libmindspore-lite-train-jni.so fix lite inference bug

This commit is contained in:
zhengjun10 2021-07-16 11:17:43 +08:00
parent 61fbff6ab2
commit ea788db1dd
9 changed files with 113 additions and 40 deletions

View File

@ -73,6 +73,11 @@ build_lite_x86_64_jni_and_jar() {
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
fi
cd ${LITE_JAVA_PATH}/java
rm -rf gradle .gradle gradlew gradlew.bat
@ -256,6 +261,10 @@ build_lite_arm64_and_jni() {
fi
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
fi
}
build_lite_arm32_and_jni() {
@ -296,6 +305,10 @@ build_lite_arm32_and_jni() {
fi
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
fi
}
build_aar() {

View File

@ -18,6 +18,7 @@ package com.mindspore.lite.train_lenet;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.TrainSession;
import com.mindspore.lite.config.MSConfig;
import java.nio.ByteBuffer;
@ -48,7 +49,7 @@ public class NetRunner {
msConfig.init(0, 2, 0, false);
session = new LiteSession();
System.out.println("Model path is " + modelPath);
session = session.createTrainSession(modelPath, msConfig, false);
session = TrainSession.createTrainSession(modelPath, msConfig, false);
session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
List<MSTensor> inputs = session.getInputs();

View File

@ -63,20 +63,14 @@ public class LiteSession {
}
}
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
LiteSession liteSession = new LiteSession();
liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
if (liteSession.sessionPtr == 0) {
return null;
} else {
return liteSession;
}
}
public long getSessionPtr() {
return sessionPtr;
}
public void setSessionPtr(long sessionPtr) {
this.sessionPtr = sessionPtr;
}
public void bindThread(boolean ifBind) {
this.bindThread(this.sessionPtr, ifBind);
}
@ -204,8 +198,6 @@ public class LiteSession {
private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
private native long createTrainSession(String filename, long msContextPtr, boolean trainMode, long msTrainCfgPtr);
private native boolean compileGraph(long sessionPtr, long modelPtr);
private native void bindThread(long sessionPtr, boolean ifBind);

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.lite;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.config.MSConfig;
public class TrainSession {
static {
System.loadLibrary("mindspore-lite-train-jni");
}
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
LiteSession liteSession = new LiteSession();
long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
if (sessionPtr == 0) {
return null;
} else {
liteSession.setSessionPtr(sessionPtr);
return liteSession;
}
}
private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode,
long msTrainCfgPtr);
}

View File

@ -18,6 +18,7 @@ package com.mindspore.flclient.model;
import com.mindspore.flclient.Common;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.TrainSession;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.config.MSConfig;
import mindspore.schema.FeatureMap;
@ -84,7 +85,7 @@ public class SessionUtil {
// arg 2: cpuBindMode:NO_BIND -> 0
// arg 3: enable_fp16 -> false
msConfig.init(0, 1, 0, false);
LiteSession trainSession = LiteSession.createTrainSession(modelPath, msConfig,false);
LiteSession trainSession = TrainSession.createTrainSession(modelPath, msConfig,false);
if (trainSession == null) {
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
return null;

View File

@ -63,20 +63,14 @@ public class LiteSession {
}
}
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
LiteSession liteSession = new LiteSession();
liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
if (liteSession.sessionPtr == 0) {
return null;
} else {
return liteSession;
}
}
public long getSessionPtr() {
return sessionPtr;
}
public void setSessionPtr(long sessionPtr) {
this.sessionPtr = sessionPtr;
}
public void bindThread(boolean ifBind) {
this.bindThread(this.sessionPtr, ifBind);
}
@ -204,8 +198,6 @@ public class LiteSession {
private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
private native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, long msTrainCfgPtr);
private native boolean compileGraph(long sessionPtr, long modelPtr);
private native void bindThread(long sessionPtr, boolean ifBind);

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.lite;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.config.MSConfig;
public class TrainSession {
static {
System.loadLibrary("mindspore-lite-train-jni");
}
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
LiteSession liteSession = new LiteSession();
long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
if (sessionPtr == 0) {
return null;
} else {
liteSession.setSessionPtr(sessionPtr);
return liteSession;
}
}
private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode,
long msTrainCfgPtr);
}

View File

@ -92,12 +92,6 @@ set(JNI_SRC
set(LITE_SO_NAME mindspore-lite)
if(SUPPORT_TRAIN)
set(JNI_SRC
${JNI_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
)
endif()
add_library(mindspore-lite-jni SHARED ${JNI_SRC})
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
@ -108,13 +102,15 @@ else()
endif()
if(SUPPORT_TRAIN)
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
find_library(log-lib log)
target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
else()
target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME})
endif()
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp)
add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
find_library(log-lib log)
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
else()
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME})
endif()
endif()
set(NDK_STRIP

View File

@ -20,7 +20,7 @@
#include "include/train/train_cfg.h"
#include "include/errorcode.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession(JNIEnv *env, jobject thiz,
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createTrainSession(JNIEnv *env, jobject thiz,
jstring file_name,
jlong ms_context_ptr,
jboolean train_mode,