forked from mindspore-Ecosystem/mindspore
!49104 java jni codecheck
Merge pull request !49104 from liyan2022/master
This commit is contained in:
commit
0742b9c1a6
|
@ -18,6 +18,11 @@ package com.mindspore;
|
|||
|
||||
import com.mindspore.config.MindsporeLite;
|
||||
|
||||
/**
|
||||
* Graph Class
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class Graph {
|
||||
static {
|
||||
MindsporeLite.init();
|
||||
|
|
|
@ -22,6 +22,11 @@ import java.nio.ByteBuffer;
|
|||
import java.lang.reflect.Array;
|
||||
import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* The MSTensor class defines a tensor in MindSpore.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class MSTensor {
|
||||
static {
|
||||
MindsporeLite.init();
|
||||
|
|
|
@ -25,6 +25,11 @@ import java.nio.MappedByteBuffer;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* The Model class is used to define a MindSpore model, facilitating computational graph management.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class Model {
|
||||
static {
|
||||
MindsporeLite.init();
|
||||
|
@ -61,16 +66,20 @@ public class Model {
|
|||
* @param buffer model buffer.
|
||||
* @param modelType model type.
|
||||
* @param context model build context.
|
||||
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16.
|
||||
* @param dec_mode define the decryption mode. Options: AES-GCM.
|
||||
* @param cropto_lib_path define the openssl library path.
|
||||
* @param decKey define the key used to decrypt the ciphertext model. The key length is 16.
|
||||
* @param decMode define the decryption mode. Options: AES-GCM.
|
||||
* @param croptoLibPath define the openssl library path.
|
||||
* @return model build status.
|
||||
*/
|
||||
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
|
||||
if (context == null || buffer == null || dec_key == null || dec_mode == null) {
|
||||
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] decKey, String decMode,
|
||||
String croptoLibPath) {
|
||||
boolean isValid = (context != null && buffer != null && decKey != null && decMode != null &&
|
||||
croptoLibPath != null);
|
||||
if (!isValid) {
|
||||
return false;
|
||||
}
|
||||
return this.buildByBuffer(modelPtr, buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
|
||||
return this.buildByBuffer(modelPtr, buffer, modelType, context.getMSContextPtr(), decKey, decMode,
|
||||
croptoLibPath);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -95,16 +104,19 @@ public class Model {
|
|||
* @param modelPath model path.
|
||||
* @param modelType model type.
|
||||
* @param context model build context.
|
||||
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16.
|
||||
* @param dec_mode define the decryption mode. Options: AES-GCM.
|
||||
* @param cropto_lib_path define the openssl library path.
|
||||
* @param decKey define the key used to decrypt the ciphertext model. The key length is 16.
|
||||
* @param decMode define the decryption mode. Options: AES-GCM.
|
||||
* @param croptoLibPath define the openssl library path.
|
||||
* @return model build status.
|
||||
*/
|
||||
public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
|
||||
if (context == null || modelPath == null || dec_key == null || dec_mode == null) {
|
||||
public boolean build(String modelPath, int modelType, MSContext context, char[] decKey, String decMode,
|
||||
String croptoLibPath) {
|
||||
boolean isValid = (context != null && modelPath != null && decKey != null && decMode != null &&
|
||||
croptoLibPath != null);
|
||||
if (!isValid) {
|
||||
return false;
|
||||
}
|
||||
return this.buildByPath(modelPtr, modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
|
||||
return this.buildByPath(modelPtr, modelPath, modelType, context.getMSContextPtr(), decKey, decMode, croptoLibPath);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -164,9 +176,9 @@ public class Model {
|
|||
* @return input tensors.
|
||||
*/
|
||||
public List<MSTensor> getInputs() {
|
||||
List<Long> ret = this.getInputs(this.modelPtr);
|
||||
List<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
List<Long> tensorAddrs = this.getInputs(this.modelPtr);
|
||||
List<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
|
||||
for (Long msTensorAddr : tensorAddrs) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
|
@ -179,9 +191,9 @@ public class Model {
|
|||
* @return model outputs tensor.
|
||||
*/
|
||||
public List<MSTensor> getOutputs() {
|
||||
List<Long> ret = this.getOutputs(this.modelPtr);
|
||||
List<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
List<Long> tensorAddrs = this.getOutputs(this.modelPtr);
|
||||
List<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
|
||||
for (Long msTensorAddr : tensorAddrs) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
|
@ -224,11 +236,11 @@ public class Model {
|
|||
*/
|
||||
public List<MSTensor> getOutputsByNodeName(String nodeName) {
|
||||
if (nodeName == null) {
|
||||
return null;
|
||||
return new ArrayList<>();
|
||||
}
|
||||
List<Long> ret = this.getOutputsByNodeName(this.modelPtr, nodeName);
|
||||
List<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
List<Long> tensorAddrs = this.getOutputsByNodeName(this.modelPtr, nodeName);
|
||||
List<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
|
||||
for (Long msTensorAddr : tensorAddrs) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
|
@ -296,9 +308,9 @@ public class Model {
|
|||
* @return FeaturesMap Tensor list.
|
||||
*/
|
||||
public List<MSTensor> getFeatureMaps() {
|
||||
List<Long> ret = this.getFeatureMaps(this.modelPtr);
|
||||
ArrayList<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
List<Long> tensorAddrs = this.getFeatureMaps(this.modelPtr);
|
||||
ArrayList<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
|
||||
for (Long msTensorAddr : tensorAddrs) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
|
|
|
@ -22,7 +22,13 @@ package com.mindspore.config;
|
|||
* @since v1.0
|
||||
*/
|
||||
public class CpuBindMode {
|
||||
|
||||
// bind mind cpu
|
||||
public static final int MID_CPU = 2;
|
||||
|
||||
// bind high cpu
|
||||
public static final int HIGHER_CPU = 1;
|
||||
|
||||
// no bind
|
||||
public static final int NO_BIND = 0;
|
||||
}
|
|
@ -21,8 +21,16 @@ package com.mindspore.config;
|
|||
* @since v1.0
|
||||
*/
|
||||
public class DeviceType {
|
||||
|
||||
// support cpu
|
||||
public static final int DT_CPU = 0;
|
||||
|
||||
// support gpu
|
||||
public static final int DT_GPU = 1;
|
||||
|
||||
// support npu
|
||||
public static final int DT_NPU = 2;
|
||||
|
||||
// support ascend
|
||||
public static final int DT_ASCEND = 3;
|
||||
}
|
|
@ -20,17 +20,22 @@ import java.util.ArrayList;
|
|||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
|
||||
/**
|
||||
* Context is used to store environment variables during execution.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class MSContext {
|
||||
private static Logger LOGGER = MindsporeLite.GetLogger();
|
||||
static {
|
||||
MindsporeLite.init();
|
||||
}
|
||||
|
||||
private long msContextPtr;
|
||||
private static final long EMPTY_CONTEXT_PTR_VALUE = 0;
|
||||
private static final long EMPTY_CONTEXT_PTR_VALUE = 0L;
|
||||
private static final int ERROR_VALUE = -1;
|
||||
private static final String NULLPTR_ERROR_MESSAGE="Context pointer from java is nullptr.\n";
|
||||
private static final String NULLPTR_ERROR_MESSAGE="Context pointer from java is nullptr.";
|
||||
|
||||
private long msContextPtr;
|
||||
|
||||
/**
|
||||
* Construct function.
|
||||
|
@ -148,13 +153,13 @@ public class MSContext {
|
|||
* @return The current thread number setting.
|
||||
*/
|
||||
public int getThreadNum() {
|
||||
int ret_val = ERROR_VALUE;
|
||||
int retVal = ERROR_VALUE;
|
||||
if (isInitialized()) {
|
||||
ret_val = getThreadNum(this.msContextPtr);
|
||||
retVal = getThreadNum(this.msContextPtr);
|
||||
} else {
|
||||
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
|
||||
}
|
||||
return ret_val;
|
||||
return retVal;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -178,13 +183,13 @@ public class MSContext {
|
|||
* @return The current operators parallel number setting.
|
||||
*/
|
||||
public int getInterOpParallelNum() {
|
||||
int ret_val = ERROR_VALUE;
|
||||
int retVal = ERROR_VALUE;
|
||||
if (isInitialized()) {
|
||||
ret_val = getInterOpParallelNum(this.msContextPtr);
|
||||
retVal = getInterOpParallelNum(this.msContextPtr);
|
||||
} else {
|
||||
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
|
||||
}
|
||||
return ret_val;
|
||||
return retVal;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -209,13 +214,13 @@ public class MSContext {
|
|||
* @return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
|
||||
*/
|
||||
public int getThreadAffinityMode() {
|
||||
int ret_val = ERROR_VALUE;
|
||||
int retVal = ERROR_VALUE;
|
||||
if (isInitialized()) {
|
||||
ret_val = getThreadAffinityMode(this.msContextPtr);
|
||||
retVal = getThreadAffinityMode(this.msContextPtr);
|
||||
} else {
|
||||
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
|
||||
}
|
||||
return ret_val;
|
||||
return retVal;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -229,11 +234,11 @@ public class MSContext {
|
|||
public void setThreadAffinity(ArrayList<Integer> coreList) {
|
||||
if (isInitialized()) {
|
||||
int len = coreList.size();
|
||||
int[] coreList_array = new int[len];
|
||||
int[] coreListArray = new int[len];
|
||||
for (int i = 0; i < len; i++) {
|
||||
coreList_array[i] = coreList.get(i);
|
||||
coreListArray[i] = coreList.get(i);
|
||||
}
|
||||
setThreadAffinity(this.msContextPtr, coreList_array);
|
||||
setThreadAffinity(this.msContextPtr, coreListArray);
|
||||
} else {
|
||||
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
|
||||
}
|
||||
|
@ -247,13 +252,13 @@ public class MSContext {
|
|||
*/
|
||||
|
||||
public ArrayList<Integer> getThreadAffinityCoreList() {
|
||||
ArrayList<Integer> ret_val = new ArrayList<>();
|
||||
ArrayList<Integer> retVal = new ArrayList<>();
|
||||
if (isInitialized()) {
|
||||
ret_val = getThreadAffinityCoreList(this.msContextPtr);
|
||||
retVal = getThreadAffinityCoreList(this.msContextPtr);
|
||||
} else {
|
||||
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
|
||||
}
|
||||
return ret_val;
|
||||
return retVal;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -277,13 +282,13 @@ public class MSContext {
|
|||
* @return boolean value that indicates whether in parallel.
|
||||
*/
|
||||
public boolean getEnableParallel() {
|
||||
boolean ret_val = false;
|
||||
boolean retVal = false;
|
||||
if (isInitialized()) {
|
||||
ret_val = getEnableParallel(this.msContextPtr);
|
||||
retVal = getEnableParallel(this.msContextPtr);
|
||||
} else {
|
||||
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
|
||||
}
|
||||
return ret_val;
|
||||
return retVal;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,28 @@
|
|||
/**
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.config;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* MSLite Init Class
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public final class MindsporeLite {
|
||||
private static final Object lock = new Object();
|
||||
private static Logger LOGGER = GetLogger();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/**
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,6 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.config;
|
||||
|
||||
/**
|
||||
|
@ -21,9 +22,19 @@ package com.mindspore.config;
|
|||
* @since v1.0
|
||||
*/
|
||||
public class ModelType {
|
||||
|
||||
// mindir type
|
||||
public static final int MT_MINDIR = 0;
|
||||
|
||||
// air type
|
||||
public static final int MT_AIR = 1;
|
||||
|
||||
// om type
|
||||
public static final int MT_OM = 2;
|
||||
|
||||
// onnx type
|
||||
public static final int MT_ONNX = 3;
|
||||
|
||||
// mindir opt type
|
||||
public static final int MT_MINDIR_OPT = 4;
|
||||
}
|
|
@ -1,3 +1,18 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.config;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -5,7 +20,13 @@ import java.io.FileOutputStream;
|
|||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.logging.Logger;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* NativeLibrary Class
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class NativeLibrary {
|
||||
private static final Logger LOGGER = MindsporeLite.GetLogger();
|
||||
|
||||
|
@ -52,7 +73,7 @@ public class NativeLibrary {
|
|||
* libmsplugin-ge-litert
|
||||
* libruntime_convert_plugin
|
||||
*/
|
||||
public static void loadLibs() {
|
||||
private static void loadLibs() {
|
||||
loadLib(makeResourceName("lib" + GLOG_LIBNAME + ".so"));
|
||||
loadLib(makeResourceName("lib" + OPENCV_CORE_LIBNAME + ".so"));
|
||||
loadLib(makeResourceName("lib" + OPENCV_IMGPROC_LIBNAME + ".so"));
|
||||
|
@ -83,45 +104,45 @@ public class NativeLibrary {
|
|||
try {
|
||||
System.loadLibrary(MINDSPORE_LITE_JNI_LIBNAME);
|
||||
loadSuccess = true;
|
||||
LOGGER.info("loadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + ": success");
|
||||
LOGGER.info("loadLibrary mindspore-lite-jni success");
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
LOGGER.info(String.format("tryLoadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + " failed: %s", e.toString()));
|
||||
LOGGER.info(String.format(Locale.ENGLISH, "tryLoadLibrary mindspore-lite-jni failed: %s", e));
|
||||
}
|
||||
try {
|
||||
System.loadLibrary(MINDSPORE_LITE_TRAIN_JNI_LIBNAME);
|
||||
loadSuccess = true;
|
||||
LOGGER.info("loadLibrary " + MINDSPORE_LITE_TRAIN_JNI_LIBNAME + ": success.");
|
||||
LOGGER.info("loadLibrary mindspore-lite-train-jni success.");
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
LOGGER.info(String.format("tryLoadLibrary " + MINDSPORE_LITE_TRAIN_JNI_LIBNAME + " failed: %s", e.toString()));
|
||||
LOGGER.info(String.format(Locale.ENGLISH, "tryLoadLibrary mindspore-lite-train-jni failed: %s", e));
|
||||
}
|
||||
return loadSuccess;
|
||||
}
|
||||
|
||||
private static void loadLib(String libResourceName) {
|
||||
LOGGER.info("start load libResourceName: " + libResourceName);
|
||||
LOGGER.info(String.format(Locale.ENGLISH,"start load libResourceName: %s.", libResourceName));
|
||||
final InputStream libResource = NativeLibrary.class.getClassLoader().getResourceAsStream(libResourceName);
|
||||
if (libResource == null) {
|
||||
LOGGER.warning(String.format("lib file: %s not exist.", libResourceName));
|
||||
LOGGER.warning(String.format(Locale.ENGLISH,"lib file: %s not exist.", libResourceName));
|
||||
return;
|
||||
}
|
||||
try {
|
||||
final File tmpDir = mkTmpDir();
|
||||
String libName = libResourceName.substring(libResourceName.lastIndexOf("/") + 1);
|
||||
String libName = libResourceName.substring(libResourceName.lastIndexOf('/') + 1);
|
||||
tmpDir.deleteOnExit();
|
||||
|
||||
// copy file to tmpFile
|
||||
final File tmpFile = new File(tmpDir.getCanonicalPath(), libName);
|
||||
tmpFile.deleteOnExit();
|
||||
LOGGER.info(String.format("extract %d bytes to %s", copyLib(libResource, tmpFile), tmpFile));
|
||||
LOGGER.info(String.format("libName %s", libName));
|
||||
if (libName.equals("lib" + MINDSPORE_LITE_LIBNAME + ".so")) {
|
||||
LOGGER.info(String.format(Locale.ENGLISH,"extract %d bytes to %s", copyLib(libResource, tmpFile),
|
||||
tmpFile));
|
||||
if (("lib" + MINDSPORE_LITE_LIBNAME + ".so").equals(libName)) {
|
||||
extractLib(makeResourceName("lib" + MSPLUGIN_GE_LITERT_LIBNAME + ".so"), tmpDir);
|
||||
extractLib(makeResourceName("lib" + RUNTIME_CONVERT_PLUGIN_LIBNAME + ".so"), tmpDir);
|
||||
}
|
||||
System.load(tmpFile.toString());
|
||||
} catch (IOException e) {
|
||||
throw new UnsatisfiedLinkError(
|
||||
String.format("extract library into tmp file (%s) failed.", e.toString()));
|
||||
String.format(Locale.ENGLISH,"extract library into tmp file (%s) failed.", e));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,9 +164,8 @@ public class NativeLibrary {
|
|||
|
||||
|
||||
private static File mkTmpDir() {
|
||||
final String MINDSPORE_LITE_LIBS = "mindspore_lite_libs-";
|
||||
Long timestamp = System.currentTimeMillis();
|
||||
String dirName = MINDSPORE_LITE_LIBS + timestamp + "-";
|
||||
String dirName = "mindspore_lite_libs-" + timestamp + "-";
|
||||
for (int i = 0; i < 10; i++) {
|
||||
File tmpDir = new File(new File(System.getProperty("java.io.tmpdir")), dirName + i);
|
||||
if (tmpDir.mkdir()) {
|
||||
|
@ -167,7 +187,7 @@ public class NativeLibrary {
|
|||
|
||||
private static String architecture() {
|
||||
final String arch = System.getProperty("os.arch").toLowerCase();
|
||||
return (arch.equals("amd64")) ? "x86_64" : arch;
|
||||
return ("amd64".equals(arch)) ? "x86_64" : arch;
|
||||
}
|
||||
|
||||
private static void extractLib(String libResourceName, File targetDir) {
|
||||
|
|
|
@ -16,6 +16,11 @@
|
|||
|
||||
package com.mindspore.config;
|
||||
|
||||
/**
|
||||
* TrainCfg Class
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class TrainCfg {
|
||||
static {
|
||||
try {
|
||||
|
|
|
@ -37,9 +37,8 @@ public class Version {
|
|||
LOGGER.info("Version init load ...");
|
||||
try {
|
||||
NativeLibrary.load();
|
||||
} catch (Exception e) {
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
LOGGER.severe("Failed to load MindSporLite native library.");
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,8 +26,15 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Graph_loadModel(JNIEnv *en
|
|||
MS_LOG(ERROR) << "Model new failed";
|
||||
return jlong(nullptr);
|
||||
}
|
||||
if (ms_file == nullptr) {
|
||||
MS_LOG(ERROR) << "ms_file from java is nullptr.";
|
||||
delete graph;
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto c_ms_file = env->GetStringUTFChars(ms_file, JNI_FALSE);
|
||||
auto status =
|
||||
mindspore::Serialization::Load(env->GetStringUTFChars(ms_file, JNI_FALSE), mindspore::ModelType::kMindIR, graph);
|
||||
mindspore::Serialization::Load(c_ms_file, mindspore::ModelType::kMindIR, graph);
|
||||
env->ReleaseStringUTFChars(ms_file, c_ms_file);
|
||||
if (status != mindspore::kSuccess) {
|
||||
MS_LOG(ERROR) << "Load graph from file failed";
|
||||
delete graph;
|
||||
|
|
|
@ -111,23 +111,34 @@ extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
|
|||
return false;
|
||||
}
|
||||
context.reset(c_context_ptr);
|
||||
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
|
||||
mindspore::Status status;
|
||||
if (key_str != NULL) {
|
||||
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
|
||||
auto key_len = static_cast<size_t>(env->GetArrayLength(key_str));
|
||||
char *dec_key_data = new (std::nothrow) char[key_len];
|
||||
if (dec_key_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Dec key new failed";
|
||||
return false;
|
||||
}
|
||||
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
|
||||
if (key_array == nullptr) {
|
||||
MS_LOG(ERROR) << "key_array is nullptr.";
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < key_len; i++) {
|
||||
dec_key_data[i] = key_array[i];
|
||||
}
|
||||
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
|
||||
mindspore::Key dec_key{dec_key_data, key_len};
|
||||
if (cropto_lib_path == nullptr || dec_mod == nullptr) {
|
||||
MS_LOG(ERROR) << "cropto_lib_path or dec_mod from java is nullptr.";
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
|
||||
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
|
||||
status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
|
||||
env->ReleaseStringUTFChars(cropto_lib_path, c_cropto_lib_path);
|
||||
env->ReleaseStringUTFChars(dec_mod, c_dec_mod);
|
||||
delete[] dec_key_data;
|
||||
} else {
|
||||
status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context);
|
||||
}
|
||||
|
@ -167,26 +178,43 @@ extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *e
|
|||
return false;
|
||||
}
|
||||
context.reset(c_context_ptr);
|
||||
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
|
||||
mindspore::Status status;
|
||||
if (key_str != NULL) {
|
||||
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
|
||||
auto key_len = static_cast<size_t>(env->GetArrayLength(key_str));
|
||||
char *dec_key_data = new (std::nothrow) char[key_len];
|
||||
if (dec_key_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Dec key new failed";
|
||||
env->ReleaseStringUTFChars(model_path, c_model_path);
|
||||
return false;
|
||||
}
|
||||
|
||||
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
|
||||
if (key_array == nullptr) {
|
||||
MS_LOG(ERROR) << "GetCharArrayElements failed.";
|
||||
env->ReleaseStringUTFChars(model_path, c_model_path);
|
||||
return jlong(nullptr);
|
||||
}
|
||||
for (size_t i = 0; i < key_len; i++) {
|
||||
dec_key_data[i] = key_array[i];
|
||||
}
|
||||
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
|
||||
mindspore::Key dec_key{dec_key_data, key_len};
|
||||
|
||||
if (dec_mod == nullptr || cropto_lib_path == nullptr) {
|
||||
MS_LOG(ERROR) << "dec_mod, cropto_lib_path from java is nullptr.";
|
||||
env->ReleaseStringUTFChars(model_path, c_model_path);
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
|
||||
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
|
||||
status = lite_model_ptr->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
|
||||
env->ReleaseStringUTFChars(dec_mod, c_dec_mod);
|
||||
env->ReleaseStringUTFChars(cropto_lib_path, c_cropto_lib_path);
|
||||
delete[] dec_key_data;
|
||||
} else {
|
||||
status = lite_model_ptr->Build(c_model_path, c_model_type, context);
|
||||
}
|
||||
env->ReleaseStringUTFChars(model_path, c_model_path);
|
||||
if (status != mindspore::kSuccess) {
|
||||
MS_LOG(ERROR) << "Error status " << static_cast<int>(status) << " during build of model";
|
||||
return false;
|
||||
|
@ -205,6 +233,8 @@ jobject GetInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_ptr, bool is_in
|
|||
auto *pointer = reinterpret_cast<mindspore::Model *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "Model pointer from java is nullptr";
|
||||
env->DeleteLocalRef(array_list);
|
||||
env->DeleteLocalRef(long_object);
|
||||
return ret;
|
||||
}
|
||||
std::vector<mindspore::MSTensor> tensors;
|
||||
|
@ -221,7 +251,10 @@ jobject GetInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_ptr, bool is_in
|
|||
}
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
|
||||
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
|
||||
env->DeleteLocalRef(tensor_addr);
|
||||
}
|
||||
env->DeleteLocalRef(array_list);
|
||||
env->DeleteLocalRef(long_object);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -233,11 +266,17 @@ jlong GetTensorByInOutName(JNIEnv *env, jlong model_ptr, jstring tensor_name, bo
|
|||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
mindspore::MSTensor tensor;
|
||||
if (is_input) {
|
||||
tensor = lite_model_ptr->GetInputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
|
||||
} else {
|
||||
tensor = lite_model_ptr->GetOutputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
|
||||
if (tensor_name == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_name from java is nullptr.";
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto c_tensor_name = env->GetStringUTFChars(tensor_name, JNI_FALSE);
|
||||
if (is_input) {
|
||||
tensor = lite_model_ptr->GetInputByTensorName(c_tensor_name);
|
||||
} else {
|
||||
tensor = lite_model_ptr->GetOutputByTensorName(c_tensor_name);
|
||||
}
|
||||
env->ReleaseStringUTFChars(tensor_name, c_tensor_name);
|
||||
if (tensor.impl() == nullptr) {
|
||||
return jlong(nullptr);
|
||||
}
|
||||
|
@ -283,8 +322,11 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputTensorNam
|
|||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
auto output_names = lite_model_ptr->GetOutputTensorNames();
|
||||
for (const auto &output_name : output_names) {
|
||||
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
|
||||
auto output_name_jstring = env->NewStringUTF(output_name.c_str());
|
||||
env->CallBooleanMethod(ret, array_list_add, output_name_jstring);
|
||||
env->DeleteLocalRef(output_name_jstring);
|
||||
}
|
||||
env->DeleteLocalRef(array_list);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -303,7 +345,13 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputsByNodeNa
|
|||
return ret;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
auto tensors = lite_model_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE));
|
||||
if (node_name == nullptr) {
|
||||
MS_LOG(ERROR) << "node_name from java is nullptr";
|
||||
return ret;
|
||||
}
|
||||
auto c_node_name = env->GetStringUTFChars(node_name, JNI_FALSE);
|
||||
auto tensors = lite_model_ptr->GetOutputsByNodeName(c_node_name);
|
||||
env->ReleaseStringUTFChars(node_name, c_node_name);
|
||||
for (auto &tensor : tensors) {
|
||||
auto tensor_ptr = std::make_unique<mindspore::MSTensor>(tensor);
|
||||
if (tensor_ptr == nullptr) {
|
||||
|
@ -312,7 +360,10 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputsByNodeNa
|
|||
}
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
|
||||
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
|
||||
env->DeleteLocalRef(tensor_addr);
|
||||
}
|
||||
env->DeleteLocalRef(array_list);
|
||||
env->DeleteLocalRef(long_object);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -351,18 +402,24 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_runStep(JNIEnv *e
|
|||
}
|
||||
|
||||
std::vector<mindspore::MSTensor> convertArrayToVector(JNIEnv *env, jlongArray inputs) {
|
||||
std::vector<mindspore::MSTensor> c_inputs;
|
||||
if (inputs == nullptr) {
|
||||
MS_LOG(ERROR) << "inputs from java is nullptr";
|
||||
return c_inputs;
|
||||
}
|
||||
auto input_size = static_cast<int>(env->GetArrayLength(inputs));
|
||||
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
|
||||
std::vector<mindspore::MSTensor> c_inputs;
|
||||
for (int i = 0; i < input_size; i++) {
|
||||
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
|
||||
if (tensor_pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
|
||||
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
|
||||
return c_inputs;
|
||||
}
|
||||
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(tensor_pointer);
|
||||
c_inputs.push_back(*ms_tensor_ptr);
|
||||
}
|
||||
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
|
||||
return c_inputs;
|
||||
}
|
||||
|
||||
|
@ -389,14 +446,22 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
|
|||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
|
||||
if (inputs == nullptr || dims == nullptr) {
|
||||
MS_LOG(ERROR) << "inputs or dims from java is nullptr";
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto input_size = static_cast<int>(env->GetArrayLength(inputs));
|
||||
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "input_data is nullptr";
|
||||
return (jboolean) false;
|
||||
}
|
||||
std::vector<mindspore::MSTensor> c_inputs;
|
||||
for (int i = 0; i < input_size; i++) {
|
||||
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
|
||||
if (tensor_pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
|
||||
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto &ms_tensor = *static_cast<mindspore::MSTensor *>(tensor_pointer);
|
||||
|
@ -405,8 +470,19 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
|
|||
auto tensor_size = static_cast<int>(env->GetArrayLength(dims));
|
||||
for (int i = 0; i < tensor_size; i++) {
|
||||
auto array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i));
|
||||
if (array == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
|
||||
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto dim_size = static_cast<int>(env->GetArrayLength(array));
|
||||
jint *dim_data = env->GetIntArrayElements(array, nullptr);
|
||||
if (dim_data == nullptr) {
|
||||
MS_LOG(ERROR) << "dim_data is nullptr";
|
||||
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
|
||||
env->DeleteLocalRef(array);
|
||||
return (jboolean) false;
|
||||
}
|
||||
std::vector<int64_t> tensor_dims(dim_size);
|
||||
for (int j = 0; j < dim_size; j++) {
|
||||
tensor_dims[j] = dim_data[j];
|
||||
|
@ -416,16 +492,17 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
|
|||
env->DeleteLocalRef(array);
|
||||
}
|
||||
auto ret = lite_model_ptr->Resize(c_inputs, c_dims);
|
||||
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
|
||||
return (jboolean)(ret.IsOk());
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_loadConfig(JNIEnv *env, jobject thiz, jstring model_ptr,
|
||||
jstring config_path) {
|
||||
auto *model_pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (model_pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "Model pointer from java is nullptr";
|
||||
if (model_ptr == nullptr || config_path == nullptr) {
|
||||
MS_LOG(ERROR) << "input params from java is nullptr";
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *model_pointer = reinterpret_cast<void *>(model_ptr);
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(model_pointer);
|
||||
const char *c_config_path = env->GetStringUTFChars(config_path, nullptr);
|
||||
std::string str_config_path(c_config_path, env->GetStringLength(config_path));
|
||||
|
|
|
@ -121,6 +121,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_addDev
|
|||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadNum(JNIEnv *env, jobject thiz,
|
||||
jlong context_ptr, jint thread_num) {
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return;
|
||||
}
|
||||
c_context_ptr->SetThreadNum(thread_num);
|
||||
}
|
||||
|
||||
|
@ -132,6 +136,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadN
|
|||
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadNum(JNIEnv *env, jobject thiz,
|
||||
jlong context_ptr) {
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return 0;
|
||||
}
|
||||
int32_t thread_num = c_context_ptr->GetThreadNum();
|
||||
return thread_num;
|
||||
}
|
||||
|
@ -145,6 +153,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setInterOp
|
|||
jlong context_ptr,
|
||||
jint op_parallel_num) {
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return;
|
||||
}
|
||||
c_context_ptr->SetInterOpParallelNum((int32_t)op_parallel_num);
|
||||
}
|
||||
|
||||
|
@ -156,6 +168,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setInterOp
|
|||
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getInterOpParallelNum(JNIEnv *env, jobject thiz,
|
||||
jlong context_ptr) {
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return 0;
|
||||
}
|
||||
auto inter_op_parallel_num = c_context_ptr->GetInterOpParallelNum();
|
||||
return inter_op_parallel_num;
|
||||
}
|
||||
|
@ -169,6 +185,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadA
|
|||
jlong context_ptr,
|
||||
jint thread_affinity) {
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return;
|
||||
}
|
||||
c_context_ptr->SetThreadAffinity(thread_affinity);
|
||||
}
|
||||
|
||||
|
@ -180,6 +200,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadA
|
|||
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadAffinityMode(JNIEnv *env, jobject thiz,
|
||||
jlong context_ptr) {
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return 0;
|
||||
}
|
||||
auto thread_affinity_mode = c_context_ptr->GetThreadAffinityMode();
|
||||
return thread_affinity_mode;
|
||||
}
|
||||
|
@ -192,7 +216,15 @@ extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadA
|
|||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadAffinity__J_3I(JNIEnv *env, jobject thiz,
|
||||
jlong context_ptr,
|
||||
jintArray core_list) {
|
||||
if (core_list == nullptr) {
|
||||
MS_LOG(ERROR) << "core_list from java is nullptr";
|
||||
return;
|
||||
}
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return;
|
||||
}
|
||||
int32_t array_len = env->GetArrayLength(core_list);
|
||||
jboolean is_copy = JNI_FALSE;
|
||||
int *core_value = env->GetIntArrayElements(core_list, &is_copy);
|
||||
|
@ -211,6 +243,10 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_config_MSContext_getThre
|
|||
jobject thiz,
|
||||
jlong context_ptr) {
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<int32_t> core_list_tmp = c_context_ptr->GetThreadAffinityCoreList();
|
||||
jobject core_list = newObjectArrayList<int32_t>(env, core_list_tmp, "java/lang/Integer", "(I)V");
|
||||
return core_list;
|
||||
|
@ -225,6 +261,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setEnableP
|
|||
jlong context_ptr,
|
||||
jboolean is_parallel) {
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return;
|
||||
}
|
||||
c_context_ptr->SetEnableParallel(static_cast<bool>(is_parallel));
|
||||
}
|
||||
|
||||
|
@ -236,6 +276,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setEnableP
|
|||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_getEnableParallel(JNIEnv *env, jobject thiz,
|
||||
jlong context_ptr) {
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return (jboolean) false;
|
||||
}
|
||||
bool is_parallel = c_context_ptr->GetEnableParallel();
|
||||
return (jboolean)is_parallel;
|
||||
}
|
||||
|
@ -243,5 +287,9 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_getEna
|
|||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_free(JNIEnv *env, jobject thiz,
|
||||
jlong context_ptr) {
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
|
||||
if (c_context_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Context pointer from java is nullptr";
|
||||
return;
|
||||
}
|
||||
delete (c_context_ptr);
|
||||
}
|
||||
|
|
|
@ -30,6 +30,10 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_MSTensor_getShape(JNIE
|
|||
auto local_shape = ms_tensor_ptr->Shape();
|
||||
auto shape_size = local_shape.size();
|
||||
jintArray shape = env->NewIntArray(shape_size);
|
||||
if (shape == nullptr) {
|
||||
MS_LOG(ERROR) << "new intArray failed.";
|
||||
return env->NewIntArray(0);
|
||||
}
|
||||
auto *tmp = new jint[shape_size];
|
||||
for (size_t i = 0; i < shape_size; i++) {
|
||||
tmp[i] = static_cast<int>(local_shape.at(i));
|
||||
|
@ -69,6 +73,10 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_MSTensor_getByteData(
|
|||
return env->NewByteArray(0);
|
||||
}
|
||||
auto ret = env->NewByteArray(local_size);
|
||||
if (ret == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc failed.";
|
||||
return env->NewByteArray(0);
|
||||
}
|
||||
env->SetByteArrayRegion(ret, 0, local_size, local_data);
|
||||
return ret;
|
||||
}
|
||||
|
@ -99,6 +107,10 @@ extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_MSTensor_getLongData(
|
|||
return env->NewLongArray(0);
|
||||
}
|
||||
auto ret = env->NewLongArray(local_element_num);
|
||||
if (ret == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc failed.";
|
||||
return env->NewLongArray(0);
|
||||
}
|
||||
env->SetLongArrayRegion(ret, 0, local_element_num, local_data);
|
||||
return ret;
|
||||
}
|
||||
|
@ -129,6 +141,10 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_MSTensor_getIntData(JN
|
|||
return env->NewIntArray(0);
|
||||
}
|
||||
auto ret = env->NewIntArray(local_element_num);
|
||||
if (ret == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc failed.";
|
||||
return env->NewIntArray(0);
|
||||
}
|
||||
env->SetIntArrayRegion(ret, 0, local_element_num, local_data);
|
||||
return ret;
|
||||
}
|
||||
|
@ -159,6 +175,10 @@ extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_MSTensor_getFloatDat
|
|||
return env->NewFloatArray(0);
|
||||
}
|
||||
auto ret = env->NewFloatArray(local_element_num);
|
||||
if (ret == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc failed.";
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
env->SetFloatArrayRegion(ret, 0, local_element_num, local_data);
|
||||
return ret;
|
||||
}
|
||||
|
@ -177,8 +197,18 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setByteData(JN
|
|||
return static_cast<jboolean>(false);
|
||||
}
|
||||
jboolean is_copy = false;
|
||||
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "data from java is nullptr.";
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
auto *data_arr = env->GetByteArrayElements(data, &is_copy);
|
||||
auto *local_data = ms_tensor_ptr->MutableData();
|
||||
if (data_arr == nullptr || local_data == nullptr) {
|
||||
MS_LOG(ERROR) << "data_arr or local_data is nullptr.";
|
||||
env->ReleaseByteArrayElements(data, data_arr, JNI_ABORT);
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
memcpy(local_data, data_arr, data_len);
|
||||
env->ReleaseByteArrayElements(data, data_arr, JNI_ABORT);
|
||||
return static_cast<jboolean>(true);
|
||||
|
@ -208,6 +238,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setFloatData(J
|
|||
MS_LOG(ERROR) << "malloc memory failed.";
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "data from java is nullptr";
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
env->GetFloatArrayRegion(data, 0, static_cast<jsize>(data_len), local_data);
|
||||
return static_cast<jboolean>(true);
|
||||
}
|
||||
|
@ -236,6 +270,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setIntData(JNI
|
|||
MS_LOG(ERROR) << "malloc memory failed.";
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "data from java is nullptr";
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
env->GetIntArrayRegion(data, 0, static_cast<jsize>(data_len), local_data);
|
||||
return static_cast<jboolean>(true);
|
||||
}
|
||||
|
@ -262,6 +300,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setLongData(JN
|
|||
MS_LOG(ERROR) << "malloc memory failed.";
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "data from java is nullptr";
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
env->GetLongArrayRegion(data, 0, static_cast<jsize>(data_len), local_data);
|
||||
return static_cast<jboolean>(true);
|
||||
}
|
||||
|
@ -287,6 +329,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setByteBufferD
|
|||
return static_cast<jboolean>(false);
|
||||
}
|
||||
auto *local_data = ms_tensor_ptr->MutableData();
|
||||
if (local_data == nullptr) {
|
||||
MS_LOG(ERROR) << "get MutableData nullptr.";
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
memcpy(local_data, p_data, data_len);
|
||||
return static_cast<jboolean>(true);
|
||||
}
|
||||
|
@ -304,8 +350,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_size(JNIEnv *env,
|
|||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setShape(JNIEnv *env, jobject thiz, jlong tensor_ptr,
|
||||
jintArray tensor_shape) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
|
||||
if (pointer == nullptr || tensor_shape == nullptr) {
|
||||
MS_LOG(ERROR) << "input params from java is nullptr";
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
|
||||
|
@ -357,6 +403,11 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_createTensorByNat
|
|||
jstring tensor_name, jint data_type,
|
||||
jintArray tensor_shape,
|
||||
jobject buffer) {
|
||||
// check inputs
|
||||
if (buffer == nullptr || tensor_name == nullptr || tensor_shape == nullptr) {
|
||||
MS_LOG(ERROR) << "input param from java is nullptr";
|
||||
return 0;
|
||||
}
|
||||
auto *p_data = reinterpret_cast<jbyte *>(env->GetDirectBufferAddress(buffer));
|
||||
jlong data_len = env->GetDirectBufferCapacity(buffer);
|
||||
if (p_data == nullptr) {
|
||||
|
@ -370,11 +421,11 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_createTensorByNat
|
|||
for (int i = 0; i < size; i++) {
|
||||
c_shape[i] = static_cast<int64_t>(shape_pointer[i]);
|
||||
}
|
||||
env->ReleaseIntArrayElements(tensor_shape, shape_pointer, JNI_ABORT);
|
||||
const char *c_tensor_name = env->GetStringUTFChars(tensor_name, nullptr);
|
||||
std::string str_tensor_name(c_tensor_name, env->GetStringLength(tensor_name));
|
||||
auto tensor = mindspore::MSTensor::CreateTensor(str_tensor_name, static_cast<mindspore::DataType>(data_type), c_shape,
|
||||
p_data, data_len);
|
||||
env->ReleaseIntArrayElements(tensor_shape, shape_pointer, JNI_ABORT);
|
||||
env->ReleaseStringUTFChars(tensor_name, c_tensor_name);
|
||||
return jlong(tensor);
|
||||
}
|
||||
|
|
|
@ -51,8 +51,10 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_TrainCfg_createTrai
|
|||
}
|
||||
if (loss_name != nullptr) {
|
||||
std::vector<std::string> traincfg_loss_name = traincfg_ptr->GetLossName();
|
||||
traincfg_loss_name.emplace_back(env->GetStringUTFChars(loss_name, JNI_FALSE));
|
||||
auto c_loss_name = env->GetStringUTFChars(loss_name, JNI_FALSE);
|
||||
traincfg_loss_name.emplace_back(c_loss_name);
|
||||
traincfg_ptr->SetLossName(traincfg_loss_name);
|
||||
env->ReleaseStringUTFChars(loss_name, c_loss_name);
|
||||
}
|
||||
traincfg_ptr->optimization_level_ = ol;
|
||||
traincfg_ptr->accumulate_gradients_ = accmulateGrads;
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore;
|
||||
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import com.mindspore.config.DataType;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* Model Test
|
||||
*
|
||||
* @since 2023.2
|
||||
*/
|
||||
@RunWith(JUnit4.class)
|
||||
public class MSTensorTest {
|
||||
@Test
|
||||
public void testCreateTensor() {
|
||||
int[] tensorShape = {6, 5, 5, 1};
|
||||
float[] tensorData = new float[6 * 5 * 5];
|
||||
Arrays.fill(tensorData, 0);
|
||||
ByteBuffer byteBuf = ByteBuffer.allocateDirect(6 * 5 * 5 * 4);
|
||||
FloatBuffer floatBuf = byteBuf.asFloatBuffer();
|
||||
floatBuf.put(tensorData);
|
||||
MSTensor newTensor = MSTensor.createTensor("conv1.weight", DataType.kNumberTypeFloat32, tensorShape,
|
||||
byteBuf);
|
||||
assertNotNull(newTensor);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSetData() {
|
||||
int[] tensorShape = {6, 5, 5, 1};
|
||||
ByteBuffer byteBuf = ByteBuffer.allocateDirect(6 * 5 * 5 * 4);
|
||||
MSTensor newTensor = MSTensor.createTensor("conv1.weight", DataType.kNumberTypeFloat32, tensorShape,
|
||||
byteBuf);
|
||||
assertNotNull(newTensor);
|
||||
float[] tensorData = new float[6 * 5 * 5];
|
||||
Arrays.fill(tensorData, 0);
|
||||
assertTrue(newTensor.setData(tensorData));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue