forked from mindspore-Ecosystem/mindspore
add libmindspore-lite-train-jni.so fix lite inference bug
This commit is contained in:
parent
61fbff6ab2
commit
ea788db1dd
|
@ -73,6 +73,11 @@ build_lite_x86_64_jni_and_jar() {
|
|||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
||||
cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
||||
cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
|
||||
fi
|
||||
|
||||
cd ${LITE_JAVA_PATH}/java
|
||||
rm -rf gradle .gradle gradlew gradlew.bat
|
||||
|
@ -256,6 +261,10 @@ build_lite_arm64_and_jni() {
|
|||
fi
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
|
||||
fi
|
||||
}
|
||||
|
||||
build_lite_arm32_and_jni() {
|
||||
|
@ -296,6 +305,10 @@ build_lite_arm32_and_jni() {
|
|||
fi
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
fi
|
||||
}
|
||||
|
||||
build_aar() {
|
||||
|
|
|
@ -18,6 +18,7 @@ package com.mindspore.lite.train_lenet;
|
|||
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.TrainSession;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -48,7 +49,7 @@ public class NetRunner {
|
|||
msConfig.init(0, 2, 0, false);
|
||||
session = new LiteSession();
|
||||
System.out.println("Model path is " + modelPath);
|
||||
session = session.createTrainSession(modelPath, msConfig, false);
|
||||
session = TrainSession.createTrainSession(modelPath, msConfig, false);
|
||||
session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
|
||||
|
||||
List<MSTensor> inputs = session.getInputs();
|
||||
|
|
|
@ -63,20 +63,14 @@ public class LiteSession {
|
|||
}
|
||||
}
|
||||
|
||||
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
|
||||
LiteSession liteSession = new LiteSession();
|
||||
liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
|
||||
if (liteSession.sessionPtr == 0) {
|
||||
return null;
|
||||
} else {
|
||||
return liteSession;
|
||||
}
|
||||
}
|
||||
|
||||
public long getSessionPtr() {
|
||||
return sessionPtr;
|
||||
}
|
||||
|
||||
public void setSessionPtr(long sessionPtr) {
|
||||
this.sessionPtr = sessionPtr;
|
||||
}
|
||||
|
||||
public void bindThread(boolean ifBind) {
|
||||
this.bindThread(this.sessionPtr, ifBind);
|
||||
}
|
||||
|
@ -204,8 +198,6 @@ public class LiteSession {
|
|||
|
||||
private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
|
||||
|
||||
private native long createTrainSession(String filename, long msContextPtr, boolean trainMode, long msTrainCfgPtr);
|
||||
|
||||
private native boolean compileGraph(long sessionPtr, long modelPtr);
|
||||
|
||||
private native void bindThread(long sessionPtr, boolean ifBind);
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.lite;
|
||||
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
|
||||
public class TrainSession {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-train-jni");
|
||||
}
|
||||
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
|
||||
LiteSession liteSession = new LiteSession();
|
||||
long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
|
||||
if (sessionPtr == 0) {
|
||||
return null;
|
||||
} else {
|
||||
liteSession.setSessionPtr(sessionPtr);
|
||||
return liteSession;
|
||||
}
|
||||
}
|
||||
|
||||
private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode,
|
||||
long msTrainCfgPtr);
|
||||
}
|
|
@ -18,6 +18,7 @@ package com.mindspore.flclient.model;
|
|||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.TrainSession;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
@ -84,7 +85,7 @@ public class SessionUtil {
|
|||
// arg 2: cpuBindMode:NO_BIND -> 0
|
||||
// arg 3: enable_fp16 -> false
|
||||
msConfig.init(0, 1, 0, false);
|
||||
LiteSession trainSession = LiteSession.createTrainSession(modelPath, msConfig,false);
|
||||
LiteSession trainSession = TrainSession.createTrainSession(modelPath, msConfig,false);
|
||||
if (trainSession == null) {
|
||||
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
|
||||
return null;
|
||||
|
|
|
@ -63,20 +63,14 @@ public class LiteSession {
|
|||
}
|
||||
}
|
||||
|
||||
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
|
||||
LiteSession liteSession = new LiteSession();
|
||||
liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
|
||||
if (liteSession.sessionPtr == 0) {
|
||||
return null;
|
||||
} else {
|
||||
return liteSession;
|
||||
}
|
||||
}
|
||||
|
||||
public long getSessionPtr() {
|
||||
return sessionPtr;
|
||||
}
|
||||
|
||||
public void setSessionPtr(long sessionPtr) {
|
||||
this.sessionPtr = sessionPtr;
|
||||
}
|
||||
|
||||
public void bindThread(boolean ifBind) {
|
||||
this.bindThread(this.sessionPtr, ifBind);
|
||||
}
|
||||
|
@ -204,8 +198,6 @@ public class LiteSession {
|
|||
|
||||
private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
|
||||
|
||||
private native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, long msTrainCfgPtr);
|
||||
|
||||
private native boolean compileGraph(long sessionPtr, long modelPtr);
|
||||
|
||||
private native void bindThread(long sessionPtr, boolean ifBind);
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.lite;
|
||||
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
|
||||
public class TrainSession {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-train-jni");
|
||||
}
|
||||
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
|
||||
LiteSession liteSession = new LiteSession();
|
||||
long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
|
||||
if (sessionPtr == 0) {
|
||||
return null;
|
||||
} else {
|
||||
liteSession.setSessionPtr(sessionPtr);
|
||||
return liteSession;
|
||||
}
|
||||
}
|
||||
|
||||
private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode,
|
||||
long msTrainCfgPtr);
|
||||
}
|
|
@ -92,12 +92,6 @@ set(JNI_SRC
|
|||
|
||||
set(LITE_SO_NAME mindspore-lite)
|
||||
|
||||
if(SUPPORT_TRAIN)
|
||||
set(JNI_SRC
|
||||
${JNI_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
|
||||
)
|
||||
endif()
|
||||
add_library(mindspore-lite-jni SHARED ${JNI_SRC})
|
||||
|
||||
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||
|
@ -108,13 +102,15 @@ else()
|
|||
endif()
|
||||
|
||||
if(SUPPORT_TRAIN)
|
||||
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
|
||||
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||
find_library(log-lib log)
|
||||
target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
|
||||
else()
|
||||
target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME})
|
||||
endif()
|
||||
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
|
||||
set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp)
|
||||
add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})
|
||||
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||
find_library(log-lib log)
|
||||
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
|
||||
else()
|
||||
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(NDK_STRIP
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "include/train/train_cfg.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession(JNIEnv *env, jobject thiz,
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createTrainSession(JNIEnv *env, jobject thiz,
|
||||
jstring file_name,
|
||||
jlong ms_context_ptr,
|
||||
jboolean train_mode,
|
||||
|
|
Loading…
Reference in New Issue