fix train example

This commit is contained in:
zhengjun10 2021-05-19 17:48:09 +08:00
parent 8721b1db1f
commit 10d188fd14
3 changed files with 7 additions and 26 deletions

View File

@ -723,8 +723,8 @@ build_lite_java_arm64() {
mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/
mkdir -p ${JAVA_PATH}/native/libs/arm64-v8a/
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/native/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/arm64-v8a/
@ -759,8 +759,8 @@ build_lite_java_arm32() {
mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/
mkdir -p ${JAVA_PATH}/native/libs/armeabi-v7a/
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/native/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/armeabi-v7a/
@ -796,8 +796,8 @@ build_lite_java_x86() {
mkdir -p ${JAVA_PATH}/java/linux_x86/libs/
mkdir -p ${JAVA_PATH}/native/libs/linux_x86/
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/java/linux_x86/libs/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite-train.so ${JAVA_PATH}/native/libs/linux_x86/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/linux_x86/libs/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/linux_x86/

View File

@ -149,10 +149,6 @@ public class TrainSession {
public boolean setupVirtualBatch(int virtualBatchMultiplier) {
return this.setupVirtualBatch(this.sessionPtr, virtualBatchMultiplier, -1.0f, -1.0f);
}
public boolean setLossName(String lossName) {
return this.setLossName(this.sessionPtr, lossName);
}
private native long createSession(String modelFilename, long msConfigPtr);
@ -190,6 +186,4 @@ public class TrainSession {
private native boolean setLearningRate(long sessionPtr, float learning_rate);
private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum);
private native boolean setLossName(long sessionPtr, String lossName);
}

View File

@ -239,7 +239,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_saveT
return (jboolean) false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer);
auto ret = train_session_ptr->SaveToFile(JstringToChar(env, model_file_name));
auto ret = train_session_ptr->Export(JstringToChar(env, model_file_name));
return (jboolean)(ret == 0);
}
@ -318,16 +318,3 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setup
auto ret = train_session_ptr->SetupVirtualBatch(virtualBatchMultiplier, learningRate, momentum);
return (jboolean)(ret == mindspore::lite::RET_OK);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setLossName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring lossName) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer);
auto ret = train_session_ptr->SetLossName(JstringToChar(env, lossName));
return (jboolean)(ret == mindspore::lite::RET_OK);
}