!7761 change GetInputsByName to GetInputByTensorName

Merge pull request !7761 from yeyunpeng2020/java
This commit is contained in:
mindspore-ci-bot 2020-10-26 20:28:33 +08:00 committed by Gitee
commit 9a6be0f48a
2 changed files with 15 additions and 23 deletions

View File

@ -66,14 +66,13 @@ public class LiteSession {
return tensors;
}
public List<MSTensor> getInputsByName(String nodeName) {
List<Long> ret = this.getInputsByName(this.sessionPtr, nodeName);
ArrayList<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
public MSTensor getInputsByTensorName(String tensorName) {
Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName);
if(tensor_addr == null){
return null;
}
return tensors;
MSTensor msTensor = new MSTensor(tensor_addr);
return msTensor;
}
public List<MSTensor> getOutputsByNodeName(String nodeName) {
@ -104,6 +103,9 @@ public class LiteSession {
public MSTensor getOutputByTensorName(String tensorName) {
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName);
if(tensor_addr == null){
return null;
}
return new MSTensor(tensor_addr);
}
@ -130,7 +132,7 @@ public class LiteSession {
private native List<Long> getInputs(long sessionPtr);
private native List<Long> getInputsByName(long sessionPtr, String nodeName);
private native Long getInputsByTensorName(long sessionPtr, String tensorName);
private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName);

View File

@ -102,27 +102,17 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
return ret;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputsByTensorName(JNIEnv *env,
jobject thiz,
jlong session_ptr,
jstring tensor_name) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getInputsByTensorName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring tensor_name) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
return jlong(nullptr);
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto input = lite_session_ptr->GetInputsByTensorName(JstringToChar(env, tensor_name));
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
return ret;
return jlong(input);
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz,