fix demo bug

This commit is contained in:
yeyunpeng2020 2021-03-25 09:52:46 +08:00
parent 5865639c7a
commit a7970b8320
3 changed files with 21 additions and 19 deletions

View File

@ -107,7 +107,7 @@ public class Main {
return false;
}
msgSb.append(" and out data:");
for (int i = 0; i < 10 && i < outTensor.elementsNum(); i++) {
for (int i = 0; i < 50 && i < outTensor.elementsNum(); i++) {
msgSb.append(" ").append(result[i]);
}
System.out.println(msgSb.toString());

View File

@ -16,7 +16,6 @@
#include <jni.h>
#include "common/ms_log.h"
#include "common/jni_utils.h"
#include "include/lite_session.h"
#include "include/errorcode.h"
@ -111,7 +110,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getInputs
return jlong(nullptr);
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto input = lite_session_ptr->GetInputsByTensorName(JstringToChar(env, tensor_name));
auto input = lite_session_ptr->GetInputsByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
return jlong(input);
}
@ -131,10 +130,11 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto inputs = lite_session_ptr->GetOutputsByNodeName(JstringToChar(env, node_name));
auto inputs = lite_session_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE));
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
env->DeleteLocalRef(tensor_addr);
}
return ret;
}
@ -155,11 +155,12 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
auto outputs = lite_session_ptr->GetOutputs();
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
for (auto output_iter : outputs) {
for (const auto &output_iter : outputs) {
auto node_name = output_iter.first;
auto ms_tensor = output_iter.second;
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor));
env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), tensor_addr);
env->DeleteLocalRef(tensor_addr);
}
return hash_map;
}
@ -178,7 +179,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto output_names = lite_session_ptr->GetOutputTensorNames();
for (auto output_name : output_names) {
for (const auto &output_name : output_names) {
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
}
return ret;
@ -193,7 +194,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getOutput
return jlong(nullptr);
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto output = lite_session_ptr->GetOutputByTensorName(JstringToChar(env, tensor_name));
auto output = lite_session_ptr->GetOutputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
return jlong(output);
}
@ -219,7 +220,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_resize
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
jsize input_size = static_cast<int>(env->GetArrayLength(inputs));
auto input_size = static_cast<int>(env->GetArrayLength(inputs));
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
std::vector<mindspore::tensor::MSTensor *> c_inputs;
for (int i = 0; i < input_size; i++) {
@ -231,16 +232,18 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_resize
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer);
c_inputs.push_back(ms_tensor_ptr);
}
jsize tensor_size = static_cast<int>(env->GetArrayLength(dims));
auto tensor_size = static_cast<int>(env->GetArrayLength(dims));
for (int i = 0; i < tensor_size; i++) {
jintArray array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i));
jsize dim_size = static_cast<int>(env->GetArrayLength(array));
auto array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i));
auto dim_size = static_cast<int>(env->GetArrayLength(array));
jint *dim_data = env->GetIntArrayElements(array, nullptr);
std::vector<int> tensor_dims;
std::vector<int> tensor_dims(dim_size);
for (int j = 0; j < dim_size; j++) {
tensor_dims.push_back(dim_data[j]);
tensor_dims[j] = dim_data[j];
}
c_dims.push_back(tensor_dims);
env->ReleaseIntArrayElements(array, dim_data, JNI_ABORT);
env->DeleteLocalRef(array);
}
int ret = lite_session_ptr->Resize(c_inputs, c_dims);
return (jboolean)(ret == mindspore::lite::RET_OK);

View File

@ -17,7 +17,6 @@
#include <jni.h>
#include <fstream>
#include "common/ms_log.h"
#include "common/jni_utils.h"
#include "include/model.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEnv *env, jobject thiz, jobject buffer) {
@ -38,7 +37,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEn
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModelByPath(JNIEnv *env, jobject thiz,
jstring model_path) {
auto model_path_char = JstringToChar(env, model_path);
auto model_path_char = env->GetStringUTFChars(model_path, JNI_FALSE);
if (nullptr == model_path_char) {
MS_LOGE("model_path_char is nullptr");
return reinterpret_cast<jlong>(nullptr);
@ -56,7 +55,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModelByPath
ifs.seekg(0, std::ios::end);
auto size = ifs.tellg();
std::unique_ptr<char[]> buf(new (std::nothrow) char[size]);
auto buf = new (std::nothrow) char[size];
if (buf == nullptr) {
MS_LOGE("malloc buf failed, file: %s", model_path_char);
ifs.close();
@ -64,10 +63,10 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModelByPath
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf.get(), size);
ifs.read(buf, size);
ifs.close();
delete[](model_path_char);
auto model = mindspore::lite::Model::Import(buf.get(), size);
auto model = mindspore::lite::Model::Import(buf, size);
delete[] buf;
if (model == nullptr) {
MS_LOGE("Import model failed");
return reinterpret_cast<jlong>(nullptr);