forked from mindspore-Ecosystem/mindspore
java model api
This commit is contained in:
parent
42799a9e2a
commit
05f3eb768a
|
@ -36,7 +36,7 @@ public class Model {
|
|||
* Construct function.
|
||||
*/
|
||||
public Model() {
|
||||
this.modelPtr = 0;
|
||||
this.modelPtr = this.createModel();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -52,8 +52,7 @@ public class Model {
|
|||
return false;
|
||||
}
|
||||
long cfgPtr = cfg != null ? cfg.getTrainCfgPtr() : 0;
|
||||
modelPtr = this.buildByGraph(graph.getGraphPtr(), context.getMSContextPtr(), cfgPtr);
|
||||
return modelPtr != 0;
|
||||
return this.buildByGraph(modelPtr, graph.getGraphPtr(), context.getMSContextPtr(), cfgPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -71,8 +70,7 @@ public class Model {
|
|||
if (context == null || buffer == null || dec_key == null || dec_mode == null) {
|
||||
return false;
|
||||
}
|
||||
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
|
||||
return modelPtr != 0;
|
||||
return this.buildByBuffer(modelPtr, buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -87,8 +85,7 @@ public class Model {
|
|||
if (context == null || buffer == null) {
|
||||
return false;
|
||||
}
|
||||
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, "", "");
|
||||
return modelPtr != 0;
|
||||
return this.buildByBuffer(modelPtr, buffer, modelType, context.getMSContextPtr(), null, "", "");
|
||||
}
|
||||
|
||||
|
||||
|
@ -107,8 +104,7 @@ public class Model {
|
|||
if (context == null || modelPath == null || dec_key == null || dec_mode == null) {
|
||||
return false;
|
||||
}
|
||||
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
|
||||
return modelPtr != 0;
|
||||
return this.buildByPath(modelPtr, modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -123,8 +119,7 @@ public class Model {
|
|||
if (context == null || modelPath == null) {
|
||||
return false;
|
||||
}
|
||||
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, "", "");
|
||||
return modelPtr != 0;
|
||||
return this.buildByPath(modelPtr, modelPath, modelType, context.getMSContextPtr(), null, "", "");
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -375,14 +370,16 @@ public class Model {
|
|||
this.free(modelPtr);
|
||||
}
|
||||
|
||||
private native long createModel();
|
||||
|
||||
private native void free(long modelPtr);
|
||||
|
||||
private native long buildByGraph(long graphPtr, long contextPtr, long cfgPtr);
|
||||
private native boolean buildByGraph(long modelPtr, long graphPtr, long contextPtr, long cfgPtr);
|
||||
|
||||
private native long buildByPath(String modelPath, int modelType, long contextPtr,
|
||||
private native boolean buildByPath(long modelPtr, String modelPath, int modelType, long contextPtr,
|
||||
char[] dec_key, String dec_mod, String cropto_lib_path);
|
||||
|
||||
private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr,
|
||||
private native boolean buildByBuffer(long modelPtr, MappedByteBuffer buffer, int modelType, long contextPtr,
|
||||
char[] dec_key, String dec_mod, String cropto_lib_path);
|
||||
|
||||
private native List<Long> getInputs(long modelPtr);
|
||||
|
|
|
@ -19,23 +19,40 @@
|
|||
#include "common/log_adapter.h"
|
||||
#include "include/api/serialization.h"
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByGraph(JNIEnv *env, jobject thiz, jlong graph_ptr,
|
||||
jlong context_ptr, jlong cfg_ptr) {
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_createModel(JNIEnv *env, jobject thiz) {
|
||||
auto model = new (std::nothrow) mindspore::Model();
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "createModel failed";
|
||||
return jlong(nullptr);
|
||||
}
|
||||
return jlong(model);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByGraph(JNIEnv *env, jobject thiz, jlong model_ptr,
|
||||
jlong graph_ptr, jlong context_ptr,
|
||||
jlong cfg_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "Session pointer from java is nullptr";
|
||||
return false;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
|
||||
auto *c_graph_ptr = reinterpret_cast<mindspore::Graph *>(graph_ptr);
|
||||
if (c_graph_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Graph pointer from java is nullptr";
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto *c_context_ptr = reinterpret_cast<mindspore::Context *>(context_ptr);
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "Make context failed";
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
context.reset(c_context_ptr);
|
||||
|
||||
|
@ -43,42 +60,42 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByGraph(JNIEnv
|
|||
auto cfg = std::make_shared<mindspore::TrainCfg>();
|
||||
if (cfg == nullptr) {
|
||||
MS_LOG(ERROR) << "Make train config failed";
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
if (c_cfg_ptr != nullptr) {
|
||||
cfg.reset(c_cfg_ptr);
|
||||
} else {
|
||||
cfg.reset();
|
||||
}
|
||||
auto model = new (std::nothrow) mindspore::Model();
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Model new failed";
|
||||
return jlong(nullptr);
|
||||
}
|
||||
|
||||
auto status = model->Build(mindspore::GraphCell(*c_graph_ptr), context, cfg);
|
||||
auto status = lite_model_ptr->Build(mindspore::GraphCell(*c_graph_ptr), context, cfg);
|
||||
if (status != mindspore::kSuccess) {
|
||||
MS_LOG(ERROR) << "Error status " << static_cast<int>(status) << " during build of model";
|
||||
delete model;
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
return jlong(model);
|
||||
return true;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv *env, jobject thiz,
|
||||
jobject model_buffer, jint model_type,
|
||||
jlong context_ptr, jcharArray key_str,
|
||||
jstring dec_mod, jstring cropto_lib_path) {
|
||||
extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv *env, jobject thiz, jlong model_ptr,
|
||||
jobject model_buffer, jint model_type,
|
||||
jlong context_ptr, jcharArray key_str,
|
||||
jstring dec_mod, jstring cropto_lib_path) {
|
||||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "Session pointer from java is nullptr";
|
||||
return false;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
|
||||
if (model_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Buffer from java is nullptr";
|
||||
return reinterpret_cast<jlong>(nullptr);
|
||||
return false;
|
||||
}
|
||||
mindspore::ModelType c_model_type;
|
||||
if (model_type >= static_cast<int>(mindspore::kMindIR) && model_type <= static_cast<int>(mindspore::kMindIR_Lite)) {
|
||||
c_model_type = static_cast<mindspore::ModelType>(model_type);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid model type : " << model_type;
|
||||
return (jlong) nullptr;
|
||||
return false;
|
||||
}
|
||||
jlong buffer_len = env->GetDirectBufferCapacity(model_buffer);
|
||||
auto *model_buf = static_cast<char *>(env->GetDirectBufferAddress(model_buffer));
|
||||
|
@ -86,20 +103,14 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
|
|||
auto *c_context_ptr = reinterpret_cast<mindspore::Context *>(context_ptr);
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "Make context failed";
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
context.reset(c_context_ptr);
|
||||
|
||||
auto model = new (std::nothrow) mindspore::Model();
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Model new failed";
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
|
||||
mindspore::Status status;
|
||||
if (key_str != NULL) {
|
||||
|
@ -108,8 +119,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
|
|||
char *dec_key_data = new (std::nothrow) char[key_len];
|
||||
if (dec_key_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Dec key new failed";
|
||||
delete model;
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < key_len; i++) {
|
||||
dec_key_data[i] = key_array[i];
|
||||
|
@ -117,47 +127,46 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
|
|||
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
|
||||
mindspore::Key dec_key{dec_key_data, key_len};
|
||||
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
|
||||
status = model->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
|
||||
status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
|
||||
} else {
|
||||
status = model->Build(model_buf, buffer_len, c_model_type, context);
|
||||
status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context);
|
||||
}
|
||||
if (status != mindspore::kSuccess) {
|
||||
MS_LOG(ERROR) << "Error status " << static_cast<int>(status) << " during build of model";
|
||||
delete model;
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
return jlong(model);
|
||||
return true;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *env, jobject thiz, jstring model_path,
|
||||
jint model_type, jlong context_ptr,
|
||||
jcharArray key_str, jstring dec_mod,
|
||||
jstring cropto_lib_path) {
|
||||
extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *env, jobject thiz, jlong model_ptr,
|
||||
jstring model_path, jint model_type,
|
||||
jlong context_ptr, jcharArray key_str,
|
||||
jstring dec_mod, jstring cropto_lib_path) {
|
||||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "Session pointer from java is nullptr";
|
||||
return false;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
auto c_model_path = env->GetStringUTFChars(model_path, JNI_FALSE);
|
||||
mindspore::ModelType c_model_type;
|
||||
if (model_type >= static_cast<int>(mindspore::kMindIR) && model_type <= static_cast<int>(mindspore::kMindIR_Lite)) {
|
||||
c_model_type = static_cast<mindspore::ModelType>(model_type);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid model type : " << model_type;
|
||||
return (jlong) nullptr;
|
||||
return false;
|
||||
}
|
||||
auto *c_context_ptr = reinterpret_cast<mindspore::Context *>(context_ptr);
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "Make context failed";
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
context.reset(c_context_ptr);
|
||||
|
||||
auto model = new (std::nothrow) mindspore::Model();
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Model new failed";
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
|
||||
mindspore::Status status;
|
||||
if (key_str != NULL) {
|
||||
|
@ -166,8 +175,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *
|
|||
char *dec_key_data = new (std::nothrow) char[key_len];
|
||||
if (dec_key_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Dec key new failed";
|
||||
delete model;
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < key_len; i++) {
|
||||
dec_key_data[i] = key_array[i];
|
||||
|
@ -175,16 +183,15 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *
|
|||
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
|
||||
mindspore::Key dec_key{dec_key_data, key_len};
|
||||
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
|
||||
status = model->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
|
||||
status = lite_model_ptr->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
|
||||
} else {
|
||||
status = model->Build(c_model_path, c_model_type, context);
|
||||
status = lite_model_ptr->Build(c_model_path, c_model_type, context);
|
||||
}
|
||||
if (status != mindspore::kSuccess) {
|
||||
MS_LOG(ERROR) << "Error status " << static_cast<int>(status) << " during build of model";
|
||||
delete model;
|
||||
return jlong(nullptr);
|
||||
return false;
|
||||
}
|
||||
return jlong(model);
|
||||
return true;
|
||||
}
|
||||
|
||||
jobject GetInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_ptr, bool is_input) {
|
||||
|
|
Loading…
Reference in New Issue