java jni codecheck

This commit is contained in:
albert-yan 2023-02-20 20:58:15 +08:00
parent fa7e16269a
commit 11c7ef527a
17 changed files with 433 additions and 88 deletions

View File

@ -18,6 +18,11 @@ package com.mindspore;
import com.mindspore.config.MindsporeLite; import com.mindspore.config.MindsporeLite;
/**
* Graph Class
*
* @since v1.0
*/
public class Graph { public class Graph {
static { static {
MindsporeLite.init(); MindsporeLite.init();

View File

@ -22,6 +22,11 @@ import java.nio.ByteBuffer;
import java.lang.reflect.Array; import java.lang.reflect.Array;
import java.util.HashMap; import java.util.HashMap;
/**
* The MSTensor class defines a tensor in MindSpore.
*
* @since v1.0
*/
public class MSTensor { public class MSTensor {
static { static {
MindsporeLite.init(); MindsporeLite.init();

View File

@ -25,6 +25,11 @@ import java.nio.MappedByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
/**
* The Model class is used to define a MindSpore model, facilitating computational graph management.
*
* @since v1.0
*/
public class Model { public class Model {
static { static {
MindsporeLite.init(); MindsporeLite.init();
@ -61,16 +66,20 @@ public class Model {
* @param buffer model buffer. * @param buffer model buffer.
* @param modelType model type. * @param modelType model type.
* @param context model build context. * @param context model build context.
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16. * @param decKey define the key used to decrypt the ciphertext model. The key length is 16.
* @param dec_mode define the decryption mode. Options: AES-GCM. * @param decMode define the decryption mode. Options: AES-GCM.
* @param cropto_lib_path define the openssl library path. * @param croptoLibPath define the openssl library path.
* @return model build status. * @return model build status.
*/ */
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) { public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] decKey, String decMode,
if (context == null || buffer == null || dec_key == null || dec_mode == null) { String croptoLibPath) {
boolean isValid = (context != null && buffer != null && decKey != null && decMode != null &&
croptoLibPath != null);
if (!isValid) {
return false; return false;
} }
return this.buildByBuffer(modelPtr, buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path); return this.buildByBuffer(modelPtr, buffer, modelType, context.getMSContextPtr(), decKey, decMode,
croptoLibPath);
} }
/** /**
@ -95,16 +104,19 @@ public class Model {
* @param modelPath model path. * @param modelPath model path.
* @param modelType model type. * @param modelType model type.
* @param context model build context. * @param context model build context.
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16. * @param decKey define the key used to decrypt the ciphertext model. The key length is 16.
* @param dec_mode define the decryption mode. Options: AES-GCM. * @param decMode define the decryption mode. Options: AES-GCM.
* @param cropto_lib_path define the openssl library path. * @param croptoLibPath define the openssl library path.
* @return model build status. * @return model build status.
*/ */
public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) { public boolean build(String modelPath, int modelType, MSContext context, char[] decKey, String decMode,
if (context == null || modelPath == null || dec_key == null || dec_mode == null) { String croptoLibPath) {
boolean isValid = (context != null && modelPath != null && decKey != null && decMode != null &&
croptoLibPath != null);
if (!isValid) {
return false; return false;
} }
return this.buildByPath(modelPtr, modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path); return this.buildByPath(modelPtr, modelPath, modelType, context.getMSContextPtr(), decKey, decMode, croptoLibPath);
} }
/** /**
@ -164,9 +176,9 @@ public class Model {
* @return input tensors. * @return input tensors.
*/ */
public List<MSTensor> getInputs() { public List<MSTensor> getInputs() {
List<Long> ret = this.getInputs(this.modelPtr); List<Long> tensorAddrs = this.getInputs(this.modelPtr);
List<MSTensor> tensors = new ArrayList<>(); List<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
for (Long msTensorAddr : ret) { for (Long msTensorAddr : tensorAddrs) {
MSTensor msTensor = new MSTensor(msTensorAddr); MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor); tensors.add(msTensor);
} }
@ -179,9 +191,9 @@ public class Model {
* @return model outputs tensor. * @return model outputs tensor.
*/ */
public List<MSTensor> getOutputs() { public List<MSTensor> getOutputs() {
List<Long> ret = this.getOutputs(this.modelPtr); List<Long> tensorAddrs = this.getOutputs(this.modelPtr);
List<MSTensor> tensors = new ArrayList<>(); List<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
for (Long msTensorAddr : ret) { for (Long msTensorAddr : tensorAddrs) {
MSTensor msTensor = new MSTensor(msTensorAddr); MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor); tensors.add(msTensor);
} }
@ -224,11 +236,11 @@ public class Model {
*/ */
public List<MSTensor> getOutputsByNodeName(String nodeName) { public List<MSTensor> getOutputsByNodeName(String nodeName) {
if (nodeName == null) { if (nodeName == null) {
return null; return new ArrayList<>();
} }
List<Long> ret = this.getOutputsByNodeName(this.modelPtr, nodeName); List<Long> tensorAddrs = this.getOutputsByNodeName(this.modelPtr, nodeName);
List<MSTensor> tensors = new ArrayList<>(); List<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
for (Long msTensorAddr : ret) { for (Long msTensorAddr : tensorAddrs) {
MSTensor msTensor = new MSTensor(msTensorAddr); MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor); tensors.add(msTensor);
} }
@ -296,9 +308,9 @@ public class Model {
* @return FeaturesMap Tensor list. * @return FeaturesMap Tensor list.
*/ */
public List<MSTensor> getFeatureMaps() { public List<MSTensor> getFeatureMaps() {
List<Long> ret = this.getFeatureMaps(this.modelPtr); List<Long> tensorAddrs = this.getFeatureMaps(this.modelPtr);
ArrayList<MSTensor> tensors = new ArrayList<>(); ArrayList<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
for (Long msTensorAddr : ret) { for (Long msTensorAddr : tensorAddrs) {
MSTensor msTensor = new MSTensor(msTensorAddr); MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor); tensors.add(msTensor);
} }

View File

@ -22,7 +22,13 @@ package com.mindspore.config;
* @since v1.0 * @since v1.0
*/ */
public class CpuBindMode { public class CpuBindMode {
// bind mind cpu
public static final int MID_CPU = 2; public static final int MID_CPU = 2;
// bind high cpu
public static final int HIGHER_CPU = 1; public static final int HIGHER_CPU = 1;
// no bind
public static final int NO_BIND = 0; public static final int NO_BIND = 0;
} }

View File

@ -21,8 +21,16 @@ package com.mindspore.config;
* @since v1.0 * @since v1.0
*/ */
public class DeviceType { public class DeviceType {
// support cpu
public static final int DT_CPU = 0; public static final int DT_CPU = 0;
// support gpu
public static final int DT_GPU = 1; public static final int DT_GPU = 1;
// support npu
public static final int DT_NPU = 2; public static final int DT_NPU = 2;
// support ascend
public static final int DT_ASCEND = 3; public static final int DT_ASCEND = 3;
} }

View File

@ -20,17 +20,22 @@ import java.util.ArrayList;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
/**
* Context is used to store environment variables during execution.
*
* @since v1.0
*/
public class MSContext { public class MSContext {
private static Logger LOGGER = MindsporeLite.GetLogger(); private static Logger LOGGER = MindsporeLite.GetLogger();
static { static {
MindsporeLite.init(); MindsporeLite.init();
} }
private long msContextPtr; private static final long EMPTY_CONTEXT_PTR_VALUE = 0L;
private static final long EMPTY_CONTEXT_PTR_VALUE = 0;
private static final int ERROR_VALUE = -1; private static final int ERROR_VALUE = -1;
private static final String NULLPTR_ERROR_MESSAGE="Context pointer from java is nullptr.\n"; private static final String NULLPTR_ERROR_MESSAGE="Context pointer from java is nullptr.";
private long msContextPtr;
/** /**
* Construct function. * Construct function.
@ -148,13 +153,13 @@ public class MSContext {
* @return The current thread number setting. * @return The current thread number setting.
*/ */
public int getThreadNum() { public int getThreadNum() {
int ret_val = ERROR_VALUE; int retVal = ERROR_VALUE;
if (isInitialized()) { if (isInitialized()) {
ret_val = getThreadNum(this.msContextPtr); retVal = getThreadNum(this.msContextPtr);
} else { } else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
} }
return ret_val; return retVal;
} }
/** /**
@ -178,13 +183,13 @@ public class MSContext {
* @return The current operators parallel number setting. * @return The current operators parallel number setting.
*/ */
public int getInterOpParallelNum() { public int getInterOpParallelNum() {
int ret_val = ERROR_VALUE; int retVal = ERROR_VALUE;
if (isInitialized()) { if (isInitialized()) {
ret_val = getInterOpParallelNum(this.msContextPtr); retVal = getInterOpParallelNum(this.msContextPtr);
} else { } else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
} }
return ret_val; return retVal;
} }
/** /**
@ -209,13 +214,13 @@ public class MSContext {
* @return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first * @return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
*/ */
public int getThreadAffinityMode() { public int getThreadAffinityMode() {
int ret_val = ERROR_VALUE; int retVal = ERROR_VALUE;
if (isInitialized()) { if (isInitialized()) {
ret_val = getThreadAffinityMode(this.msContextPtr); retVal = getThreadAffinityMode(this.msContextPtr);
} else { } else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
} }
return ret_val; return retVal;
} }
/** /**
@ -229,11 +234,11 @@ public class MSContext {
public void setThreadAffinity(ArrayList<Integer> coreList) { public void setThreadAffinity(ArrayList<Integer> coreList) {
if (isInitialized()) { if (isInitialized()) {
int len = coreList.size(); int len = coreList.size();
int[] coreList_array = new int[len]; int[] coreListArray = new int[len];
for (int i = 0; i < len; i++) { for (int i = 0; i < len; i++) {
coreList_array[i] = coreList.get(i); coreListArray[i] = coreList.get(i);
} }
setThreadAffinity(this.msContextPtr, coreList_array); setThreadAffinity(this.msContextPtr, coreListArray);
} else { } else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
} }
@ -247,13 +252,13 @@ public class MSContext {
*/ */
public ArrayList<Integer> getThreadAffinityCoreList() { public ArrayList<Integer> getThreadAffinityCoreList() {
ArrayList<Integer> ret_val = new ArrayList<>(); ArrayList<Integer> retVal = new ArrayList<>();
if (isInitialized()) { if (isInitialized()) {
ret_val = getThreadAffinityCoreList(this.msContextPtr); retVal = getThreadAffinityCoreList(this.msContextPtr);
} else { } else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
} }
return ret_val; return retVal;
} }
/** /**
@ -277,13 +282,13 @@ public class MSContext {
* @return boolean value that indicates whether in parallel. * @return boolean value that indicates whether in parallel.
*/ */
public boolean getEnableParallel() { public boolean getEnableParallel() {
boolean ret_val = false; boolean retVal = false;
if (isInitialized()) { if (isInitialized()) {
ret_val = getEnableParallel(this.msContextPtr); retVal = getEnableParallel(this.msContextPtr);
} else { } else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
} }
return ret_val; return retVal;
} }

View File

@ -1,7 +1,28 @@
/**
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.config; package com.mindspore.config;
import java.util.logging.Logger; import java.util.logging.Logger;
/**
* MSLite Init Class
*
* @since v1.0
*/
public final class MindsporeLite { public final class MindsporeLite {
private static final Object lock = new Object(); private static final Object lock = new Object();
private static Logger LOGGER = GetLogger(); private static Logger LOGGER = GetLogger();

View File

@ -1,5 +1,5 @@
/* /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2022-2023 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.config; package com.mindspore.config;
/** /**
@ -21,9 +22,19 @@ package com.mindspore.config;
* @since v1.0 * @since v1.0
*/ */
public class ModelType { public class ModelType {
// mindir type
public static final int MT_MINDIR = 0; public static final int MT_MINDIR = 0;
// air type
public static final int MT_AIR = 1; public static final int MT_AIR = 1;
// om type
public static final int MT_OM = 2; public static final int MT_OM = 2;
// onnx type
public static final int MT_ONNX = 3; public static final int MT_ONNX = 3;
// mindir opt type
public static final int MT_MINDIR_OPT = 4; public static final int MT_MINDIR_OPT = 4;
} }

View File

@ -1,3 +1,18 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.config; package com.mindspore.config;
import java.io.File; import java.io.File;
@ -5,7 +20,13 @@ import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.logging.Logger; import java.util.logging.Logger;
import java.util.Locale;
/**
* NativeLibrary Class
*
* @since v1.0
*/
public class NativeLibrary { public class NativeLibrary {
private static final Logger LOGGER = MindsporeLite.GetLogger(); private static final Logger LOGGER = MindsporeLite.GetLogger();
@ -52,7 +73,7 @@ public class NativeLibrary {
* libmsplugin-ge-litert * libmsplugin-ge-litert
* libruntime_convert_plugin * libruntime_convert_plugin
*/ */
public static void loadLibs() { private static void loadLibs() {
loadLib(makeResourceName("lib" + GLOG_LIBNAME + ".so")); loadLib(makeResourceName("lib" + GLOG_LIBNAME + ".so"));
loadLib(makeResourceName("lib" + OPENCV_CORE_LIBNAME + ".so")); loadLib(makeResourceName("lib" + OPENCV_CORE_LIBNAME + ".so"));
loadLib(makeResourceName("lib" + OPENCV_IMGPROC_LIBNAME + ".so")); loadLib(makeResourceName("lib" + OPENCV_IMGPROC_LIBNAME + ".so"));
@ -83,49 +104,49 @@ public class NativeLibrary {
try { try {
System.loadLibrary(MINDSPORE_LITE_JNI_LIBNAME); System.loadLibrary(MINDSPORE_LITE_JNI_LIBNAME);
loadSuccess = true; loadSuccess = true;
LOGGER.info("loadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + ": success"); LOGGER.info("loadLibrary mindspore-lite-jni success");
} catch (UnsatisfiedLinkError e) { } catch (UnsatisfiedLinkError e) {
LOGGER.info(String.format("tryLoadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + " failed: %s", e.toString())); LOGGER.info(String.format(Locale.ENGLISH, "tryLoadLibrary mindspore-lite-jni failed: %s", e));
} }
try { try {
System.loadLibrary(MINDSPORE_LITE_TRAIN_JNI_LIBNAME); System.loadLibrary(MINDSPORE_LITE_TRAIN_JNI_LIBNAME);
loadSuccess = true; loadSuccess = true;
LOGGER.info("loadLibrary " + MINDSPORE_LITE_TRAIN_JNI_LIBNAME + ": success."); LOGGER.info("loadLibrary mindspore-lite-train-jni success.");
} catch (UnsatisfiedLinkError e) { } catch (UnsatisfiedLinkError e) {
LOGGER.info(String.format("tryLoadLibrary " + MINDSPORE_LITE_TRAIN_JNI_LIBNAME + " failed: %s", e.toString())); LOGGER.info(String.format(Locale.ENGLISH, "tryLoadLibrary mindspore-lite-train-jni failed: %s", e));
} }
return loadSuccess; return loadSuccess;
} }
private static void loadLib(String libResourceName) { private static void loadLib(String libResourceName) {
LOGGER.info("start load libResourceName: " + libResourceName); LOGGER.info(String.format(Locale.ENGLISH,"start load libResourceName: %s.", libResourceName));
final InputStream libResource = NativeLibrary.class.getClassLoader().getResourceAsStream(libResourceName); final InputStream libResource = NativeLibrary.class.getClassLoader().getResourceAsStream(libResourceName);
if (libResource == null) { if (libResource == null) {
LOGGER.warning(String.format("lib file: %s not exist.", libResourceName)); LOGGER.warning(String.format(Locale.ENGLISH,"lib file: %s not exist.", libResourceName));
return; return;
} }
try { try {
final File tmpDir = mkTmpDir(); final File tmpDir = mkTmpDir();
String libName = libResourceName.substring(libResourceName.lastIndexOf("/") + 1); String libName = libResourceName.substring(libResourceName.lastIndexOf('/') + 1);
tmpDir.deleteOnExit(); tmpDir.deleteOnExit();
//copy file to tmpFile // copy file to tmpFile
final File tmpFile = new File(tmpDir.getCanonicalPath(), libName); final File tmpFile = new File(tmpDir.getCanonicalPath(), libName);
tmpFile.deleteOnExit(); tmpFile.deleteOnExit();
LOGGER.info(String.format("extract %d bytes to %s", copyLib(libResource, tmpFile), tmpFile)); LOGGER.info(String.format(Locale.ENGLISH,"extract %d bytes to %s", copyLib(libResource, tmpFile),
LOGGER.info(String.format("libName %s", libName)); tmpFile));
if (libName.equals("lib" + MINDSPORE_LITE_LIBNAME + ".so")) { if (("lib" + MINDSPORE_LITE_LIBNAME + ".so").equals(libName)) {
extractLib(makeResourceName("lib" + MSPLUGIN_GE_LITERT_LIBNAME + ".so"), tmpDir); extractLib(makeResourceName("lib" + MSPLUGIN_GE_LITERT_LIBNAME + ".so"), tmpDir);
extractLib(makeResourceName("lib" + RUNTIME_CONVERT_PLUGIN_LIBNAME + ".so"), tmpDir); extractLib(makeResourceName("lib" + RUNTIME_CONVERT_PLUGIN_LIBNAME + ".so"), tmpDir);
} }
System.load(tmpFile.toString()); System.load(tmpFile.toString());
} catch (IOException e) { } catch (IOException e) {
throw new UnsatisfiedLinkError( throw new UnsatisfiedLinkError(
String.format("extract library into tmp file (%s) failed.", e.toString())); String.format(Locale.ENGLISH,"extract library into tmp file (%s) failed.", e));
} }
} }
private static long copyLib(InputStream libResource, File tmpFile) throws IOException{ private static long copyLib(InputStream libResource, File tmpFile) throws IOException {
try (FileOutputStream outputStream = new FileOutputStream(tmpFile);) { try (FileOutputStream outputStream = new FileOutputStream(tmpFile);) {
// 1MB // 1MB
byte[] buffer = new byte[1 << 20]; byte[] buffer = new byte[1 << 20];
@ -143,9 +164,8 @@ public class NativeLibrary {
private static File mkTmpDir() { private static File mkTmpDir() {
final String MINDSPORE_LITE_LIBS = "mindspore_lite_libs-";
Long timestamp = System.currentTimeMillis(); Long timestamp = System.currentTimeMillis();
String dirName = MINDSPORE_LITE_LIBS + timestamp + "-"; String dirName = "mindspore_lite_libs-" + timestamp + "-";
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
File tmpDir = new File(new File(System.getProperty("java.io.tmpdir")), dirName + i); File tmpDir = new File(new File(System.getProperty("java.io.tmpdir")), dirName + i);
if (tmpDir.mkdir()) { if (tmpDir.mkdir()) {
@ -167,7 +187,7 @@ public class NativeLibrary {
private static String architecture() { private static String architecture() {
final String arch = System.getProperty("os.arch").toLowerCase(); final String arch = System.getProperty("os.arch").toLowerCase();
return (arch.equals("amd64")) ? "x86_64" : arch; return ("amd64".equals(arch)) ? "x86_64" : arch;
} }
private static void extractLib(String libResourceName, File targetDir) { private static void extractLib(String libResourceName, File targetDir) {

View File

@ -16,6 +16,11 @@
package com.mindspore.config; package com.mindspore.config;
/**
* TrainCfg Class
*
* @since v1.0
*/
public class TrainCfg { public class TrainCfg {
static { static {
try { try {

View File

@ -37,9 +37,8 @@ public class Version {
LOGGER.info("Version init load ..."); LOGGER.info("Version init load ...");
try { try {
NativeLibrary.load(); NativeLibrary.load();
} catch (Exception e) { } catch (UnsatisfiedLinkError e) {
LOGGER.severe("Failed to load MindSporLite native library."); LOGGER.severe("Failed to load MindSporLite native library.");
throw e;
} }
} }

View File

@ -26,8 +26,15 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Graph_loadModel(JNIEnv *en
MS_LOG(ERROR) << "Model new failed"; MS_LOG(ERROR) << "Model new failed";
return jlong(nullptr); return jlong(nullptr);
} }
if (ms_file == nullptr) {
MS_LOG(ERROR) << "ms_file from java is nullptr.";
delete graph;
return jlong(nullptr);
}
auto c_ms_file = env->GetStringUTFChars(ms_file, JNI_FALSE);
auto status = auto status =
mindspore::Serialization::Load(env->GetStringUTFChars(ms_file, JNI_FALSE), mindspore::ModelType::kMindIR, graph); mindspore::Serialization::Load(c_ms_file, mindspore::ModelType::kMindIR, graph);
env->ReleaseStringUTFChars(ms_file, c_ms_file);
if (status != mindspore::kSuccess) { if (status != mindspore::kSuccess) {
MS_LOG(ERROR) << "Load graph from file failed"; MS_LOG(ERROR) << "Load graph from file failed";
delete graph; delete graph;

View File

@ -111,23 +111,34 @@ extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
return false; return false;
} }
context.reset(c_context_ptr); context.reset(c_context_ptr);
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
mindspore::Status status; mindspore::Status status;
if (key_str != NULL) { if (key_str != NULL) {
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
auto key_len = static_cast<size_t>(env->GetArrayLength(key_str)); auto key_len = static_cast<size_t>(env->GetArrayLength(key_str));
char *dec_key_data = new (std::nothrow) char[key_len]; char *dec_key_data = new (std::nothrow) char[key_len];
if (dec_key_data == nullptr) { if (dec_key_data == nullptr) {
MS_LOG(ERROR) << "Dec key new failed"; MS_LOG(ERROR) << "Dec key new failed";
return false; return false;
} }
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
if (key_array == nullptr) {
MS_LOG(ERROR) << "key_array is nullptr.";
return false;
}
for (size_t i = 0; i < key_len; i++) { for (size_t i = 0; i < key_len; i++) {
dec_key_data[i] = key_array[i]; dec_key_data[i] = key_array[i];
} }
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT); env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
mindspore::Key dec_key{dec_key_data, key_len}; mindspore::Key dec_key{dec_key_data, key_len};
if (cropto_lib_path == nullptr || dec_mod == nullptr) {
MS_LOG(ERROR) << "cropto_lib_path or dec_mod from java is nullptr.";
return jlong(nullptr);
}
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE); auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path); status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
env->ReleaseStringUTFChars(cropto_lib_path, c_cropto_lib_path);
env->ReleaseStringUTFChars(dec_mod, c_dec_mod);
delete[] dec_key_data;
} else { } else {
status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context); status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context);
} }
@ -167,26 +178,43 @@ extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *e
return false; return false;
} }
context.reset(c_context_ptr); context.reset(c_context_ptr);
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
mindspore::Status status; mindspore::Status status;
if (key_str != NULL) { if (key_str != NULL) {
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
auto key_len = static_cast<size_t>(env->GetArrayLength(key_str)); auto key_len = static_cast<size_t>(env->GetArrayLength(key_str));
char *dec_key_data = new (std::nothrow) char[key_len]; char *dec_key_data = new (std::nothrow) char[key_len];
if (dec_key_data == nullptr) { if (dec_key_data == nullptr) {
MS_LOG(ERROR) << "Dec key new failed"; MS_LOG(ERROR) << "Dec key new failed";
env->ReleaseStringUTFChars(model_path, c_model_path);
return false; return false;
} }
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
if (key_array == nullptr) {
MS_LOG(ERROR) << "GetCharArrayElements failed.";
env->ReleaseStringUTFChars(model_path, c_model_path);
return jlong(nullptr);
}
for (size_t i = 0; i < key_len; i++) { for (size_t i = 0; i < key_len; i++) {
dec_key_data[i] = key_array[i]; dec_key_data[i] = key_array[i];
} }
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT); env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
mindspore::Key dec_key{dec_key_data, key_len}; mindspore::Key dec_key{dec_key_data, key_len};
if (dec_mod == nullptr || cropto_lib_path == nullptr) {
MS_LOG(ERROR) << "dec_mod, cropto_lib_path from java is nullptr.";
env->ReleaseStringUTFChars(model_path, c_model_path);
return jlong(nullptr);
}
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE); auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
status = lite_model_ptr->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path); status = lite_model_ptr->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
env->ReleaseStringUTFChars(dec_mod, c_dec_mod);
env->ReleaseStringUTFChars(cropto_lib_path, c_cropto_lib_path);
delete[] dec_key_data;
} else { } else {
status = lite_model_ptr->Build(c_model_path, c_model_type, context); status = lite_model_ptr->Build(c_model_path, c_model_type, context);
} }
env->ReleaseStringUTFChars(model_path, c_model_path);
if (status != mindspore::kSuccess) { if (status != mindspore::kSuccess) {
MS_LOG(ERROR) << "Error status " << static_cast<int>(status) << " during build of model"; MS_LOG(ERROR) << "Error status " << static_cast<int>(status) << " during build of model";
return false; return false;
@ -205,6 +233,8 @@ jobject GetInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_ptr, bool is_in
auto *pointer = reinterpret_cast<mindspore::Model *>(model_ptr); auto *pointer = reinterpret_cast<mindspore::Model *>(model_ptr);
if (pointer == nullptr) { if (pointer == nullptr) {
MS_LOG(ERROR) << "Model pointer from java is nullptr"; MS_LOG(ERROR) << "Model pointer from java is nullptr";
env->DeleteLocalRef(array_list);
env->DeleteLocalRef(long_object);
return ret; return ret;
} }
std::vector<mindspore::MSTensor> tensors; std::vector<mindspore::MSTensor> tensors;
@ -221,7 +251,10 @@ jobject GetInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_ptr, bool is_in
} }
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release())); jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
env->CallBooleanMethod(ret, array_list_add, tensor_addr); env->CallBooleanMethod(ret, array_list_add, tensor_addr);
env->DeleteLocalRef(tensor_addr);
} }
env->DeleteLocalRef(array_list);
env->DeleteLocalRef(long_object);
return ret; return ret;
} }
@ -233,11 +266,17 @@ jlong GetTensorByInOutName(JNIEnv *env, jlong model_ptr, jstring tensor_name, bo
} }
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer); auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
mindspore::MSTensor tensor; mindspore::MSTensor tensor;
if (is_input) { if (tensor_name == nullptr) {
tensor = lite_model_ptr->GetInputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE)); MS_LOG(ERROR) << "tensor_name from java is nullptr.";
} else { return jlong(nullptr);
tensor = lite_model_ptr->GetOutputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
} }
auto c_tensor_name = env->GetStringUTFChars(tensor_name, JNI_FALSE);
if (is_input) {
tensor = lite_model_ptr->GetInputByTensorName(c_tensor_name);
} else {
tensor = lite_model_ptr->GetOutputByTensorName(c_tensor_name);
}
env->ReleaseStringUTFChars(tensor_name, c_tensor_name);
if (tensor.impl() == nullptr) { if (tensor.impl() == nullptr) {
return jlong(nullptr); return jlong(nullptr);
} }
@ -283,8 +322,11 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputTensorNam
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer); auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto output_names = lite_model_ptr->GetOutputTensorNames(); auto output_names = lite_model_ptr->GetOutputTensorNames();
for (const auto &output_name : output_names) { for (const auto &output_name : output_names) {
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str())); auto output_name_jstring = env->NewStringUTF(output_name.c_str());
env->CallBooleanMethod(ret, array_list_add, output_name_jstring);
env->DeleteLocalRef(output_name_jstring);
} }
env->DeleteLocalRef(array_list);
return ret; return ret;
} }
@ -303,7 +345,13 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputsByNodeNa
return ret; return ret;
} }
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer); auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto tensors = lite_model_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE)); if (node_name == nullptr) {
MS_LOG(ERROR) << "node_name from java is nullptr";
return ret;
}
auto c_node_name = env->GetStringUTFChars(node_name, JNI_FALSE);
auto tensors = lite_model_ptr->GetOutputsByNodeName(c_node_name);
env->ReleaseStringUTFChars(node_name, c_node_name);
for (auto &tensor : tensors) { for (auto &tensor : tensors) {
auto tensor_ptr = std::make_unique<mindspore::MSTensor>(tensor); auto tensor_ptr = std::make_unique<mindspore::MSTensor>(tensor);
if (tensor_ptr == nullptr) { if (tensor_ptr == nullptr) {
@ -312,7 +360,10 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputsByNodeNa
} }
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release())); jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
env->CallBooleanMethod(ret, array_list_add, tensor_addr); env->CallBooleanMethod(ret, array_list_add, tensor_addr);
env->DeleteLocalRef(tensor_addr);
} }
env->DeleteLocalRef(array_list);
env->DeleteLocalRef(long_object);
return ret; return ret;
} }
@ -351,18 +402,24 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_runStep(JNIEnv *e
} }
std::vector<mindspore::MSTensor> convertArrayToVector(JNIEnv *env, jlongArray inputs) { std::vector<mindspore::MSTensor> convertArrayToVector(JNIEnv *env, jlongArray inputs) {
std::vector<mindspore::MSTensor> c_inputs;
if (inputs == nullptr) {
MS_LOG(ERROR) << "inputs from java is nullptr";
return c_inputs;
}
auto input_size = static_cast<int>(env->GetArrayLength(inputs)); auto input_size = static_cast<int>(env->GetArrayLength(inputs));
jlong *input_data = env->GetLongArrayElements(inputs, nullptr); jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
std::vector<mindspore::MSTensor> c_inputs;
for (int i = 0; i < input_size; i++) { for (int i = 0; i < input_size; i++) {
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]); auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
if (tensor_pointer == nullptr) { if (tensor_pointer == nullptr) {
MS_LOG(ERROR) << "Tensor pointer from java is nullptr"; MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
return c_inputs; return c_inputs;
} }
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(tensor_pointer); auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(tensor_pointer);
c_inputs.push_back(*ms_tensor_ptr); c_inputs.push_back(*ms_tensor_ptr);
} }
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
return c_inputs; return c_inputs;
} }
@ -389,14 +446,22 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
return (jboolean) false; return (jboolean) false;
} }
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer); auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
if (inputs == nullptr || dims == nullptr) {
MS_LOG(ERROR) << "inputs or dims from java is nullptr";
return (jboolean) false;
}
auto input_size = static_cast<int>(env->GetArrayLength(inputs)); auto input_size = static_cast<int>(env->GetArrayLength(inputs));
jlong *input_data = env->GetLongArrayElements(inputs, nullptr); jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
if (input_data == nullptr) {
MS_LOG(ERROR) << "input_data is nullptr";
return (jboolean) false;
}
std::vector<mindspore::MSTensor> c_inputs; std::vector<mindspore::MSTensor> c_inputs;
for (int i = 0; i < input_size; i++) { for (int i = 0; i < input_size; i++) {
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]); auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
if (tensor_pointer == nullptr) { if (tensor_pointer == nullptr) {
MS_LOG(ERROR) << "Tensor pointer from java is nullptr"; MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
return (jboolean) false; return (jboolean) false;
} }
auto &ms_tensor = *static_cast<mindspore::MSTensor *>(tensor_pointer); auto &ms_tensor = *static_cast<mindspore::MSTensor *>(tensor_pointer);
@ -405,8 +470,19 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
auto tensor_size = static_cast<int>(env->GetArrayLength(dims)); auto tensor_size = static_cast<int>(env->GetArrayLength(dims));
for (int i = 0; i < tensor_size; i++) { for (int i = 0; i < tensor_size; i++) {
auto array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i)); auto array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i));
if (array == nullptr) {
MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
return (jboolean) false;
}
auto dim_size = static_cast<int>(env->GetArrayLength(array)); auto dim_size = static_cast<int>(env->GetArrayLength(array));
jint *dim_data = env->GetIntArrayElements(array, nullptr); jint *dim_data = env->GetIntArrayElements(array, nullptr);
if (dim_data == nullptr) {
MS_LOG(ERROR) << "dim_data is nullptr";
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
env->DeleteLocalRef(array);
return (jboolean) false;
}
std::vector<int64_t> tensor_dims(dim_size); std::vector<int64_t> tensor_dims(dim_size);
for (int j = 0; j < dim_size; j++) { for (int j = 0; j < dim_size; j++) {
tensor_dims[j] = dim_data[j]; tensor_dims[j] = dim_data[j];
@ -416,16 +492,17 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
env->DeleteLocalRef(array); env->DeleteLocalRef(array);
} }
auto ret = lite_model_ptr->Resize(c_inputs, c_dims); auto ret = lite_model_ptr->Resize(c_inputs, c_dims);
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
return (jboolean)(ret.IsOk()); return (jboolean)(ret.IsOk());
} }
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_loadConfig(JNIEnv *env, jobject thiz, jstring model_ptr, extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_loadConfig(JNIEnv *env, jobject thiz, jstring model_ptr,
jstring config_path) { jstring config_path) {
auto *model_pointer = reinterpret_cast<void *>(model_ptr); if (model_ptr == nullptr || config_path == nullptr) {
if (model_pointer == nullptr) { MS_LOG(ERROR) << "input params from java is nullptr";
MS_LOG(ERROR) << "Model pointer from java is nullptr";
return (jboolean) false; return (jboolean) false;
} }
auto *model_pointer = reinterpret_cast<void *>(model_ptr);
auto *lite_model_ptr = static_cast<mindspore::Model *>(model_pointer); auto *lite_model_ptr = static_cast<mindspore::Model *>(model_pointer);
const char *c_config_path = env->GetStringUTFChars(config_path, nullptr); const char *c_config_path = env->GetStringUTFChars(config_path, nullptr);
std::string str_config_path(c_config_path, env->GetStringLength(config_path)); std::string str_config_path(c_config_path, env->GetStringLength(config_path));

View File

@ -119,6 +119,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_addDev
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadNum(JNIEnv *env, jobject thiz, extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadNum(JNIEnv *env, jobject thiz,
jlong context_ptr, jint thread_num) { jlong context_ptr, jint thread_num) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
c_context_ptr->SetThreadNum(thread_num); c_context_ptr->SetThreadNum(thread_num);
} }
@ -130,6 +134,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadN
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadNum(JNIEnv *env, jobject thiz, extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadNum(JNIEnv *env, jobject thiz,
jlong context_ptr) { jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return 0;
}
int32_t thread_num = c_context_ptr->GetThreadNum(); int32_t thread_num = c_context_ptr->GetThreadNum();
return thread_num; return thread_num;
} }
@ -143,6 +151,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setInterOp
jlong context_ptr, jlong context_ptr,
jint op_parallel_num) { jint op_parallel_num) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
c_context_ptr->SetInterOpParallelNum((int32_t)op_parallel_num); c_context_ptr->SetInterOpParallelNum((int32_t)op_parallel_num);
} }
@ -154,6 +166,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setInterOp
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getInterOpParallelNum(JNIEnv *env, jobject thiz, extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getInterOpParallelNum(JNIEnv *env, jobject thiz,
jlong context_ptr) { jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return 0;
}
auto inter_op_parallel_num = c_context_ptr->GetInterOpParallelNum(); auto inter_op_parallel_num = c_context_ptr->GetInterOpParallelNum();
return inter_op_parallel_num; return inter_op_parallel_num;
} }
@ -167,6 +183,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadA
jlong context_ptr, jlong context_ptr,
jint thread_affinity) { jint thread_affinity) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
c_context_ptr->SetThreadAffinity(thread_affinity); c_context_ptr->SetThreadAffinity(thread_affinity);
} }
@ -178,6 +198,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadA
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadAffinityMode(JNIEnv *env, jobject thiz, extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadAffinityMode(JNIEnv *env, jobject thiz,
jlong context_ptr) { jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return 0;
}
auto thread_affinity_mode = c_context_ptr->GetThreadAffinityMode(); auto thread_affinity_mode = c_context_ptr->GetThreadAffinityMode();
return thread_affinity_mode; return thread_affinity_mode;
} }
@ -190,7 +214,15 @@ extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadA
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadAffinity__J_3I(JNIEnv *env, jobject thiz, extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadAffinity__J_3I(JNIEnv *env, jobject thiz,
jlong context_ptr, jlong context_ptr,
jintArray core_list) { jintArray core_list) {
if (core_list == nullptr) {
MS_LOG(ERROR) << "core_list from java is nullptr";
return;
}
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
int32_t array_len = env->GetArrayLength(core_list); int32_t array_len = env->GetArrayLength(core_list);
jboolean is_copy = JNI_FALSE; jboolean is_copy = JNI_FALSE;
int *core_value = env->GetIntArrayElements(core_list, &is_copy); int *core_value = env->GetIntArrayElements(core_list, &is_copy);
@ -209,6 +241,10 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_config_MSContext_getThre
jobject thiz, jobject thiz,
jlong context_ptr) { jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return nullptr;
}
std::vector<int32_t> core_list_tmp = c_context_ptr->GetThreadAffinityCoreList(); std::vector<int32_t> core_list_tmp = c_context_ptr->GetThreadAffinityCoreList();
jobject core_list = newObjectArrayList<int32_t>(env, core_list_tmp, "java/lang/Integer", "(I)V"); jobject core_list = newObjectArrayList<int32_t>(env, core_list_tmp, "java/lang/Integer", "(I)V");
return core_list; return core_list;
@ -223,6 +259,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setEnableP
jlong context_ptr, jlong context_ptr,
jboolean is_parallel) { jboolean is_parallel) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
c_context_ptr->SetEnableParallel(static_cast<bool>(is_parallel)); c_context_ptr->SetEnableParallel(static_cast<bool>(is_parallel));
} }
@ -234,6 +274,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setEnableP
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_getEnableParallel(JNIEnv *env, jobject thiz, extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_getEnableParallel(JNIEnv *env, jobject thiz,
jlong context_ptr) { jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return (jboolean) false;
}
bool is_parallel = c_context_ptr->GetEnableParallel(); bool is_parallel = c_context_ptr->GetEnableParallel();
return (jboolean)is_parallel; return (jboolean)is_parallel;
} }
@ -241,5 +285,9 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_getEna
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_free(JNIEnv *env, jobject thiz, extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_free(JNIEnv *env, jobject thiz,
jlong context_ptr) { jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr)); auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
delete (c_context_ptr); delete (c_context_ptr);
} }

