forked from mindspore-Ecosystem/mindspore
!7761 change GetInputsByName to GetInputByTensorName
Merge pull request !7761 from yeyunpeng2020/java
This commit is contained in:
commit
9a6be0f48a
|
@ -66,14 +66,13 @@ public class LiteSession {
|
|||
return tensors;
|
||||
}
|
||||
|
||||
public List<MSTensor> getInputsByName(String nodeName) {
|
||||
List<Long> ret = this.getInputsByName(this.sessionPtr, nodeName);
|
||||
ArrayList<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
public MSTensor getInputsByTensorName(String tensorName) {
|
||||
Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName);
|
||||
if(tensor_addr == null){
|
||||
return null;
|
||||
}
|
||||
return tensors;
|
||||
MSTensor msTensor = new MSTensor(tensor_addr);
|
||||
return msTensor;
|
||||
}
|
||||
|
||||
public List<MSTensor> getOutputsByNodeName(String nodeName) {
|
||||
|
@ -104,6 +103,9 @@ public class LiteSession {
|
|||
|
||||
public MSTensor getOutputByTensorName(String tensorName) {
|
||||
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName);
|
||||
if(tensor_addr == null){
|
||||
return null;
|
||||
}
|
||||
return new MSTensor(tensor_addr);
|
||||
}
|
||||
|
||||
|
@ -130,7 +132,7 @@ public class LiteSession {
|
|||
|
||||
private native List<Long> getInputs(long sessionPtr);
|
||||
|
||||
private native List<Long> getInputsByName(long sessionPtr, String nodeName);
|
||||
private native Long getInputsByTensorName(long sessionPtr, String tensorName);
|
||||
|
||||
private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName);
|
||||
|
||||
|
|
|
@ -102,27 +102,17 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
|
|||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputsByTensorName(JNIEnv *env,
|
||||
jobject thiz,
|
||||
jlong session_ptr,
|
||||
jstring tensor_name) {
|
||||
jclass array_list = env->FindClass("java/util/ArrayList");
|
||||
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
|
||||
jobject ret = env->NewObject(array_list, array_list_construct);
|
||||
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
|
||||
|
||||
jclass long_object = env->FindClass("java/lang/Long");
|
||||
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getInputsByTensorName(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr,
|
||||
jstring tensor_name) {
|
||||
auto *pointer = reinterpret_cast<void *>(session_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return ret;
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto input = lite_session_ptr->GetInputsByTensorName(JstringToChar(env, tensor_name));
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
|
||||
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
|
||||
return ret;
|
||||
return jlong(input);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz,
|
||||
|
|
Loading…
Reference in New Issue