diff --git a/mindspore/lite/java/src/main/java/com/mindspore/Graph.java b/mindspore/lite/java/src/main/java/com/mindspore/Graph.java index 06e7713cf6d..a1d003af807 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/Graph.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/Graph.java @@ -18,6 +18,11 @@ package com.mindspore; import com.mindspore.config.MindsporeLite; +/** + * Graph Class + * + * @since v1.0 + */ public class Graph { static { MindsporeLite.init(); diff --git a/mindspore/lite/java/src/main/java/com/mindspore/MSTensor.java b/mindspore/lite/java/src/main/java/com/mindspore/MSTensor.java index 434a7988a07..cf374870fcd 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/MSTensor.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/MSTensor.java @@ -22,6 +22,11 @@ import java.nio.ByteBuffer; import java.lang.reflect.Array; import java.util.HashMap; +/** + * The MSTensor class defines a tensor in MindSpore. + * + * @since v1.0 + */ public class MSTensor { static { MindsporeLite.init(); diff --git a/mindspore/lite/java/src/main/java/com/mindspore/Model.java b/mindspore/lite/java/src/main/java/com/mindspore/Model.java index c41d958e576..7de7984e8b8 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/Model.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/Model.java @@ -25,6 +25,11 @@ import java.nio.MappedByteBuffer; import java.util.ArrayList; import java.util.List; +/** + * The Model class is used to define a MindSpore model, facilitating computational graph management. + * + * @since v1.0 + */ public class Model { static { MindsporeLite.init(); @@ -61,16 +66,20 @@ public class Model { * @param buffer model buffer. * @param modelType model type. * @param context model build context. - * @param dec_key define the key used to decrypt the ciphertext model. The key length is 16. - * @param dec_mode define the decryption mode. Options: AES-GCM. - * @param cropto_lib_path define the openssl library path. + * @param decKey define the key used to decrypt the ciphertext model. The key length is 16. + * @param decMode define the decryption mode. Options: AES-GCM. + * @param croptoLibPath define the openssl library path. * @return model build status. */ - public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) { - if (context == null || buffer == null || dec_key == null || dec_mode == null) { + public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] decKey, String decMode, + String croptoLibPath) { + boolean isValid = (context != null && buffer != null && decKey != null && decMode != null && + croptoLibPath != null); + if (!isValid) { return false; } - return this.buildByBuffer(modelPtr, buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path); + return this.buildByBuffer(modelPtr, buffer, modelType, context.getMSContextPtr(), decKey, decMode, + croptoLibPath); } /** @@ -95,16 +104,19 @@ public class Model { * @param modelPath model path. * @param modelType model type. * @param context model build context. - * @param dec_key define the key used to decrypt the ciphertext model. The key length is 16. - * @param dec_mode define the decryption mode. Options: AES-GCM. - * @param cropto_lib_path define the openssl library path. + * @param decKey define the key used to decrypt the ciphertext model. The key length is 16. + * @param decMode define the decryption mode. Options: AES-GCM. + * @param croptoLibPath define the openssl library path. * @return model build status. */ - public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) { - if (context == null || modelPath == null || dec_key == null || dec_mode == null) { + public boolean build(String modelPath, int modelType, MSContext context, char[] decKey, String decMode, + String croptoLibPath) { + boolean isValid = (context != null && modelPath != null && decKey != null && decMode != null && + croptoLibPath != null); + if (!isValid) { return false; } - return this.buildByPath(modelPtr, modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path); + return this.buildByPath(modelPtr, modelPath, modelType, context.getMSContextPtr(), decKey, decMode, croptoLibPath); } /** @@ -164,9 +176,9 @@ public class Model { * @return input tensors. */ public List getInputs() { - List ret = this.getInputs(this.modelPtr); - List tensors = new ArrayList<>(); - for (Long msTensorAddr : ret) { + List tensorAddrs = this.getInputs(this.modelPtr); + List tensors = new ArrayList<>(tensorAddrs.size()); + for (Long msTensorAddr : tensorAddrs) { MSTensor msTensor = new MSTensor(msTensorAddr); tensors.add(msTensor); } @@ -179,9 +191,9 @@ public class Model { * @return model outputs tensor. */ public List getOutputs() { - List ret = this.getOutputs(this.modelPtr); - List tensors = new ArrayList<>(); - for (Long msTensorAddr : ret) { + List tensorAddrs = this.getOutputs(this.modelPtr); + List tensors = new ArrayList<>(tensorAddrs.size()); + for (Long msTensorAddr : tensorAddrs) { MSTensor msTensor = new MSTensor(msTensorAddr); tensors.add(msTensor); } @@ -224,11 +236,11 @@ public class Model { */ public List getOutputsByNodeName(String nodeName) { if (nodeName == null) { - return null; + return new ArrayList<>(); } - List ret = this.getOutputsByNodeName(this.modelPtr, nodeName); - List tensors = new ArrayList<>(); - for (Long msTensorAddr : ret) { + List tensorAddrs = this.getOutputsByNodeName(this.modelPtr, nodeName); + List tensors = new ArrayList<>(tensorAddrs.size()); + for (Long msTensorAddr : tensorAddrs) { MSTensor msTensor = new MSTensor(msTensorAddr); tensors.add(msTensor); } @@ -296,9 +308,9 @@ public class Model { * @return FeaturesMap Tensor list. */ public List getFeatureMaps() { - List ret = this.getFeatureMaps(this.modelPtr); - ArrayList tensors = new ArrayList<>(); - for (Long msTensorAddr : ret) { + List tensorAddrs = this.getFeatureMaps(this.modelPtr); + ArrayList tensors = new ArrayList<>(tensorAddrs.size()); + for (Long msTensorAddr : tensorAddrs) { MSTensor msTensor = new MSTensor(msTensorAddr); tensors.add(msTensor); } diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java b/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java index 97175153338..76d702253e6 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java @@ -22,7 +22,13 @@ package com.mindspore.config; * @since v1.0 */ public class CpuBindMode { + + // bind mind cpu public static final int MID_CPU = 2; + + // bind high cpu public static final int HIGHER_CPU = 1; + + // no bind public static final int NO_BIND = 0; } \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java b/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java index 4795f597ee4..115b553a2b4 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java @@ -21,8 +21,16 @@ package com.mindspore.config; * @since v1.0 */ public class DeviceType { + + // support cpu public static final int DT_CPU = 0; + + // support gpu public static final int DT_GPU = 1; + + // support npu public static final int DT_NPU = 2; + + // support ascend public static final int DT_ASCEND = 3; } \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/MSContext.java b/mindspore/lite/java/src/main/java/com/mindspore/config/MSContext.java index 8a9b493f0de..e30c8c650c6 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/config/MSContext.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/MSContext.java @@ -20,17 +20,22 @@ import java.util.ArrayList; import java.util.logging.Level; import java.util.logging.Logger; - +/** + * Context is used to store environment variables during execution. + * + * @since v1.0 + */ public class MSContext { private static Logger LOGGER = MindsporeLite.GetLogger(); static { MindsporeLite.init(); } - private long msContextPtr; - private static final long EMPTY_CONTEXT_PTR_VALUE = 0; + private static final long EMPTY_CONTEXT_PTR_VALUE = 0L; private static final int ERROR_VALUE = -1; - private static final String NULLPTR_ERROR_MESSAGE="Context pointer from java is nullptr.\n"; + private static final String NULLPTR_ERROR_MESSAGE="Context pointer from java is nullptr."; + + private long msContextPtr; /** * Construct function. @@ -148,13 +153,13 @@ public class MSContext { * @return The current thread number setting. */ public int getThreadNum() { - int ret_val = ERROR_VALUE; + int retVal = ERROR_VALUE; if (isInitialized()) { - ret_val = getThreadNum(this.msContextPtr); + retVal = getThreadNum(this.msContextPtr); } else { LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); } - return ret_val; + return retVal; } /** @@ -178,13 +183,13 @@ public class MSContext { * @return The current operators parallel number setting. */ public int getInterOpParallelNum() { - int ret_val = ERROR_VALUE; + int retVal = ERROR_VALUE; if (isInitialized()) { - ret_val = getInterOpParallelNum(this.msContextPtr); + retVal = getInterOpParallelNum(this.msContextPtr); } else { LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); } - return ret_val; + return retVal; } /** @@ -209,13 +214,13 @@ public class MSContext { * @return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first */ public int getThreadAffinityMode() { - int ret_val = ERROR_VALUE; + int retVal = ERROR_VALUE; if (isInitialized()) { - ret_val = getThreadAffinityMode(this.msContextPtr); + retVal = getThreadAffinityMode(this.msContextPtr); } else { LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); } - return ret_val; + return retVal; } /** @@ -229,11 +234,11 @@ public class MSContext { public void setThreadAffinity(ArrayList coreList) { if (isInitialized()) { int len = coreList.size(); - int[] coreList_array = new int[len]; + int[] coreListArray = new int[len]; for (int i = 0; i < len; i++) { - coreList_array[i] = coreList.get(i); + coreListArray[i] = coreList.get(i); } - setThreadAffinity(this.msContextPtr, coreList_array); + setThreadAffinity(this.msContextPtr, coreListArray); } else { LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); } @@ -247,13 +252,13 @@ public class MSContext { */ public ArrayList getThreadAffinityCoreList() { - ArrayList ret_val = new ArrayList<>(); + ArrayList retVal = new ArrayList<>(); if (isInitialized()) { - ret_val = getThreadAffinityCoreList(this.msContextPtr); + retVal = getThreadAffinityCoreList(this.msContextPtr); } else { LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); } - return ret_val; + return retVal; } /** @@ -277,13 +282,13 @@ public class MSContext { * @return boolean value that indicates whether in parallel. */ public boolean getEnableParallel() { - boolean ret_val = false; + boolean retVal = false; if (isInitialized()) { - ret_val = getEnableParallel(this.msContextPtr); + retVal = getEnableParallel(this.msContextPtr); } else { LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE); } - return ret_val; + return retVal; } diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/MindsporeLite.java b/mindspore/lite/java/src/main/java/com/mindspore/config/MindsporeLite.java index 59d65a3e328..55ba73bf228 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/config/MindsporeLite.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/MindsporeLite.java @@ -1,7 +1,28 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.mindspore.config; import java.util.logging.Logger; +/** + * MSLite Init Class + * + * @since v1.0 + */ public final class MindsporeLite { private static final Object lock = new Object(); private static Logger LOGGER = GetLogger(); diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java b/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java index 88a3e1b6bdd..2936e9c600a 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java @@ -1,5 +1,5 @@ -/* - * Copyright 2021 Huawei Technologies Co., Ltd +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.mindspore.config; /** @@ -21,9 +22,19 @@ package com.mindspore.config; * @since v1.0 */ public class ModelType { + + // mindir type public static final int MT_MINDIR = 0; + + // air type public static final int MT_AIR = 1; + + // om type public static final int MT_OM = 2; + + // onnx type public static final int MT_ONNX = 3; + + // mindir opt type public static final int MT_MINDIR_OPT = 4; } \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/NativeLibrary.java b/mindspore/lite/java/src/main/java/com/mindspore/config/NativeLibrary.java index 16b2ed5412a..67805fc442e 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/config/NativeLibrary.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/NativeLibrary.java @@ -1,3 +1,18 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.mindspore.config; import java.io.File; @@ -5,7 +20,13 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.util.logging.Logger; +import java.util.Locale; +/** + * NativeLibrary Class + * + * @since v1.0 + */ public class NativeLibrary { private static final Logger LOGGER = MindsporeLite.GetLogger(); @@ -52,7 +73,7 @@ public class NativeLibrary { * libmsplugin-ge-litert * libruntime_convert_plugin */ - public static void loadLibs() { + private static void loadLibs() { loadLib(makeResourceName("lib" + GLOG_LIBNAME + ".so")); loadLib(makeResourceName("lib" + OPENCV_CORE_LIBNAME + ".so")); loadLib(makeResourceName("lib" + OPENCV_IMGPROC_LIBNAME + ".so")); @@ -83,49 +104,49 @@ public class NativeLibrary { try { System.loadLibrary(MINDSPORE_LITE_JNI_LIBNAME); loadSuccess = true; - LOGGER.info("loadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + ": success"); + LOGGER.info("loadLibrary mindspore-lite-jni success"); } catch (UnsatisfiedLinkError e) { - LOGGER.info(String.format("tryLoadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + " failed: %s", e.toString())); + LOGGER.info(String.format(Locale.ENGLISH, "tryLoadLibrary mindspore-lite-jni failed: %s", e)); } try { System.loadLibrary(MINDSPORE_LITE_TRAIN_JNI_LIBNAME); loadSuccess = true; - LOGGER.info("loadLibrary " + MINDSPORE_LITE_TRAIN_JNI_LIBNAME + ": success."); + LOGGER.info("loadLibrary mindspore-lite-train-jni success."); } catch (UnsatisfiedLinkError e) { - LOGGER.info(String.format("tryLoadLibrary " + MINDSPORE_LITE_TRAIN_JNI_LIBNAME + " failed: %s", e.toString())); + LOGGER.info(String.format(Locale.ENGLISH, "tryLoadLibrary mindspore-lite-train-jni failed: %s", e)); } return loadSuccess; } private static void loadLib(String libResourceName) { - LOGGER.info("start load libResourceName: " + libResourceName); + LOGGER.info(String.format(Locale.ENGLISH,"start load libResourceName: %s.", libResourceName)); final InputStream libResource = NativeLibrary.class.getClassLoader().getResourceAsStream(libResourceName); if (libResource == null) { - LOGGER.warning(String.format("lib file: %s not exist.", libResourceName)); + LOGGER.warning(String.format(Locale.ENGLISH,"lib file: %s not exist.", libResourceName)); return; } try { final File tmpDir = mkTmpDir(); - String libName = libResourceName.substring(libResourceName.lastIndexOf("/") + 1); + String libName = libResourceName.substring(libResourceName.lastIndexOf('/') + 1); tmpDir.deleteOnExit(); - //copy file to tmpFile + // copy file to tmpFile final File tmpFile = new File(tmpDir.getCanonicalPath(), libName); tmpFile.deleteOnExit(); - LOGGER.info(String.format("extract %d bytes to %s", copyLib(libResource, tmpFile), tmpFile)); - LOGGER.info(String.format("libName %s", libName)); - if (libName.equals("lib" + MINDSPORE_LITE_LIBNAME + ".so")) { + LOGGER.info(String.format(Locale.ENGLISH,"extract %d bytes to %s", copyLib(libResource, tmpFile), + tmpFile)); + if (("lib" + MINDSPORE_LITE_LIBNAME + ".so").equals(libName)) { extractLib(makeResourceName("lib" + MSPLUGIN_GE_LITERT_LIBNAME + ".so"), tmpDir); extractLib(makeResourceName("lib" + RUNTIME_CONVERT_PLUGIN_LIBNAME + ".so"), tmpDir); } System.load(tmpFile.toString()); } catch (IOException e) { throw new UnsatisfiedLinkError( - String.format("extract library into tmp file (%s) failed.", e.toString())); + String.format(Locale.ENGLISH,"extract library into tmp file (%s) failed.", e)); } } - private static long copyLib(InputStream libResource, File tmpFile) throws IOException{ + private static long copyLib(InputStream libResource, File tmpFile) throws IOException { try (FileOutputStream outputStream = new FileOutputStream(tmpFile);) { // 1MB byte[] buffer = new byte[1 << 20]; @@ -143,9 +164,8 @@ public class NativeLibrary { private static File mkTmpDir() { - final String MINDSPORE_LITE_LIBS = "mindspore_lite_libs-"; Long timestamp = System.currentTimeMillis(); - String dirName = MINDSPORE_LITE_LIBS + timestamp + "-"; + String dirName = "mindspore_lite_libs-" + timestamp + "-"; for (int i = 0; i < 10; i++) { File tmpDir = new File(new File(System.getProperty("java.io.tmpdir")), dirName + i); if (tmpDir.mkdir()) { @@ -167,7 +187,7 @@ public class NativeLibrary { private static String architecture() { final String arch = System.getProperty("os.arch").toLowerCase(); - return (arch.equals("amd64")) ? "x86_64" : arch; + return ("amd64".equals(arch)) ? "x86_64" : arch; } private static void extractLib(String libResourceName, File targetDir) { diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/TrainCfg.java b/mindspore/lite/java/src/main/java/com/mindspore/config/TrainCfg.java index 7bdc9c577a8..ad0cb959341 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/config/TrainCfg.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/TrainCfg.java @@ -16,6 +16,11 @@ package com.mindspore.config; +/** + * TrainCfg Class + * + * @since v1.0 + */ public class TrainCfg { static { try { diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java b/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java index 9b3b119fe5c..6d2ed9fec45 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java @@ -37,9 +37,8 @@ public class Version { LOGGER.info("Version init load ..."); try { NativeLibrary.load(); - } catch (Exception e) { + } catch (UnsatisfiedLinkError e) { LOGGER.severe("Failed to load MindSporLite native library."); - throw e; } } diff --git a/mindspore/lite/java/src/main/native/graph.cpp b/mindspore/lite/java/src/main/native/graph.cpp index 70367e76548..012ba3253c7 100644 --- a/mindspore/lite/java/src/main/native/graph.cpp +++ b/mindspore/lite/java/src/main/native/graph.cpp @@ -26,8 +26,15 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Graph_loadModel(JNIEnv *en MS_LOG(ERROR) << "Model new failed"; return jlong(nullptr); } + if (ms_file == nullptr) { + MS_LOG(ERROR) << "ms_file from java is nullptr."; + delete graph; + return jlong(nullptr); + } + auto c_ms_file = env->GetStringUTFChars(ms_file, JNI_FALSE); auto status = - mindspore::Serialization::Load(env->GetStringUTFChars(ms_file, JNI_FALSE), mindspore::ModelType::kMindIR, graph); + mindspore::Serialization::Load(c_ms_file, mindspore::ModelType::kMindIR, graph); + env->ReleaseStringUTFChars(ms_file, c_ms_file); if (status != mindspore::kSuccess) { MS_LOG(ERROR) << "Load graph from file failed"; delete graph; diff --git a/mindspore/lite/java/src/main/native/model.cpp b/mindspore/lite/java/src/main/native/model.cpp index 8508ce11e73..eda5dc7e047 100644 --- a/mindspore/lite/java/src/main/native/model.cpp +++ b/mindspore/lite/java/src/main/native/model.cpp @@ -111,23 +111,34 @@ extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv return false; } context.reset(c_context_ptr); - auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE); mindspore::Status status; if (key_str != NULL) { - jchar *key_array = env->GetCharArrayElements(key_str, NULL); auto key_len = static_cast(env->GetArrayLength(key_str)); char *dec_key_data = new (std::nothrow) char[key_len]; if (dec_key_data == nullptr) { MS_LOG(ERROR) << "Dec key new failed"; return false; } + jchar *key_array = env->GetCharArrayElements(key_str, NULL); + if (key_array == nullptr) { + MS_LOG(ERROR) << "key_array is nullptr."; + return false; + } for (size_t i = 0; i < key_len; i++) { dec_key_data[i] = key_array[i]; } env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT); mindspore::Key dec_key{dec_key_data, key_len}; + if (cropto_lib_path == nullptr || dec_mod == nullptr) { + MS_LOG(ERROR) << "cropto_lib_path or dec_mod from java is nullptr."; + return jlong(nullptr); + } + auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE); auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE); status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path); + env->ReleaseStringUTFChars(cropto_lib_path, c_cropto_lib_path); + env->ReleaseStringUTFChars(dec_mod, c_dec_mod); + delete[] dec_key_data; } else { status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context); } @@ -167,26 +178,43 @@ extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *e return false; } context.reset(c_context_ptr); - auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE); mindspore::Status status; if (key_str != NULL) { - jchar *key_array = env->GetCharArrayElements(key_str, NULL); auto key_len = static_cast(env->GetArrayLength(key_str)); char *dec_key_data = new (std::nothrow) char[key_len]; if (dec_key_data == nullptr) { MS_LOG(ERROR) << "Dec key new failed"; + env->ReleaseStringUTFChars(model_path, c_model_path); return false; } + + jchar *key_array = env->GetCharArrayElements(key_str, NULL); + if (key_array == nullptr) { + MS_LOG(ERROR) << "GetCharArrayElements failed."; + env->ReleaseStringUTFChars(model_path, c_model_path); + return jlong(nullptr); + } for (size_t i = 0; i < key_len; i++) { dec_key_data[i] = key_array[i]; } env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT); mindspore::Key dec_key{dec_key_data, key_len}; + + if (dec_mod == nullptr || cropto_lib_path == nullptr) { + MS_LOG(ERROR) << "dec_mod, cropto_lib_path from java is nullptr."; + env->ReleaseStringUTFChars(model_path, c_model_path); + return jlong(nullptr); + } + auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE); auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE); status = lite_model_ptr->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path); + env->ReleaseStringUTFChars(dec_mod, c_dec_mod); + env->ReleaseStringUTFChars(cropto_lib_path, c_cropto_lib_path); + delete[] dec_key_data; } else { status = lite_model_ptr->Build(c_model_path, c_model_type, context); } + env->ReleaseStringUTFChars(model_path, c_model_path); if (status != mindspore::kSuccess) { MS_LOG(ERROR) << "Error status " << static_cast(status) << " during build of model"; return false; @@ -205,6 +233,8 @@ jobject GetInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_ptr, bool is_in auto *pointer = reinterpret_cast(model_ptr); if (pointer == nullptr) { MS_LOG(ERROR) << "Model pointer from java is nullptr"; + env->DeleteLocalRef(array_list); + env->DeleteLocalRef(long_object); return ret; } std::vector tensors; @@ -221,7 +251,10 @@ jobject GetInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_ptr, bool is_in } jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release())); env->CallBooleanMethod(ret, array_list_add, tensor_addr); + env->DeleteLocalRef(tensor_addr); } + env->DeleteLocalRef(array_list); + env->DeleteLocalRef(long_object); return ret; } @@ -233,11 +266,17 @@ jlong GetTensorByInOutName(JNIEnv *env, jlong model_ptr, jstring tensor_name, bo } auto *lite_model_ptr = static_cast(pointer); mindspore::MSTensor tensor; - if (is_input) { - tensor = lite_model_ptr->GetInputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE)); - } else { - tensor = lite_model_ptr->GetOutputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE)); + if (tensor_name == nullptr) { + MS_LOG(ERROR) << "tensor_name from java is nullptr."; + return jlong(nullptr); } + auto c_tensor_name = env->GetStringUTFChars(tensor_name, JNI_FALSE); + if (is_input) { + tensor = lite_model_ptr->GetInputByTensorName(c_tensor_name); + } else { + tensor = lite_model_ptr->GetOutputByTensorName(c_tensor_name); + } + env->ReleaseStringUTFChars(tensor_name, c_tensor_name); if (tensor.impl() == nullptr) { return jlong(nullptr); } @@ -283,8 +322,11 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputTensorNam auto *lite_model_ptr = static_cast(pointer); auto output_names = lite_model_ptr->GetOutputTensorNames(); for (const auto &output_name : output_names) { - env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str())); + auto output_name_jstring = env->NewStringUTF(output_name.c_str()); + env->CallBooleanMethod(ret, array_list_add, output_name_jstring); + env->DeleteLocalRef(output_name_jstring); } + env->DeleteLocalRef(array_list); return ret; } @@ -303,7 +345,13 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputsByNodeNa return ret; } auto *lite_model_ptr = static_cast(pointer); - auto tensors = lite_model_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE)); + if (node_name == nullptr) { + MS_LOG(ERROR) << "node_name from java is nullptr"; + return ret; + } + auto c_node_name = env->GetStringUTFChars(node_name, JNI_FALSE); + auto tensors = lite_model_ptr->GetOutputsByNodeName(c_node_name); + env->ReleaseStringUTFChars(node_name, c_node_name); for (auto &tensor : tensors) { auto tensor_ptr = std::make_unique(tensor); if (tensor_ptr == nullptr) { @@ -312,7 +360,10 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputsByNodeNa } jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release())); env->CallBooleanMethod(ret, array_list_add, tensor_addr); + env->DeleteLocalRef(tensor_addr); } + env->DeleteLocalRef(array_list); + env->DeleteLocalRef(long_object); return ret; } @@ -351,18 +402,24 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_runStep(JNIEnv *e } std::vector convertArrayToVector(JNIEnv *env, jlongArray inputs) { + std::vector c_inputs; + if (inputs == nullptr) { + MS_LOG(ERROR) << "inputs from java is nullptr"; + return c_inputs; + } auto input_size = static_cast(env->GetArrayLength(inputs)); jlong *input_data = env->GetLongArrayElements(inputs, nullptr); - std::vector c_inputs; for (int i = 0; i < input_size; i++) { auto *tensor_pointer = reinterpret_cast(input_data[i]); if (tensor_pointer == nullptr) { MS_LOG(ERROR) << "Tensor pointer from java is nullptr"; + env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT); return c_inputs; } auto *ms_tensor_ptr = static_cast(tensor_pointer); c_inputs.push_back(*ms_tensor_ptr); } + env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT); return c_inputs; } @@ -389,14 +446,22 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en return (jboolean) false; } auto *lite_model_ptr = static_cast(pointer); - + if (inputs == nullptr || dims == nullptr) { + MS_LOG(ERROR) << "inputs or dims from java is nullptr"; + return (jboolean) false; + } auto input_size = static_cast(env->GetArrayLength(inputs)); jlong *input_data = env->GetLongArrayElements(inputs, nullptr); + if (input_data == nullptr) { + MS_LOG(ERROR) << "input_data is nullptr"; + return (jboolean) false; + } std::vector c_inputs; for (int i = 0; i < input_size; i++) { auto *tensor_pointer = reinterpret_cast(input_data[i]); if (tensor_pointer == nullptr) { MS_LOG(ERROR) << "Tensor pointer from java is nullptr"; + env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT); return (jboolean) false; } auto &ms_tensor = *static_cast(tensor_pointer); @@ -405,8 +470,19 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en auto tensor_size = static_cast(env->GetArrayLength(dims)); for (int i = 0; i < tensor_size; i++) { auto array = static_cast(env->GetObjectArrayElement(dims, i)); + if (array == nullptr) { + MS_LOG(ERROR) << "Tensor pointer from java is nullptr"; + env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT); + return (jboolean) false; + } auto dim_size = static_cast(env->GetArrayLength(array)); jint *dim_data = env->GetIntArrayElements(array, nullptr); + if (dim_data == nullptr) { + MS_LOG(ERROR) << "dim_data is nullptr"; + env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT); + env->DeleteLocalRef(array); + return (jboolean) false; + } std::vector tensor_dims(dim_size); for (int j = 0; j < dim_size; j++) { tensor_dims[j] = dim_data[j]; @@ -416,16 +492,17 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en env->DeleteLocalRef(array); } auto ret = lite_model_ptr->Resize(c_inputs, c_dims); + env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT); return (jboolean)(ret.IsOk()); } extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_loadConfig(JNIEnv *env, jobject thiz, jstring model_ptr, jstring config_path) { - auto *model_pointer = reinterpret_cast(model_ptr); - if (model_pointer == nullptr) { - MS_LOG(ERROR) << "Model pointer from java is nullptr"; + if (model_ptr == nullptr || config_path == nullptr) { + MS_LOG(ERROR) << "input params from java is nullptr"; return (jboolean) false; } + auto *model_pointer = reinterpret_cast(model_ptr); auto *lite_model_ptr = static_cast(model_pointer); const char *c_config_path = env->GetStringUTFChars(config_path, nullptr); std::string str_config_path(c_config_path, env->GetStringLength(config_path)); diff --git a/mindspore/lite/java/src/main/native/ms_context.cpp b/mindspore/lite/java/src/main/native/ms_context.cpp index f9f0614aac9..bb905d26d6c 100644 --- a/mindspore/lite/java/src/main/native/ms_context.cpp +++ b/mindspore/lite/java/src/main/native/ms_context.cpp @@ -121,6 +121,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_addDev extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadNum(JNIEnv *env, jobject thiz, jlong context_ptr, jint thread_num) { auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return; + } c_context_ptr->SetThreadNum(thread_num); } @@ -132,6 +136,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadN extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadNum(JNIEnv *env, jobject thiz, jlong context_ptr) { auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return 0; + } int32_t thread_num = c_context_ptr->GetThreadNum(); return thread_num; } @@ -145,6 +153,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setInterOp jlong context_ptr, jint op_parallel_num) { auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return; + } c_context_ptr->SetInterOpParallelNum((int32_t)op_parallel_num); } @@ -156,6 +168,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setInterOp extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getInterOpParallelNum(JNIEnv *env, jobject thiz, jlong context_ptr) { auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return 0; + } auto inter_op_parallel_num = c_context_ptr->GetInterOpParallelNum(); return inter_op_parallel_num; } @@ -169,6 +185,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadA jlong context_ptr, jint thread_affinity) { auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return; + } c_context_ptr->SetThreadAffinity(thread_affinity); } @@ -180,6 +200,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadA extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadAffinityMode(JNIEnv *env, jobject thiz, jlong context_ptr) { auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return 0; + } auto thread_affinity_mode = c_context_ptr->GetThreadAffinityMode(); return thread_affinity_mode; } @@ -192,7 +216,15 @@ extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadA extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadAffinity__J_3I(JNIEnv *env, jobject thiz, jlong context_ptr, jintArray core_list) { + if (core_list == nullptr) { + MS_LOG(ERROR) << "core_list from java is nullptr"; + return; + } auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return; + } int32_t array_len = env->GetArrayLength(core_list); jboolean is_copy = JNI_FALSE; int *core_value = env->GetIntArrayElements(core_list, &is_copy); @@ -211,6 +243,10 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_config_MSContext_getThre jobject thiz, jlong context_ptr) { auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return nullptr; + } std::vector core_list_tmp = c_context_ptr->GetThreadAffinityCoreList(); jobject core_list = newObjectArrayList(env, core_list_tmp, "java/lang/Integer", "(I)V"); return core_list; @@ -225,6 +261,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setEnableP jlong context_ptr, jboolean is_parallel) { auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return; + } c_context_ptr->SetEnableParallel(static_cast(is_parallel)); } @@ -236,6 +276,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setEnableP extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_getEnableParallel(JNIEnv *env, jobject thiz, jlong context_ptr) { auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return (jboolean) false; + } bool is_parallel = c_context_ptr->GetEnableParallel(); return (jboolean)is_parallel; } @@ -243,5 +287,9 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_getEna extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_free(JNIEnv *env, jobject thiz, jlong context_ptr) { auto *c_context_ptr = static_cast(reinterpret_cast(context_ptr)); + if (c_context_ptr == nullptr) { + MS_LOG(ERROR) << "Context pointer from java is nullptr"; + return; + } delete (c_context_ptr); } diff --git a/mindspore/lite/java/src/main/native/ms_tensor.cpp b/mindspore/lite/java/src/main/native/ms_tensor.cpp index 5d18d003faf..09be849b5b9 100644 --- a/mindspore/lite/java/src/main/native/ms_tensor.cpp +++ b/mindspore/lite/java/src/main/native/ms_tensor.cpp @@ -30,6 +30,10 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_MSTensor_getShape(JNIE auto local_shape = ms_tensor_ptr->Shape(); auto shape_size = local_shape.size(); jintArray shape = env->NewIntArray(shape_size); + if (shape == nullptr) { + MS_LOG(ERROR) << "new intArray failed."; + return env->NewIntArray(0); + } auto *tmp = new jint[shape_size]; for (size_t i = 0; i < shape_size; i++) { tmp[i] = static_cast(local_shape.at(i)); @@ -69,6 +73,10 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_MSTensor_getByteData( return env->NewByteArray(0); } auto ret = env->NewByteArray(local_size); + if (ret == nullptr) { + MS_LOG(ERROR) << "malloc failed."; + return env->NewByteArray(0); + } env->SetByteArrayRegion(ret, 0, local_size, local_data); return ret; } @@ -99,6 +107,10 @@ extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_MSTensor_getLongData( return env->NewLongArray(0); } auto ret = env->NewLongArray(local_element_num); + if (ret == nullptr) { + MS_LOG(ERROR) << "malloc failed."; + return env->NewLongArray(0); + } env->SetLongArrayRegion(ret, 0, local_element_num, local_data); return ret; } @@ -129,6 +141,10 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_MSTensor_getIntData(JN return env->NewIntArray(0); } auto ret = env->NewIntArray(local_element_num); + if (ret == nullptr) { + MS_LOG(ERROR) << "malloc failed."; + return env->NewIntArray(0); + } env->SetIntArrayRegion(ret, 0, local_element_num, local_data); return ret; } @@ -159,6 +175,10 @@ extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_MSTensor_getFloatDat return env->NewFloatArray(0); } auto ret = env->NewFloatArray(local_element_num); + if (ret == nullptr) { + MS_LOG(ERROR) << "malloc failed."; + return env->NewFloatArray(0); + } env->SetFloatArrayRegion(ret, 0, local_element_num, local_data); return ret; } @@ -177,8 +197,18 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setByteData(JN return static_cast(false); } jboolean is_copy = false; + + if (data == nullptr) { + MS_LOG(ERROR) << "data from java is nullptr."; + return static_cast(false); + } auto *data_arr = env->GetByteArrayElements(data, &is_copy); auto *local_data = ms_tensor_ptr->MutableData(); + if (data_arr == nullptr || local_data == nullptr) { + MS_LOG(ERROR) << "data_arr or local_data is nullptr."; + env->ReleaseByteArrayElements(data, data_arr, JNI_ABORT); + return static_cast(false); + } memcpy(local_data, data_arr, data_len); env->ReleaseByteArrayElements(data, data_arr, JNI_ABORT); return static_cast(true); @@ -208,6 +238,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setFloatData(J MS_LOG(ERROR) << "malloc memory failed."; return static_cast(false); } + if (data == nullptr) { + MS_LOG(ERROR) << "data from java is nullptr"; + return static_cast(false); + } env->GetFloatArrayRegion(data, 0, static_cast(data_len), local_data); return static_cast(true); } @@ -236,6 +270,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setIntData(JNI MS_LOG(ERROR) << "malloc memory failed."; return static_cast(false); } + if (data == nullptr) { + MS_LOG(ERROR) << "data from java is nullptr"; + return static_cast(false); + } env->GetIntArrayRegion(data, 0, static_cast(data_len), local_data); return static_cast(true); } @@ -262,6 +300,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setLongData(JN MS_LOG(ERROR) << "malloc memory failed."; return static_cast(false); } + if (data == nullptr) { + MS_LOG(ERROR) << "data from java is nullptr"; + return static_cast(false); + } env->GetLongArrayRegion(data, 0, static_cast(data_len), local_data); return static_cast(true); } @@ -287,6 +329,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setByteBufferD return static_cast(false); } auto *local_data = ms_tensor_ptr->MutableData(); + if (local_data == nullptr) { + MS_LOG(ERROR) << "get MutableData nullptr."; + return static_cast(false); + } memcpy(local_data, p_data, data_len); return static_cast(true); } @@ -304,8 +350,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_size(JNIEnv *env, extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setShape(JNIEnv *env, jobject thiz, jlong tensor_ptr, jintArray tensor_shape) { auto *pointer = reinterpret_cast(tensor_ptr); - if (pointer == nullptr) { - MS_LOG(ERROR) << "Tensor pointer from java is nullptr"; + if (pointer == nullptr || tensor_shape == nullptr) { + MS_LOG(ERROR) << "input params from java is nullptr"; return static_cast(false); } @@ -357,6 +403,11 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_createTensorByNat jstring tensor_name, jint data_type, jintArray tensor_shape, jobject buffer) { + // check inputs + if (buffer == nullptr || tensor_name == nullptr || tensor_shape == nullptr) { + MS_LOG(ERROR) << "input param from java is nullptr"; + return 0; + } auto *p_data = reinterpret_cast(env->GetDirectBufferAddress(buffer)); jlong data_len = env->GetDirectBufferCapacity(buffer); if (p_data == nullptr) { @@ -370,11 +421,11 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_createTensorByNat for (int i = 0; i < size; i++) { c_shape[i] = static_cast(shape_pointer[i]); } - env->ReleaseIntArrayElements(tensor_shape, shape_pointer, JNI_ABORT); const char *c_tensor_name = env->GetStringUTFChars(tensor_name, nullptr); std::string str_tensor_name(c_tensor_name, env->GetStringLength(tensor_name)); auto tensor = mindspore::MSTensor::CreateTensor(str_tensor_name, static_cast(data_type), c_shape, p_data, data_len); + env->ReleaseIntArrayElements(tensor_shape, shape_pointer, JNI_ABORT); env->ReleaseStringUTFChars(tensor_name, c_tensor_name); return jlong(tensor); } diff --git a/mindspore/lite/java/src/main/native/train_config.cpp b/mindspore/lite/java/src/main/native/train_config.cpp index e33044222b8..5ff409a14ac 100644 --- a/mindspore/lite/java/src/main/native/train_config.cpp +++ b/mindspore/lite/java/src/main/native/train_config.cpp @@ -51,8 +51,10 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_TrainCfg_createTrai } if (loss_name != nullptr) { std::vector traincfg_loss_name = traincfg_ptr->GetLossName(); - traincfg_loss_name.emplace_back(env->GetStringUTFChars(loss_name, JNI_FALSE)); + auto c_loss_name = env->GetStringUTFChars(loss_name, JNI_FALSE); + traincfg_loss_name.emplace_back(c_loss_name); traincfg_ptr->SetLossName(traincfg_loss_name); + env->ReleaseStringUTFChars(loss_name, c_loss_name); } traincfg_ptr->optimization_level_ = ol; traincfg_ptr->accumulate_gradients_ = accmulateGrads; diff --git a/mindspore/lite/java/src/test/java/com/mindspore/MSTensorTest.java b/mindspore/lite/java/src/test/java/com/mindspore/MSTensorTest.java new file mode 100644 index 00000000000..a3684be2fbe --- /dev/null +++ b/mindspore/lite/java/src/test/java/com/mindspore/MSTensorTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import com.mindspore.config.DataType; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.nio.ByteBuffer; +import java.nio.FloatBuffer; +import java.util.Arrays; + +/** + * Model Test + * + * @since 2023.2 + */ +@RunWith(JUnit4.class) +public class MSTensorTest { + @Test + public void testCreateTensor() { + int[] tensorShape = {6, 5, 5, 1}; + float[] tensorData = new float[6 * 5 * 5]; + Arrays.fill(tensorData, 0); + ByteBuffer byteBuf = ByteBuffer.allocateDirect(6 * 5 * 5 * 4); + FloatBuffer floatBuf = byteBuf.asFloatBuffer(); + floatBuf.put(tensorData); + MSTensor newTensor = MSTensor.createTensor("conv1.weight", DataType.kNumberTypeFloat32, tensorShape, + byteBuf); + assertNotNull(newTensor); + } + + @Test + public void testSetData() { + int[] tensorShape = {6, 5, 5, 1}; + ByteBuffer byteBuf = ByteBuffer.allocateDirect(6 * 5 * 5 * 4); + MSTensor newTensor = MSTensor.createTensor("conv1.weight", DataType.kNumberTypeFloat32, tensorShape, + byteBuf); + assertNotNull(newTensor); + float[] tensorData = new float[6 * 5 * 5]; + Arrays.fill(tensorData, 0); + assertTrue(newTensor.setData(tensorData)); + } +} \ No newline at end of file