add cxx native api

This commit is contained in:
zhengjun10 2021-12-16 17:29:26 +08:00
parent 48c3be7953
commit c5953652d4
7 changed files with 919 additions and 0 deletions

View File

@ -48,6 +48,7 @@ endif()
set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
set(LITE_DIR ${TOP_DIR}/mindspore/lite)
set(NEW_NATIVE_DIR ${LITE_DIR}/java/src/main/native)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${LITE_DIR}) ## lite include
@ -69,6 +70,12 @@ set(JNI_SRC
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp
${NEW_NATIVE_DIR}/graph.cpp
${NEW_NATIVE_DIR}/model.cpp
${NEW_NATIVE_DIR}/ms_context.cpp
${NEW_NATIVE_DIR}/ms_tensor.cpp
${NEW_NATIVE_DIR}/train_config.cpp
${NEW_NATIVE_DIR}/version.cpp
)
set(LITE_SO_NAME mindspore-lite)

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <jni.h>
#include "common/ms_log.h"
#include "include/api/graph.h"
#include "include/api/serialization.h"
#include "include/api/types.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Graph_load(JNIEnv *env, jobject thiz, jstring ms_file) {
auto graph = new (std::nothrow) mindspore::Graph();
if (graph == nullptr) {
MS_LOGE("Model new failed");
return jlong(nullptr);
}
auto status =
mindspore::Serialization::Load(env->GetStringUTFChars(ms_file, JNI_FALSE), mindspore::ModelType::kMindIR, graph);
if (status != mindspore::kSuccess) {
MS_LOGE("Load graph from file failed");
delete graph;
return jlong(nullptr);
}
return jlong(graph);
}

View File