View File

@ -30,6 +30,10 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_MSTensor_getShape(JNIE
auto local_shape = ms_tensor_ptr->Shape(); auto local_shape = ms_tensor_ptr->Shape();
auto shape_size = local_shape.size(); auto shape_size = local_shape.size();
jintArray shape = env->NewIntArray(shape_size); jintArray shape = env->NewIntArray(shape_size);
if (shape == nullptr) {
MS_LOG(ERROR) << "new intArray failed.";
return env->NewIntArray(0);
}
auto *tmp = new jint[shape_size]; auto *tmp = new jint[shape_size];
for (size_t i = 0; i < shape_size; i++) { for (size_t i = 0; i < shape_size; i++) {
tmp[i] = static_cast<int>(local_shape.at(i)); tmp[i] = static_cast<int>(local_shape.at(i));
@ -69,6 +73,10 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_MSTensor_getByteData(
return env->NewByteArray(0); return env->NewByteArray(0);
} }
auto ret = env->NewByteArray(local_size); auto ret = env->NewByteArray(local_size);
if (ret == nullptr) {
MS_LOG(ERROR) << "malloc failed.";
return env->NewByteArray(0);
}
env->SetByteArrayRegion(ret, 0, local_size, local_data); env->SetByteArrayRegion(ret, 0, local_size, local_data);
return ret; return ret;
} }
@ -99,6 +107,10 @@ extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_MSTensor_getLongData(
return env->NewLongArray(0); return env->NewLongArray(0);
} }
auto ret = env->NewLongArray(local_element_num); auto ret = env->NewLongArray(local_element_num);
if (ret == nullptr) {
MS_LOG(ERROR) << "malloc failed.";
return env->NewLongArray(0);
}
env->SetLongArrayRegion(ret, 0, local_element_num, local_data); env->SetLongArrayRegion(ret, 0, local_element_num, local_data);
return ret; return ret;
} }
@ -129,6 +141,10 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_MSTensor_getIntData(JN
return env->NewIntArray(0); return env->NewIntArray(0);
} }
auto ret = env->NewIntArray(local_element_num); auto ret = env->NewIntArray(local_element_num);
if (ret == nullptr) {
MS_LOG(ERROR) << "malloc failed.";
return env->NewIntArray(0);
}
env->SetIntArrayRegion(ret, 0, local_element_num, local_data); env->SetIntArrayRegion(ret, 0, local_element_num, local_data);
return ret; return ret;
} }
@ -159,6 +175,10 @@ extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_MSTensor_getFloatDat
return env->NewFloatArray(0); return env->NewFloatArray(0);
} }
auto ret = env->NewFloatArray(local_element_num); auto ret = env->NewFloatArray(local_element_num);
if (ret == nullptr) {
MS_LOG(ERROR) << "malloc failed.";
return env->NewFloatArray(0);
}
env->SetFloatArrayRegion(ret, 0, local_element_num, local_data); env->SetFloatArrayRegion(ret, 0, local_element_num, local_data);
return ret; return ret;
} }
@ -177,8 +197,18 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setByteData(JN
return static_cast<jboolean>(false); return static_cast<jboolean>(false);
} }
jboolean is_copy = false; jboolean is_copy = false;
if (data == nullptr) {
MS_LOG(ERROR) << "data from java is nullptr.";
return static_cast<jboolean>(false);
}
auto *data_arr = env->GetByteArrayElements(data, &is_copy); auto *data_arr = env->GetByteArrayElements(data, &is_copy);
auto *local_data = ms_tensor_ptr->MutableData(); auto *local_data = ms_tensor_ptr->MutableData();
if (data_arr == nullptr || local_data == nullptr) {
MS_LOG(ERROR) << "data_arr or local_data is nullptr.";
env->ReleaseByteArrayElements(data, data_arr, JNI_ABORT);
return static_cast<jboolean>(false);
}
memcpy(local_data, data_arr, data_len); memcpy(local_data, data_arr, data_len);
env->ReleaseByteArrayElements(data, data_arr, JNI_ABORT); env->ReleaseByteArrayElements(data, data_arr, JNI_ABORT);
return static_cast<jboolean>(true); return static_cast<jboolean>(true);
@ -208,6 +238,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setFloatData(J
MS_LOG(ERROR) << "malloc memory failed."; MS_LOG(ERROR) << "malloc memory failed.";
return static_cast<jboolean>(false); return static_cast<jboolean>(false);
} }
if (data == nullptr) {
MS_LOG(ERROR) << "data from java is nullptr";
return static_cast<jboolean>(false);
}
env->GetFloatArrayRegion(data, 0, static_cast<jsize>(data_len), local_data); env->GetFloatArrayRegion(data, 0, static_cast<jsize>(data_len), local_data);
return static_cast<jboolean>(true); return static_cast<jboolean>(true);
} }
@ -236,6 +270,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setIntData(JNI
MS_LOG(ERROR) << "malloc memory failed."; MS_LOG(ERROR) << "malloc memory failed.";
return static_cast<jboolean>(false); return static_cast<jboolean>(false);
} }
if (data == nullptr) {
MS_LOG(ERROR) << "data from java is nullptr";
return static_cast<jboolean>(false);
}
env->GetIntArrayRegion(data, 0, static_cast<jsize>(data_len), local_data); env->GetIntArrayRegion(data, 0, static_cast<jsize>(data_len), local_data);
return static_cast<jboolean>(true); return static_cast<jboolean>(true);
} }
@ -262,6 +300,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setLongData(JN
MS_LOG(ERROR) << "malloc memory failed."; MS_LOG(ERROR) << "malloc memory failed.";
return static_cast<jboolean>(false); return static_cast<jboolean>(false);
} }
if (data == nullptr) {
MS_LOG(ERROR) << "data from java is nullptr";
return static_cast<jboolean>(false);
}
env->GetLongArrayRegion(data, 0, static_cast<jsize>(data_len), local_data); env->GetLongArrayRegion(data, 0, static_cast<jsize>(data_len), local_data);
return static_cast<jboolean>(true); return static_cast<jboolean>(true);
} }
@ -287,6 +329,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setByteBufferD
return static_cast<jboolean>(false); return static_cast<jboolean>(false);
} }
auto *local_data = ms_tensor_ptr->MutableData(); auto *local_data = ms_tensor_ptr->MutableData();
if (local_data == nullptr) {
MS_LOG(ERROR) << "get MutableData nullptr.";
return static_cast<jboolean>(false);
}
memcpy(local_data, p_data, data_len); memcpy(local_data, p_data, data_len);
return static_cast<jboolean>(true); return static_cast<jboolean>(true);
} }
@ -304,8 +350,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_size(JNIEnv *env,
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setShape(JNIEnv *env, jobject thiz, jlong tensor_ptr, extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setShape(JNIEnv *env, jobject thiz, jlong tensor_ptr,
jintArray tensor_shape) { jintArray tensor_shape) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr); auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) { if (pointer == nullptr || tensor_shape == nullptr) {
MS_LOG(ERROR) << "Tensor pointer from java is nullptr"; MS_LOG(ERROR) << "input params from java is nullptr";
return static_cast<jboolean>(false); return static_cast<jboolean>(false);
} }
@ -357,6 +403,11 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_createTensorByNat
jstring tensor_name, jint data_type, jstring tensor_name, jint data_type,
jintArray tensor_shape, jintArray tensor_shape,
jobject buffer) { jobject buffer) {
// check inputs
if (buffer == nullptr || tensor_name == nullptr || tensor_shape == nullptr) {
MS_LOG(ERROR) << "input param from java is nullptr";
return 0;
}
auto *p_data = reinterpret_cast<jbyte *>(env->GetDirectBufferAddress(buffer)); auto *p_data = reinterpret_cast<jbyte *>(env->GetDirectBufferAddress(buffer));
jlong data_len = env->GetDirectBufferCapacity(buffer); jlong data_len = env->GetDirectBufferCapacity(buffer);
if (p_data == nullptr) { if (p_data == nullptr) {
@ -370,11 +421,11 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_createTensorByNat
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
c_shape[i] = static_cast<int64_t>(shape_pointer[i]); c_shape[i] = static_cast<int64_t>(shape_pointer[i]);
} }
env->ReleaseIntArrayElements(tensor_shape, shape_pointer, JNI_ABORT);
const char *c_tensor_name = env->GetStringUTFChars(tensor_name, nullptr); const char *c_tensor_name = env->GetStringUTFChars(tensor_name, nullptr);
std::string str_tensor_name(c_tensor_name, env->GetStringLength(tensor_name)); std::string str_tensor_name(c_tensor_name, env->GetStringLength(tensor_name));
auto tensor = mindspore::MSTensor::CreateTensor(str_tensor_name, static_cast<mindspore::DataType>(data_type), c_shape, auto tensor = mindspore::MSTensor::CreateTensor(str_tensor_name, static_cast<mindspore::DataType>(data_type), c_shape,
p_data, data_len); p_data, data_len);
env->ReleaseIntArrayElements(tensor_shape, shape_pointer, JNI_ABORT);
env->ReleaseStringUTFChars(tensor_name, c_tensor_name); env->ReleaseStringUTFChars(tensor_name, c_tensor_name);
return jlong(tensor); return jlong(tensor);
} }

