forked from mindspore-Ecosystem/mindspore
!30316 [MS][LITE][MODELPOOL] java api to master
Merge pull request !30316 from yefeng/237-java_api_for_master
This commit is contained in:
commit
271e3757da
|
@ -78,6 +78,13 @@ set(JNI_SRC
|
|||
${NEW_NATIVE_DIR}/version.cpp
|
||||
)
|
||||
|
||||
if(MSLITE_ENABLE_SERVER_INFERENCE)
|
||||
set(JNI_TRAIN_SRC
|
||||
${NEW_NATIVE_DIR}/runner_config.cpp
|
||||
${NEW_NATIVE_DIR}/model_parallel_runner.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
set(LITE_SO_NAME mindspore-lite)
|
||||
|
||||
add_library(mindspore-lite-jni SHARED ${JNI_SRC})
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
/*
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore;
|
||||
|
||||
import com.mindspore.config.RunnerConfig;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* ModelParallelRunner is used to define a MindSpore ModelPoolManager, facilitating Model management.
|
||||
*
|
||||
* @since v1.6
|
||||
*/
|
||||
public class ModelParallelRunner {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
private long modelParallelRunnerPtr;
|
||||
|
||||
/**
|
||||
* Construct function.
|
||||
*/
|
||||
public ModelParallelRunner() {
|
||||
this.modelParallelRunnerPtr = 0L;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get modelParallelRunnerPtr pointer.
|
||||
*
|
||||
* @return modelParallelRunnerPtr pointer.
|
||||
*/
|
||||
public long getModelParallelRunnerPtr() {
|
||||
return this.modelParallelRunnerPtr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a model runner from model path so that it can run on a device. Only valid for Lite.
|
||||
*
|
||||
* @param modelPath the model path.
|
||||
* @param runnerConfig the RunnerConfig Object.
|
||||
* @return init status.
|
||||
*/
|
||||
public boolean init(String modelPath, RunnerConfig runnerConfig) {
|
||||
if (runnerConfig == null || modelPath == null) {
|
||||
return false;
|
||||
}
|
||||
modelParallelRunnerPtr = this.init(modelPath, runnerConfig.getRunnerConfigPtr());
|
||||
return modelParallelRunnerPtr != 0L;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a model runner from model path so that it can run on a device. Only valid for Lite.
|
||||
*
|
||||
* @param modelPath the model path.
|
||||
* @return init status.
|
||||
*/
|
||||
public boolean init(String modelPath) {
|
||||
if (modelPath == null) {
|
||||
return false;
|
||||
}
|
||||
modelParallelRunnerPtr = this.init(modelPath, 0L);
|
||||
return modelParallelRunnerPtr != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a model runner from model path so that it can run on a device. Only valid for Lite.
|
||||
*
|
||||
* @param inputs inputs A vector where model inputs are arranged in sequence.
|
||||
* @param outputs outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
|
||||
* @return init status.
|
||||
*/
|
||||
public boolean predict(List<MSTensor> inputs, List<MSTensor> outputs) {
|
||||
long[] inputsPtrArray = new long[inputs.size()];
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
inputsPtrArray[i] = inputs.get(i).getMSTensorPtr();
|
||||
}
|
||||
List<Long> outputPtrs = predict(modelParallelRunnerPtr, inputsPtrArray);
|
||||
if (outputPtrs.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < outputPtrs.size(); i++) {
|
||||
if (outputPtrs.get(i) == 0L) {
|
||||
return false;
|
||||
}
|
||||
MSTensor msTensor = new MSTensor(outputPtrs.get(i));
|
||||
outputs.add(msTensor);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtains all input tensors of the model.
|
||||
*
|
||||
* @return The vector that includes all input tensors.
|
||||
*/
|
||||
public List<MSTensor> getInputs() {
|
||||
List<Long> ret = this.getInputs(this.modelParallelRunnerPtr);
|
||||
List<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtains all output tensors of the model.
|
||||
*
|
||||
* @return The vector that includes all input tensors.
|
||||
*/
|
||||
public List<MSTensor> getOutputs() {
|
||||
List<Long> ret = this.getOutputs(this.modelParallelRunnerPtr);
|
||||
List<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Free model
|
||||
*/
|
||||
public void free() {
|
||||
if (modelParallelRunnerPtr != 0L) {
|
||||
this.free(modelParallelRunnerPtr);
|
||||
modelParallelRunnerPtr = 0L;
|
||||
}
|
||||
}
|
||||
|
||||
private native long init(String modelPath, long runnerConfigPtr);
|
||||
|
||||
private native List<Long> predict(long modelParallelRunnerPtr, long[] inputs);
|
||||
|
||||
private native List<Long> getInputs(long modelParallelRunnerPtr);
|
||||
|
||||
private native List<Long> getOutputs(long modelParallelRunnerPtr);
|
||||
|
||||
private native void free(long modelParallelRunnerPtr);
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.config;
|
||||
|
||||
/**
|
||||
* Configuration for ModelParallelRunner.
|
||||
*
|
||||
* @since v1.6
|
||||
*/
|
||||
public class RunnerConfig {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
private long runnerConfigPtr;
|
||||
|
||||
/**
|
||||
* Construct function.
|
||||
*/
|
||||
public RunnerConfig() {
|
||||
this.runnerConfigPtr = 0L;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Init RunnerConfig
|
||||
*
|
||||
* @param msContext MSContext Object.
|
||||
* @return init status.
|
||||
*/
|
||||
public boolean init(MSContext msContext) {
|
||||
this.runnerConfigPtr = createRunnerConfig(msContext.getMSContextPtr());
|
||||
return this.runnerConfigPtr != 0L;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set num models
|
||||
*
|
||||
* @param numModel The number of parallel models.
|
||||
*/
|
||||
public void setNumModel(int numModel) {
|
||||
setNumModel(runnerConfigPtr, numModel);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get RunnerConfig pointer.
|
||||
*
|
||||
* @return RunnerConfig pointer.
|
||||
*/
|
||||
public long getRunnerConfigPtr() {
|
||||
return runnerConfigPtr;
|
||||
}
|
||||
|
||||
private native long createRunnerConfig(long msContextPtr);
|
||||
|
||||
private native void setNumModel(long runnerConfigPtr, int numModel);
|
||||
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <jni.h>
|
||||
#include "common/ms_log.h"
|
||||
#include "include/api/model_parallel_runner.h"
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_ModelParallelRunner_init(JNIEnv *env, jobject thiz,
|
||||
jstring model_path,
|
||||
jlong runner_config_ptr) {
|
||||
auto runner = new (std::nothrow) mindspore::ModelParallelRunner();
|
||||
if (runner == nullptr) {
|
||||
MS_LOGE("Make ModelParallelRunner failed");
|
||||
return (jlong) nullptr;
|
||||
}
|
||||
auto model_path_str = env->GetStringUTFChars(model_path, JNI_FALSE);
|
||||
if (runner_config_ptr == 0L) {
|
||||
runner->Init(model_path_str);
|
||||
} else {
|
||||
auto *c_runner_config = reinterpret_cast<mindspore::RunnerConfig *>(runner_config_ptr);
|
||||
auto runner_config = std::make_shared<mindspore::RunnerConfig>();
|
||||
if (runner_config == nullptr) {
|
||||
delete runner;
|
||||
MS_LOGE("Make RunnerConfig failed");
|
||||
return (jlong) nullptr;
|
||||
}
|
||||
runner_config.reset(c_runner_config);
|
||||
runner->Init(model_path_str, runner_config);
|
||||
}
|
||||
return (jlong)runner;
|
||||
}
|
||||
|
||||
jobject GetParallelInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_parallel_runner_ptr, bool is_input) {
|
||||
jclass array_list = env->FindClass("java/util/ArrayList");
|
||||
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
|
||||
jobject ret = env->NewObject(array_list, array_list_construct);
|
||||
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
|
||||
|
||||
jclass long_object = env->FindClass("java/lang/Long");
|
||||
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
|
||||
auto *pointer = reinterpret_cast<mindspore::ModelParallelRunner *>(model_parallel_runner_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Model pointer from java is nullptr");
|
||||
return ret;
|
||||
}
|
||||
std::vector<mindspore::MSTensor> tensors;
|
||||
if (is_input) {
|
||||
tensors = pointer->GetInputs();
|
||||
} else {
|
||||
tensors = pointer->GetOutputs();
|
||||
}
|
||||
for (auto &tensor : tensors) {
|
||||
auto tensor_ptr = std::make_unique<mindspore::MSTensor>(tensor);
|
||||
if (tensor_ptr == nullptr) {
|
||||
MS_LOGE("Make ms tensor failed");
|
||||
return ret;
|
||||
}
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
|
||||
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_ModelParallelRunner_getInputs(JNIEnv *env, jobject thiz,
|
||||
jlong model_parallel_runner_ptr) {
|
||||
return GetParallelInOrOutTensors(env, thiz, model_parallel_runner_ptr, true);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL
|
||||
Java_com_mindspore_ModelParallelRunner_getOutputs(JNIEnv *env, jobject thiz, jlong model_parallel_runner_ptr) {
|
||||
return GetParallelInOrOutTensors(env, thiz, model_parallel_runner_ptr, false);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_ModelParallelRunner_predict(JNIEnv *env, jobject thiz,
|
||||
jlong model_parallel_runner_ptr,
|
||||
jlongArray inputs) {
|
||||
jclass array_list = env->FindClass("java/util/ArrayList");
|
||||
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
|
||||
jobject ret = env->NewObject(array_list, array_list_construct);
|
||||
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
|
||||
auto *pointer = reinterpret_cast<mindspore::ModelParallelRunner *>(model_parallel_runner_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Model pointer from java is nullptr");
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto input_size = static_cast<int>(env->GetArrayLength(inputs));
|
||||
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
|
||||
std::vector<mindspore::MSTensor> c_inputs;
|
||||
for (int i = 0; i < input_size; i++) {
|
||||
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
|
||||
if (tensor_pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
return ret;
|
||||
}
|
||||
auto &ms_tensor = *static_cast<mindspore::MSTensor *>(tensor_pointer);
|
||||
c_inputs.push_back(ms_tensor);
|
||||
}
|
||||
std::vector<mindspore::MSTensor> outputs;
|
||||
pointer->Predict(c_inputs, &outputs);
|
||||
for (auto &tensor : outputs) {
|
||||
auto tensor_ptr = std::make_unique<mindspore::MSTensor>(tensor);
|
||||
if (tensor_ptr == nullptr) {
|
||||
MS_LOGE("Make ms tensor failed");
|
||||
return ret;
|
||||
}
|
||||
jclass long_object = env->FindClass("java/lang/Long");
|
||||
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
|
||||
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_ModelParallelRunner_free(JNIEnv *env, jobject thiz,
|
||||
jlong model_parallel_runner_ptr) {
|
||||
auto *pointer = reinterpret_cast<mindspore::ModelParallelRunner *>(model_parallel_runner_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("ModelParallelRunner pointer from java is nullptr");
|
||||
return;
|
||||
}
|
||||
delete pointer;
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <jni.h>
|
||||
#include "common/ms_log.h"
|
||||
#include "include/api/model_parallel_runner.h"
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_RunnerConfig_createRunnerConfig(JNIEnv *env, jobject thiz,
|
||||
jlong context_ptr) {
|
||||
auto runner_config = new (std::nothrow) mindspore::RunnerConfig();
|
||||
if (runner_config == nullptr) {
|
||||
MS_LOGE("new RunnerConfig fail!");
|
||||
return (jlong) nullptr;
|
||||
}
|
||||
auto *c_context_ptr = reinterpret_cast<mindspore::Context *>(context_ptr);
|
||||
if (c_context_ptr == nullptr) {
|
||||
delete runner_config;
|
||||
MS_LOGE("Context pointer from java is nullptr");
|
||||
return (jlong) nullptr;
|
||||
}
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
if (context == nullptr) {
|
||||
delete runner_config;
|
||||
MS_LOGE("Make context failed");
|
||||
return (jlong) nullptr;
|
||||
}
|
||||
context.reset(c_context_ptr);
|
||||
runner_config->model_ctx = context;
|
||||
return (jlong)runner_config;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_RunnerConfig_setNumModel(JNIEnv *env, jobject thiz,
|
||||
jstring runner_config_ptr,
|
||||
jint num_model) {
|
||||
auto *pointer = reinterpret_cast<mindspore::RunnerConfig *>(runner_config_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("runner config pointer from java is nullptr");
|
||||
return;
|
||||
}
|
||||
pointer->num_model = num_model;
|
||||
}
|
Loading…
Reference in New Issue