@ -0,0 +1,388 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/model.h"
#include <jni.h>
#include "common/ms_log.h"
#include "include/api/serialization.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByGraph(JNIEnv *env, jobject thiz, jlong graph_ptr,
jlong context_ptr, jlong cfg_ptr) {
auto *c_graph_ptr = reinterpret_cast<mindspore::Graph *>(graph_ptr);
if (c_graph_ptr == nullptr) {
MS_LOGE("Graph pointer from java is nullptr");
return jlong(nullptr);
}
auto *c_context_ptr = reinterpret_cast<mindspore::Context *>(context_ptr);
if (c_context_ptr == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return jlong(nullptr);
}
auto context = std::make_shared<mindspore::Context>();
if (context == nullptr) {
MS_LOGE("Make context failed");
return jlong(nullptr);
}
context.reset(c_context_ptr);
auto *c_cfg_ptr = reinterpret_cast<mindspore::TrainCfg *>(cfg_ptr);
auto cfg = std::make_shared<mindspore::TrainCfg>();
if (cfg == nullptr) {
MS_LOGE("Make train config failed");
return jlong(nullptr);
}
cfg.reset(c_cfg_ptr);
auto model = new (std::nothrow) mindspore::Model();
if (model == nullptr) {
MS_LOGE("Model new failed");
return jlong(nullptr);
}
auto status = model->Build(mindspore::GraphCell(*c_graph_ptr), context, cfg);
if (status != mindspore::kSuccess) {
MS_LOGE("Error (%d) during build of model", static_cast<int>(status));
delete model;
return jlong(nullptr);
}
return jlong(model);
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv *env, jobject thiz,
jobject model_buffer, jint model_type,
jlong context_ptr, jcharArray key_str,
jstring dec_mod) {
if (model_buffer == nullptr) {
MS_LOGE("Buffer from java is nullptr");
return reinterpret_cast<jlong>(nullptr);
}
mindspore::ModelType c_model_type;
if (model_type >= static_cast<int>(mindspore::kMindIR) && model_type <= static_cast<int>(mindspore::kMindIR_Opt)) {
c_model_type = static_cast<mindspore::ModelType>(model_type);
} else {
MS_LOGE("Invalid model type : %d", model_type);
return (jlong) nullptr;
}
jlong buffer_len = env->GetDirectBufferCapacity(model_buffer);
auto *model_buf = static_cast<char *>(env->GetDirectBufferAddress(model_buffer));
auto *c_context_ptr = reinterpret_cast<mindspore::Context *>(context_ptr);
if (c_context_ptr == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return jlong(nullptr);
}
auto context = std::make_shared<mindspore::Context>();
if (context == nullptr) {
MS_LOGE("Make context failed");
return jlong(nullptr);
}
context.reset(c_context_ptr);
auto model = new (std::nothrow) mindspore::Model();
if (model == nullptr) {
MS_LOGE("Model new failed");
return jlong(nullptr);
}
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
mindspore::Status status;
if (key_str != NULL) {
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
auto key_len = static_cast<size_t>(env->GetArrayLength(key_str));
char *dec_key_data = new (std::nothrow) char[key_len];
if (dec_key_data == nullptr) {
MS_LOGE("Dec key new failed");
delete model;
return jlong(nullptr);
}
for (size_t i = 0; i < key_len; i++) {
dec_key_data[i] = key_array[i];
}
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
mindspore::Key dec_key{dec_key_data, key_len};
status = model->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod);
} else {
status = model->Build(model_buf, buffer_len, c_model_type, context);
}
if (status != mindspore::kSuccess) {
MS_LOGE("Error (%d) during build of model", static_cast<int>(status));
delete model;
return jlong(nullptr);
}
return jlong(model);
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *env, jobject thiz, jstring model_path,
jint model_type, jlong context_ptr,
jcharArray key_str, jstring dec_mod) {
auto c_model_path = env->GetStringUTFChars(model_path, JNI_FALSE);
mindspore::ModelType c_model_type;
if (model_type >= static_cast<int>(mindspore::kMindIR) && model_type <= static_cast<int>(mindspore::kMindIR_Opt)) {
c_model_type = static_cast<mindspore::ModelType>(model_type);
} else {
MS_LOGE("Invalid model type : %d", model_type);
return (jlong) nullptr;
}
auto *c_context_ptr = reinterpret_cast<mindspore::Context *>(context_ptr);
if (c_context_ptr == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return jlong(nullptr);
}
auto context = std::make_shared<mindspore::Context>();
if (context == nullptr) {
MS_LOGE("Make context failed");
return jlong(nullptr);
}
context.reset(c_context_ptr);
auto model = new (std::nothrow) mindspore::Model();
if (model == nullptr) {
MS_LOGE("Model new failed");
return jlong(nullptr);
}
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
mindspore::Status status;
if (key_str != NULL) {
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
auto key_len = static_cast<size_t>(env->GetArrayLength(key_str));
char *dec_key_data = new (std::nothrow) char[key_len];
if (dec_key_data == nullptr) {
MS_LOGE("Dec key new failed");
delete model;
return jlong(nullptr);
}
for (size_t i = 0; i < key_len; i++) {
dec_key_data[i] = key_array[i];
}
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
mindspore::Key dec_key{dec_key_data, key_len};
status = model->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod);
} else {
status = model->Build(c_model_path, c_model_type, context);
}
if (status != mindspore::kSuccess) {
MS_LOGE("Error (%d) during build of model", static_cast<int>(status));
delete model;
return jlong(nullptr);
}
return jlong(model);
}
jobject GetInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_ptr, bool is_input) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
auto *pointer = reinterpret_cast<mindspore::Model *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return ret;
}
std::vector<mindspore::MSTensor> tensors;
if (is_input) {
tensors = pointer->GetInputs();
} else {
tensors = pointer->GetOutputs();
}
for (auto &tensor : tensors) {
auto tensor_ptr = std::make_unique<mindspore::MSTensor>(tensor);
if (tensor_ptr == nullptr) {
MS_LOGE("Make ms tensor failed");
return ret;
}
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
}
return ret;
}
jlong GetTensorByInOutName(JNIEnv *env, jlong model_ptr, jstring tensor_name, bool is_input) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return jlong(nullptr);
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
mindspore::MSTensor tensor;
if (is_input) {
tensor = lite_model_ptr->GetInputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
} else {
tensor = lite_model_ptr->GetOutputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
}
auto tensor_ptr = std::make_unique<mindspore::MSTensor>(tensor);
if (tensor_ptr == nullptr) {
MS_LOGE("Make ms tensor failed");
return jlong(nullptr);
}
return jlong(tensor_ptr.release());
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getInputs(JNIEnv *env, jobject thiz, jlong model_ptr) {
return GetInOrOutTensors(env, thiz, model_ptr, true);
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_getInputByTensorName(JNIEnv *env, jobject thiz,
jlong model_ptr, jstring tensor_name) {
return GetTensorByInOutName(env, model_ptr, tensor_name, true);
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputs(JNIEnv *env, jobject thiz, jlong model_ptr) {
return GetInOrOutTensors(env, thiz, model_ptr, false);
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_getOutputByTensorName(JNIEnv *env, jobject thiz,
jlong model_ptr,
jstring tensor_name) {
return GetTensorByInOutName(env, model_ptr, tensor_name, false);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_getTrainMode(JNIEnv *env, jobject thiz,
jlong model_ptr) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return jlong(false);
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
return static_cast<jboolean>(lite_model_ptr->GetTrainMode());
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_setTrainMode(JNIEnv *env, jobject thiz, jlong model_ptr,
jboolean train_mode) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return jlong(false);
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto status = lite_model_ptr->SetTrainMode(train_mode);
return static_cast<jboolean>(status.IsOk());
}
std::vector<mindspore::MSTensor> convertArrayToVector(JNIEnv *env, jlongArray inputs) {
auto input_size = static_cast<int>(env->GetArrayLength(inputs));
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
std::vector<mindspore::MSTensor> c_inputs;
for (int i = 0; i < input_size; i++) {
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
if (tensor_pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return c_inputs;
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(tensor_pointer);
c_inputs.push_back(*ms_tensor_ptr);
}
return c_inputs;
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_predict(JNIEnv *env, jobject thiz, jlong model_ptr,
jlongArray inputs, jlongArray outputs) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return jlong(false);
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto c_inputs = convertArrayToVector(env, inputs);
auto c_outputs = convertArrayToVector(env, outputs);
auto status = lite_model_ptr->Predict(c_inputs, &c_outputs);
return static_cast<jboolean>(status.IsOk());
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *env, jobject thiz, jlong model_ptr,
jlongArray inputs, jobjectArray dims) {
std::vector<std::vector<int64_t>> c_dims;
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return false;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto input_size = static_cast<int>(env->GetArrayLength(inputs));
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
std::vector<mindspore::MSTensor> c_inputs;
for (int i = 0; i < input_size; i++) {
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
if (tensor_pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return false;
}
auto &ms_tensor = *static_cast<mindspore::MSTensor *>(tensor_pointer);
c_inputs.push_back(ms_tensor);
}
auto tensor_size = static_cast<int>(env->GetArrayLength(dims));
for (int i = 0; i < tensor_size; i++) {
auto array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i));
auto dim_size = static_cast<int>(env->GetArrayLength(array));
jint *dim_data = env->GetIntArrayElements(array, nullptr);
std::vector<int64_t> tensor_dims(dim_size);
for (int j = 0; j < dim_size; j++) {
tensor_dims[j] = dim_data[j];
}
c_dims.push_back(tensor_dims);
env->ReleaseIntArrayElements(array, dim_data, JNI_ABORT);
env->DeleteLocalRef(array);
}
auto ret = lite_model_ptr->Resize(c_inputs, c_dims);
return (jboolean)(ret.IsOk());
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_export(JNIEnv *env, jobject thiz, jlong model_ptr,
jstring model_name, jint quantization_type,
jboolean export_inference_only,
jobjectArray tensorNames) {
auto *model_pointer = reinterpret_cast<void *>(model_ptr);
if (model_pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(model_pointer);
auto model_path = env->GetStringUTFChars(model_name, JNI_FALSE);
std::vector<std::string> output_tensor_names;
if (tensorNames != nullptr) {
auto tensor_size = static_cast<int>(env->GetArrayLength(tensorNames));
for (int i = 0; i < tensor_size; i++) {
auto tensor_name = static_cast<jstring>(env->GetObjectArrayElement(tensorNames, i));
output_tensor_names[i] = env->GetStringUTFChars(tensor_name, JNI_FALSE);
env->DeleteLocalRef(tensor_name);
}
}
mindspore::QuantizationType quant_type;
if (quantization_type >= static_cast<int>(mindspore::kNoQuant) &&
quantization_type <= static_cast<int>(mindspore::kFullQuant)) {
quant_type = static_cast<mindspore::QuantizationType>(quantization_type);
} else {
MS_LOGE("Invalid quantization_type : %d", quantization_type);
return (jlong) nullptr;
}
auto ret = mindspore::Serialization::ExportModel(*lite_model_ptr, mindspore::kMindIR, model_path, quant_type,
export_inference_only, output_tensor_names);
return (jboolean)(ret.IsOk());
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_Model_free(JNIEnv *env, jobject thiz, jlong model_ptr) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
delete (lite_model_ptr);
}

View File

@ -0,0 +1,99 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <jni.h>
#include "common/ms_log.h"
#include "include/api/context.h"
constexpr const int kNpuDevice = 2;
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_MSContext_createMSContext(JNIEnv *env, jobject thiz,
jint thread_num,
jint cpu_bind_mode,
jboolean enable_parallel) {
auto *context = new (std::nothrow) mindspore::Context();
if (context == nullptr) {
MS_LOGE("new Context fail!");
return (jlong) nullptr;
}
context->SetThreadNum(thread_num);
context->SetEnableParallel(enable_parallel);
context->SetThreadAffinity(cpu_bind_mode);
return (jlong)context;
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceInfo(JNIEnv *env, jobject thiz,
jlong context_ptr, jint device_type,
jboolean enable_fp16,
jint npu_freq) {
auto *pointer = reinterpret_cast<void *>(context_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return;
}
auto *c_context_ptr = static_cast<mindspore::Context *>(pointer);
auto &device_list = c_context_ptr->MutableDeviceInfo();
switch (device_type) {
case 0: {
auto cpu_device_info = std::make_shared<mindspore::CPUDeviceInfo>();
if (cpu_device_info == nullptr) {
MS_LOGE("cpu device info is nullptr");
delete (c_context_ptr);
return;
}
cpu_device_info->SetEnableFP16(enable_fp16);
device_list.push_back(cpu_device_info);
break;
}
case 1: // DT_GPU
{
auto gpu_device_info = std::make_shared<mindspore::GPUDeviceInfo>();
if (gpu_device_info == nullptr) {
MS_LOGE("gpu device info is nullptr");
delete (c_context_ptr);
return;
}
gpu_device_info->SetEnableFP16(enable_fp16);
device_list.push_back(gpu_device_info);
break;
}
case kNpuDevice: // DT_NPU
{
auto npu_device_info = std::make_shared<mindspore::KirinNPUDeviceInfo>();
if (npu_device_info == nullptr) {
MS_LOGE("npu device info is nullptr");
delete (c_context_ptr);
return;
}
npu_device_info->SetFrequency(npu_freq);
device_list.push_back(npu_device_info);
break;
}
default:
MS_LOGE("Invalid device_type : %d", device_type);
delete (c_context_ptr);
}
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_free(JNIEnv *env, jobject thiz,
jlong context_ptr) {
auto *pointer = reinterpret_cast<void *>(context_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return;
}
auto *c_context_ptr = static_cast<mindspore::Context *>(pointer);
delete (c_context_ptr);
}

View File

@ -0,0 +1,277 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <jni.h>
#include <cstring>
#include "common/ms_log.h"
#include "include/api/types.h"
extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_MSTensor_getShape(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewIntArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
auto local_shape = ms_tensor_ptr->Shape();
auto shape_size = local_shape.size();
jintArray shape = env->NewIntArray(shape_size);
auto *tmp = new jint[shape_size];
for (size_t i = 0; i < shape_size; i++) {
tmp[i] = static_cast<int>(local_shape.at(i));
}
env->SetIntArrayRegion(shape, 0, shape_size, tmp);
delete[](tmp);
return shape;
}
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_MSTensor_getDataType(JNIEnv *env, jobject thiz, jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
return jint(ms_tensor_ptr->DataType());
}
extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_MSTensor_getByteData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewByteArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
auto *local_data = static_cast<jbyte *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewByteArray(0);
}
auto local_size = ms_tensor_ptr->DataSize();
if (local_size <= 0) {
MS_LOGE("Size of tensor is negative: %zu", local_size);
return env->NewByteArray(0);
}
auto ret = env->NewByteArray(local_size);
env->SetByteArrayRegion(ret, 0, local_size, local_data);
return ret;
}
extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_MSTensor_getLongData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewLongArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
auto *local_data = static_cast<jlong *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewLongArray(0);
}
if (ms_tensor_ptr->DataType() != mindspore::DataType::kNumberTypeInt64) {
MS_LOGE("data type is error : %d", static_cast<int>(ms_tensor_ptr->DataType()));
return env->NewLongArray(0);
}
auto local_element_num = ms_tensor_ptr->ElementNum();
if (local_element_num <= 0) {
MS_LOGE("ElementsNum of tensor is negative: %d", static_cast<int>(local_element_num));
return env->NewLongArray(0);
}
auto ret = env->NewLongArray(local_element_num);
env->SetLongArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_MSTensor_getIntData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewIntArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
auto *local_data = static_cast<jint *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewIntArray(0);
}
if (ms_tensor_ptr->DataType() != mindspore::DataType::kNumberTypeInt32) {
MS_LOGE("data type is error : %d", static_cast<int>(ms_tensor_ptr->DataType()));
return env->NewIntArray(0);
}
auto local_element_num = ms_tensor_ptr->ElementNum();
if (local_element_num <= 0) {
MS_LOGE("ElementsNum of tensor is negative: %d", static_cast<int>(local_element_num));
return env->NewIntArray(0);
}
auto ret = env->NewIntArray(local_element_num);
env->SetIntArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_MSTensor_getFloatData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewFloatArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
auto *local_data = static_cast<jfloat *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewFloatArray(0);
}
if (ms_tensor_ptr->DataType() != mindspore::DataType::kNumberTypeFloat32) {
MS_LOGE("data type is error : %d", static_cast<int>(ms_tensor_ptr->DataType()));
return env->NewFloatArray(0);
}
auto local_element_num = ms_tensor_ptr->ElementNum();
if (local_element_num <= 0) {
MS_LOGE("ElementsNum of tensor is negative: %d", static_cast<int>(local_element_num));
return env->NewFloatArray(0);
}
auto ret = env->NewFloatArray(local_element_num);
env->SetFloatArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setData(JNIEnv *env, jobject thiz, jlong tensor_ptr,
jbyteArray data, jlong data_len) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
if (static_cast<size_t>(data_len) != ms_tensor_ptr->DataSize()) {
#ifdef ENABLE_ARM32
MS_LOGE("data_len(%lld) not equal to Size of ms_tensor(%zu)", data_len, ms_tensor_ptr->DataSize());
#else
MS_LOGE("data_len(%ld) not equal to Size of ms_tensor(%zu)", data_len, ms_tensor_ptr->DataSize());
#endif
return static_cast<jboolean>(false);
}
jboolean is_copy = false;
auto *data_arr = env->GetByteArrayElements(data, &is_copy);
auto *local_data = ms_tensor_ptr->MutableData();
memcpy(local_data, data_arr, data_len);
env->ReleaseByteArrayElements(data, data_arr, JNI_ABORT);
return static_cast<jboolean>(true);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setByteBufferData(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jobject buffer) {
auto *p_data = reinterpret_cast<jbyte *>(env->GetDirectBufferAddress(buffer)); // get buffer pointer
jlong data_len = env->GetDirectBufferCapacity(buffer); // get buffer capacity
if (p_data == nullptr) {
MS_LOGE("GetDirectBufferAddress return null");
return false;
}
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
if (static_cast<size_t>(data_len) != ms_tensor_ptr->DataSize()) {
#ifdef ENABLE_ARM32
MS_LOGE("data_len(%lld) not equal to Size of ms_tensor(%zu)", data_len, ms_tensor_ptr->DataSize());
#else
MS_LOGE("data_len(%ld) not equal to Size of ms_tensor(%zu)", data_len, ms_tensor_ptr->DataSize());
#endif
return static_cast<jboolean>(false);
}
auto *local_data = ms_tensor_ptr->MutableData();
memcpy(local_data, p_data, data_len);
return static_cast<jboolean>(true);
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_size(JNIEnv *env, jobject thiz, jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return 0;
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
return ms_tensor_ptr->DataSize();
}
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_MSTensor_elementsNum(JNIEnv *env, jobject thiz, jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return 0;
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
return ms_tensor_ptr->ElementNum();
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_MSTensor_free(JNIEnv *env, jobject thiz, jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return;
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
delete (ms_tensor_ptr);
}
extern "C" JNIEXPORT jstring JNICALL Java_com_mindspore_MSTensor_tensorName(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return nullptr;
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
return env->NewStringUTF(ms_tensor_ptr->Name().c_str());
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_createTensor(JNIEnv *env, jobject thiz,
jstring tensor_name, jobject buffer) {
auto *p_data = reinterpret_cast<jbyte *>(env->GetDirectBufferAddress(buffer)); // get buffer pointer
jlong data_len = env->GetDirectBufferCapacity(buffer); // get buffer capacity
if (p_data == nullptr) {
MS_LOGE("GetDirectBufferAddress return null");
return false;
}
char *tensor_data(new char[data_len]);
memcpy(tensor_data, p_data, data_len);
int tensor_size = static_cast<jint>(data_len / sizeof(float));
std::vector<int64_t> shape = {tensor_size};
auto tensor =
mindspore::MSTensor::CreateTensor(env->GetStringUTFChars(tensor_name, JNI_FALSE),
mindspore::DataType::kNumberTypeFloat32, shape, tensor_data, data_len);
return jlong(tensor);
}

View File

@ -0,0 +1,88 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <jni.h>
#include "common/ms_log.h"
#include "include/api/types.h"
#include "include/api/cfg.h"
constexpr const int kOLKeepBatchNorm = 2;
constexpr const int kOLNotKeepBatchNorm = 3;
constexpr const int kOLAuto = 4;
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_TrainCfg_createTrainCfg(JNIEnv *env, jobject thiz,
jstring loss_name,
jint optimizationLevel,
jboolean accmulateGrads) {
mindspore::OptimizationLevel ol;
switch (optimizationLevel) {
case 0:
ol = mindspore::OptimizationLevel::kO0;
break;
case kOLKeepBatchNorm:
ol = mindspore::OptimizationLevel::kO2;
break;
case kOLNotKeepBatchNorm:
ol = mindspore::OptimizationLevel::kO3;
break;
case kOLAuto:
ol = mindspore::OptimizationLevel::kAuto;
break;
default:
MS_LOGE("Invalid optimization_type : %d", optimizationLevel);
return (jlong) nullptr;
}
auto *traincfg_ptr = new (std::nothrow) mindspore::TrainCfg();
if (traincfg_ptr == nullptr) {
MS_LOGE("new train config fail!");
return (jlong) nullptr;
}
if (loss_name != nullptr) {
traincfg_ptr->loss_name_ = env->GetStringUTFChars(loss_name, JNI_FALSE);
}
traincfg_ptr->optimization_level_ = ol;
traincfg_ptr->accumulate_gradients_ = accmulateGrads;
return (jlong)traincfg_ptr;
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_TrainCfg_addMixPrecisionCfg(JNIEnv *env, jobject thiz,
jlong train_cfg_ptr,
jboolean dynamic_loss_scale,
jfloat loss_scale,
jint thresholdIterNum) {
mindspore::MixPrecisionCfg mix_precision_cfg;
mix_precision_cfg.dynamic_loss_scale_ = dynamic_loss_scale;
mix_precision_cfg.loss_scale_ = loss_scale;
mix_precision_cfg.num_of_not_nan_iter_th_ = thresholdIterNum;
auto *pointer = reinterpret_cast<void *>(train_cfg_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return jboolean(false);
}
auto *c_train_cfg_ptr = static_cast<mindspore::TrainCfg *>(pointer);
c_train_cfg_ptr->mix_precision_cfg_ = mix_precision_cfg;
return jboolean(true);
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_TrainCfg_free(JNIEnv *env, jobject thiz,
jlong train_cfg_ptr) {
auto *pointer = reinterpret_cast<void *>(train_cfg_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return;
}
auto *c_cfg_ptr = static_cast<mindspore::TrainCfg *>(pointer);
delete (c_cfg_ptr);
}

View File

@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <jni.h>
#include "common/ms_log.h"
#include "include/api/types.h"
extern "C" JNIEXPORT jstring JNICALL Java_com_mindspore_config_Version_version(JNIEnv *env, jclass thiz) {
return env->NewStringUTF(mindspore::Version().c_str());
}