View File

@ -51,8 +51,10 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_TrainCfg_createTrai
} }
if (loss_name != nullptr) { if (loss_name != nullptr) {
std::vector<std::string> traincfg_loss_name = traincfg_ptr->GetLossName(); std::vector<std::string> traincfg_loss_name = traincfg_ptr->GetLossName();
traincfg_loss_name.emplace_back(env->GetStringUTFChars(loss_name, JNI_FALSE)); auto c_loss_name = env->GetStringUTFChars(loss_name, JNI_FALSE);
traincfg_loss_name.emplace_back(c_loss_name);
traincfg_ptr->SetLossName(traincfg_loss_name); traincfg_ptr->SetLossName(traincfg_loss_name);
env->ReleaseStringUTFChars(loss_name, c_loss_name);
} }
traincfg_ptr->optimization_level_ = ol; traincfg_ptr->optimization_level_ = ol;
traincfg_ptr->accumulate_gradients_ = accmulateGrads; traincfg_ptr->accumulate_gradients_ = accmulateGrads;

View File

@ -0,0 +1,63 @@
/*
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import com.mindspore.config.DataType;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.Arrays;
/**
* Model Test
*
* @since 2023.2
*/
@RunWith(JUnit4.class)
public class MSTensorTest {
@Test
public void testCreateTensor() {
int[] tensorShape = {6, 5, 5, 1};
float[] tensorData = new float[6 * 5 * 5];
Arrays.fill(tensorData, 0);
ByteBuffer byteBuf = ByteBuffer.allocateDirect(6 * 5 * 5 * 4);
FloatBuffer floatBuf = byteBuf.asFloatBuffer();
floatBuf.put(tensorData);
MSTensor newTensor = MSTensor.createTensor("conv1.weight", DataType.kNumberTypeFloat32, tensorShape,
byteBuf);
assertNotNull(newTensor);
}
@Test
public void testSetData() {
int[] tensorShape = {6, 5, 5, 1};
ByteBuffer byteBuf = ByteBuffer.allocateDirect(6 * 5 * 5 * 4);
MSTensor newTensor = MSTensor.createTensor("conv1.weight", DataType.kNumberTypeFloat32, tensorShape,
byteBuf);
assertNotNull(newTensor);
float[] tensorData = new float[6 * 5 * 5];
Arrays.fill(tensorData, 0);
assertTrue(newTensor.setData(tensorData));
}
}