!49104 java jni codecheck

Merge pull request !49104 from liyan2022/master
This commit is contained in:
i-robot 2023-03-08 06:30:22 +00:00 committed by Gitee
commit 0742b9c1a6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 433 additions and 88 deletions

View File

@ -18,6 +18,11 @@ package com.mindspore;
import com.mindspore.config.MindsporeLite;
/**
* Graph Class
*
* @since v1.0
*/
public class Graph {
static {
MindsporeLite.init();

View File

@ -22,6 +22,11 @@ import java.nio.ByteBuffer;
import java.lang.reflect.Array;
import java.util.HashMap;
/**
* The MSTensor class defines a tensor in MindSpore.
*
* @since v1.0
*/
public class MSTensor {
static {
MindsporeLite.init();

View File

@ -25,6 +25,11 @@ import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.List;
/**
* The Model class is used to define a MindSpore model, facilitating computational graph management.
*
* @since v1.0
*/
public class Model {
static {
MindsporeLite.init();
@ -61,16 +66,20 @@ public class Model {
* @param buffer model buffer.
* @param modelType model type.
* @param context model build context.
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16.
* @param dec_mode define the decryption mode. Options: AES-GCM.
* @param cropto_lib_path define the openssl library path.
* @param decKey define the key used to decrypt the ciphertext model. The key length is 16.
* @param decMode define the decryption mode. Options: AES-GCM.
* @param croptoLibPath define the openssl library path.
* @return model build status.
*/
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
if (context == null || buffer == null || dec_key == null || dec_mode == null) {
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] decKey, String decMode,
String croptoLibPath) {
boolean isValid = (context != null && buffer != null && decKey != null && decMode != null &&
croptoLibPath != null);
if (!isValid) {
return false;
}
return this.buildByBuffer(modelPtr, buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
return this.buildByBuffer(modelPtr, buffer, modelType, context.getMSContextPtr(), decKey, decMode,
croptoLibPath);
}
/**
@ -95,16 +104,19 @@ public class Model {
* @param modelPath model path.
* @param modelType model type.
* @param context model build context.
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16.
* @param dec_mode define the decryption mode. Options: AES-GCM.
* @param cropto_lib_path define the openssl library path.
* @param decKey define the key used to decrypt the ciphertext model. The key length is 16.
* @param decMode define the decryption mode. Options: AES-GCM.
* @param croptoLibPath define the openssl library path.
* @return model build status.
*/
public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
if (context == null || modelPath == null || dec_key == null || dec_mode == null) {
public boolean build(String modelPath, int modelType, MSContext context, char[] decKey, String decMode,
String croptoLibPath) {
boolean isValid = (context != null && modelPath != null && decKey != null && decMode != null &&
croptoLibPath != null);
if (!isValid) {
return false;
}
return this.buildByPath(modelPtr, modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
return this.buildByPath(modelPtr, modelPath, modelType, context.getMSContextPtr(), decKey, decMode, croptoLibPath);
}
/**
@ -164,9 +176,9 @@ public class Model {
* @return input tensors.
*/
public List<MSTensor> getInputs() {
List<Long> ret = this.getInputs(this.modelPtr);
List<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
List<Long> tensorAddrs = this.getInputs(this.modelPtr);
List<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
for (Long msTensorAddr : tensorAddrs) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
}
@ -179,9 +191,9 @@ public class Model {
* @return model outputs tensor.
*/
public List<MSTensor> getOutputs() {
List<Long> ret = this.getOutputs(this.modelPtr);
List<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
List<Long> tensorAddrs = this.getOutputs(this.modelPtr);
List<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
for (Long msTensorAddr : tensorAddrs) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
}
@ -224,11 +236,11 @@ public class Model {
*/
public List<MSTensor> getOutputsByNodeName(String nodeName) {
if (nodeName == null) {
return null;
return new ArrayList<>();
}
List<Long> ret = this.getOutputsByNodeName(this.modelPtr, nodeName);
List<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
List<Long> tensorAddrs = this.getOutputsByNodeName(this.modelPtr, nodeName);
List<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
for (Long msTensorAddr : tensorAddrs) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
}
@ -296,9 +308,9 @@ public class Model {
* @return FeaturesMap Tensor list.
*/
public List<MSTensor> getFeatureMaps() {
List<Long> ret = this.getFeatureMaps(this.modelPtr);
ArrayList<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
List<Long> tensorAddrs = this.getFeatureMaps(this.modelPtr);
ArrayList<MSTensor> tensors = new ArrayList<>(tensorAddrs.size());
for (Long msTensorAddr : tensorAddrs) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
}

View File

@ -22,7 +22,13 @@ package com.mindspore.config;
* @since v1.0
*/
public class CpuBindMode {
// bind mind cpu
public static final int MID_CPU = 2;
// bind high cpu
public static final int HIGHER_CPU = 1;
// no bind
public static final int NO_BIND = 0;
}

View File

@ -21,8 +21,16 @@ package com.mindspore.config;
* @since v1.0
*/
public class DeviceType {
// support cpu
public static final int DT_CPU = 0;
// support gpu
public static final int DT_GPU = 1;
// support npu
public static final int DT_NPU = 2;
// support ascend
public static final int DT_ASCEND = 3;
}

View File

@ -20,17 +20,22 @@ import java.util.ArrayList;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Context is used to store environment variables during execution.
*
* @since v1.0
*/
public class MSContext {
private static Logger LOGGER = MindsporeLite.GetLogger();
static {
MindsporeLite.init();
}
private long msContextPtr;
private static final long EMPTY_CONTEXT_PTR_VALUE = 0;
private static final long EMPTY_CONTEXT_PTR_VALUE = 0L;
private static final int ERROR_VALUE = -1;
private static final String NULLPTR_ERROR_MESSAGE="Context pointer from java is nullptr.\n";
private static final String NULLPTR_ERROR_MESSAGE="Context pointer from java is nullptr.";
private long msContextPtr;
/**
* Construct function.
@ -148,13 +153,13 @@ public class MSContext {
* @return The current thread number setting.
*/
public int getThreadNum() {
int ret_val = ERROR_VALUE;
int retVal = ERROR_VALUE;
if (isInitialized()) {
ret_val = getThreadNum(this.msContextPtr);
retVal = getThreadNum(this.msContextPtr);
} else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
}
return ret_val;
return retVal;
}
/**
@ -178,13 +183,13 @@ public class MSContext {
* @return The current operators parallel number setting.
*/
public int getInterOpParallelNum() {
int ret_val = ERROR_VALUE;
int retVal = ERROR_VALUE;
if (isInitialized()) {
ret_val = getInterOpParallelNum(this.msContextPtr);
retVal = getInterOpParallelNum(this.msContextPtr);
} else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
}
return ret_val;
return retVal;
}
/**
@ -209,13 +214,13 @@ public class MSContext {
* @return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
*/
public int getThreadAffinityMode() {
int ret_val = ERROR_VALUE;
int retVal = ERROR_VALUE;
if (isInitialized()) {
ret_val = getThreadAffinityMode(this.msContextPtr);
retVal = getThreadAffinityMode(this.msContextPtr);
} else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
}
return ret_val;
return retVal;
}
/**
@ -229,11 +234,11 @@ public class MSContext {
public void setThreadAffinity(ArrayList<Integer> coreList) {
if (isInitialized()) {
int len = coreList.size();
int[] coreList_array = new int[len];
int[] coreListArray = new int[len];
for (int i = 0; i < len; i++) {
coreList_array[i] = coreList.get(i);
coreListArray[i] = coreList.get(i);
}
setThreadAffinity(this.msContextPtr, coreList_array);
setThreadAffinity(this.msContextPtr, coreListArray);
} else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
}
@ -247,13 +252,13 @@ public class MSContext {
*/
public ArrayList<Integer> getThreadAffinityCoreList() {
ArrayList<Integer> ret_val = new ArrayList<>();
ArrayList<Integer> retVal = new ArrayList<>();
if (isInitialized()) {
ret_val = getThreadAffinityCoreList(this.msContextPtr);
retVal = getThreadAffinityCoreList(this.msContextPtr);
} else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
}
return ret_val;
return retVal;
}
/**
@ -277,13 +282,13 @@ public class MSContext {
* @return boolean value that indicates whether in parallel.
*/
public boolean getEnableParallel() {
boolean ret_val = false;
boolean retVal = false;
if (isInitialized()) {
ret_val = getEnableParallel(this.msContextPtr);
retVal = getEnableParallel(this.msContextPtr);
} else {
LOGGER.log(Level.SEVERE, NULLPTR_ERROR_MESSAGE);
}
return ret_val;
return retVal;
}

View File

@ -1,7 +1,28 @@
/**
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.config;
import java.util.logging.Logger;
/**
* MSLite Init Class
*
* @since v1.0
*/
public final class MindsporeLite {
private static final Object lock = new Object();
private static Logger LOGGER = GetLogger();

View File

@ -1,5 +1,5 @@
/*
* Copyright 2021 Huawei Technologies Co., Ltd
/**
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.config;
/**
@ -21,9 +22,19 @@ package com.mindspore.config;
* @since v1.0
*/
public class ModelType {
// mindir type
public static final int MT_MINDIR = 0;
// air type
public static final int MT_AIR = 1;
// om type
public static final int MT_OM = 2;
// onnx type
public static final int MT_ONNX = 3;
// mindir opt type
public static final int MT_MINDIR_OPT = 4;
}

View File

@ -1,3 +1,18 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.config;
import java.io.File;
@ -5,7 +20,13 @@ import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.logging.Logger;
import java.util.Locale;
/**
* NativeLibrary Class
*
* @since v1.0
*/
public class NativeLibrary {
private static final Logger LOGGER = MindsporeLite.GetLogger();
@ -52,7 +73,7 @@ public class NativeLibrary {
* libmsplugin-ge-litert
* libruntime_convert_plugin
*/
public static void loadLibs() {
private static void loadLibs() {
loadLib(makeResourceName("lib" + GLOG_LIBNAME + ".so"));
loadLib(makeResourceName("lib" + OPENCV_CORE_LIBNAME + ".so"));
loadLib(makeResourceName("lib" + OPENCV_IMGPROC_LIBNAME + ".so"));
@ -83,49 +104,49 @@ public class NativeLibrary {
try {
System.loadLibrary(MINDSPORE_LITE_JNI_LIBNAME);
loadSuccess = true;
LOGGER.info("loadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + ": success");
LOGGER.info("loadLibrary mindspore-lite-jni success");
} catch (UnsatisfiedLinkError e) {
LOGGER.info(String.format("tryLoadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + " failed: %s", e.toString()));
LOGGER.info(String.format(Locale.ENGLISH, "tryLoadLibrary mindspore-lite-jni failed: %s", e));
}
try {
System.loadLibrary(MINDSPORE_LITE_TRAIN_JNI_LIBNAME);
loadSuccess = true;
LOGGER.info("loadLibrary " + MINDSPORE_LITE_TRAIN_JNI_LIBNAME + ": success.");
LOGGER.info("loadLibrary mindspore-lite-train-jni success.");
} catch (UnsatisfiedLinkError e) {
LOGGER.info(String.format("tryLoadLibrary " + MINDSPORE_LITE_TRAIN_JNI_LIBNAME + " failed: %s", e.toString()));
LOGGER.info(String.format(Locale.ENGLISH, "tryLoadLibrary mindspore-lite-train-jni failed: %s", e));
}
return loadSuccess;
}
private static void loadLib(String libResourceName) {
LOGGER.info("start load libResourceName: " + libResourceName);
LOGGER.info(String.format(Locale.ENGLISH,"start load libResourceName: %s.", libResourceName));
final InputStream libResource = NativeLibrary.class.getClassLoader().getResourceAsStream(libResourceName);
if (libResource == null) {
LOGGER.warning(String.format("lib file: %s not exist.", libResourceName));
LOGGER.warning(String.format(Locale.ENGLISH,"lib file: %s not exist.", libResourceName));
return;
}
try {
final File tmpDir = mkTmpDir();
String libName = libResourceName.substring(libResourceName.lastIndexOf("/") + 1);
String libName = libResourceName.substring(libResourceName.lastIndexOf('/') + 1);
tmpDir.deleteOnExit();
//copy file to tmpFile
// copy file to tmpFile
final File tmpFile = new File(tmpDir.getCanonicalPath(), libName);
tmpFile.deleteOnExit();
LOGGER.info(String.format("extract %d bytes to %s", copyLib(libResource, tmpFile), tmpFile));
LOGGER.info(String.format("libName %s", libName));
if (libName.equals("lib" + MINDSPORE_LITE_LIBNAME + ".so")) {
LOGGER.info(String.format(Locale.ENGLISH,"extract %d bytes to %s", copyLib(libResource, tmpFile),
tmpFile));
if (("lib" + MINDSPORE_LITE_LIBNAME + ".so").equals(libName)) {
extractLib(makeResourceName("lib" + MSPLUGIN_GE_LITERT_LIBNAME + ".so"), tmpDir);
extractLib(makeResourceName("lib" + RUNTIME_CONVERT_PLUGIN_LIBNAME + ".so"), tmpDir);
}
System.load(tmpFile.toString());
} catch (IOException e) {
throw new UnsatisfiedLinkError(
String.format("extract library into tmp file (%s) failed.", e.toString()));
String.format(Locale.ENGLISH,"extract library into tmp file (%s) failed.", e));
}
}
private static long copyLib(InputStream libResource, File tmpFile) throws IOException{
private static long copyLib(InputStream libResource, File tmpFile) throws IOException {
try (FileOutputStream outputStream = new FileOutputStream(tmpFile);) {
// 1MB
byte[] buffer = new byte[1 << 20];
@ -143,9 +164,8 @@ public class NativeLibrary {
private static File mkTmpDir() {
final String MINDSPORE_LITE_LIBS = "mindspore_lite_libs-";
Long timestamp = System.currentTimeMillis();
String dirName = MINDSPORE_LITE_LIBS + timestamp + "-";
String dirName = "mindspore_lite_libs-" + timestamp + "-";
for (int i = 0; i < 10; i++) {
File tmpDir = new File(new File(System.getProperty("java.io.tmpdir")), dirName + i);
if (tmpDir.mkdir()) {
@ -167,7 +187,7 @@ public class NativeLibrary {
private static String architecture() {
final String arch = System.getProperty("os.arch").toLowerCase();
return (arch.equals("amd64")) ? "x86_64" : arch;
return ("amd64".equals(arch)) ? "x86_64" : arch;
}
private static void extractLib(String libResourceName, File targetDir) {

View File

@ -16,6 +16,11 @@
package com.mindspore.config;
/**
* TrainCfg Class
*
* @since v1.0
*/
public class TrainCfg {
static {
try {

View File

@ -37,9 +37,8 @@ public class Version {
LOGGER.info("Version init load ...");
try {
NativeLibrary.load();
} catch (Exception e) {
} catch (UnsatisfiedLinkError e) {
LOGGER.severe("Failed to load MindSporLite native library.");
throw e;
}
}

View File

@ -26,8 +26,15 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Graph_loadModel(JNIEnv *en
MS_LOG(ERROR) << "Model new failed";
return jlong(nullptr);
}
if (ms_file == nullptr) {
MS_LOG(ERROR) << "ms_file from java is nullptr.";
delete graph;
return jlong(nullptr);
}
auto c_ms_file = env->GetStringUTFChars(ms_file, JNI_FALSE);
auto status =
mindspore::Serialization::Load(env->GetStringUTFChars(ms_file, JNI_FALSE), mindspore::ModelType::kMindIR, graph);
mindspore::Serialization::Load(c_ms_file, mindspore::ModelType::kMindIR, graph);
env->ReleaseStringUTFChars(ms_file, c_ms_file);
if (status != mindspore::kSuccess) {
MS_LOG(ERROR) << "Load graph from file failed";
delete graph;

View File

@ -111,23 +111,34 @@ extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
return false;
}
context.reset(c_context_ptr);
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
mindspore::Status status;
if (key_str != NULL) {
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
auto key_len = static_cast<size_t>(env->GetArrayLength(key_str));
char *dec_key_data = new (std::nothrow) char[key_len];
if (dec_key_data == nullptr) {
MS_LOG(ERROR) << "Dec key new failed";
return false;
}
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
if (key_array == nullptr) {
MS_LOG(ERROR) << "key_array is nullptr.";
return false;
}
for (size_t i = 0; i < key_len; i++) {
dec_key_data[i] = key_array[i];
}
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
mindspore::Key dec_key{dec_key_data, key_len};
if (cropto_lib_path == nullptr || dec_mod == nullptr) {
MS_LOG(ERROR) << "cropto_lib_path or dec_mod from java is nullptr.";
return jlong(nullptr);
}
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
env->ReleaseStringUTFChars(cropto_lib_path, c_cropto_lib_path);
env->ReleaseStringUTFChars(dec_mod, c_dec_mod);
delete[] dec_key_data;
} else {
status = lite_model_ptr->Build(model_buf, buffer_len, c_model_type, context);
}
@ -167,26 +178,43 @@ extern "C" JNIEXPORT bool JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *e
return false;
}
context.reset(c_context_ptr);
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
mindspore::Status status;
if (key_str != NULL) {
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
auto key_len = static_cast<size_t>(env->GetArrayLength(key_str));
char *dec_key_data = new (std::nothrow) char[key_len];
if (dec_key_data == nullptr) {
MS_LOG(ERROR) << "Dec key new failed";
env->ReleaseStringUTFChars(model_path, c_model_path);
return false;
}
jchar *key_array = env->GetCharArrayElements(key_str, NULL);
if (key_array == nullptr) {
MS_LOG(ERROR) << "GetCharArrayElements failed.";
env->ReleaseStringUTFChars(model_path, c_model_path);
return jlong(nullptr);
}
for (size_t i = 0; i < key_len; i++) {
dec_key_data[i] = key_array[i];
}
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
mindspore::Key dec_key{dec_key_data, key_len};
if (dec_mod == nullptr || cropto_lib_path == nullptr) {
MS_LOG(ERROR) << "dec_mod, cropto_lib_path from java is nullptr.";
env->ReleaseStringUTFChars(model_path, c_model_path);
return jlong(nullptr);
}
auto c_dec_mod = env->GetStringUTFChars(dec_mod, JNI_FALSE);
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
status = lite_model_ptr->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
env->ReleaseStringUTFChars(dec_mod, c_dec_mod);
env->ReleaseStringUTFChars(cropto_lib_path, c_cropto_lib_path);
delete[] dec_key_data;
} else {
status = lite_model_ptr->Build(c_model_path, c_model_type, context);
}
env->ReleaseStringUTFChars(model_path, c_model_path);
if (status != mindspore::kSuccess) {
MS_LOG(ERROR) << "Error status " << static_cast<int>(status) << " during build of model";
return false;
@ -205,6 +233,8 @@ jobject GetInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_ptr, bool is_in
auto *pointer = reinterpret_cast<mindspore::Model *>(model_ptr);
if (pointer == nullptr) {
MS_LOG(ERROR) << "Model pointer from java is nullptr";
env->DeleteLocalRef(array_list);
env->DeleteLocalRef(long_object);
return ret;
}
std::vector<mindspore::MSTensor> tensors;
@ -221,7 +251,10 @@ jobject GetInOrOutTensors(JNIEnv *env, jobject thiz, jlong model_ptr, bool is_in
}
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
env->DeleteLocalRef(tensor_addr);
}
env->DeleteLocalRef(array_list);
env->DeleteLocalRef(long_object);
return ret;
}
@ -233,11 +266,17 @@ jlong GetTensorByInOutName(JNIEnv *env, jlong model_ptr, jstring tensor_name, bo
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
mindspore::MSTensor tensor;
if (is_input) {
tensor = lite_model_ptr->GetInputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
} else {
tensor = lite_model_ptr->GetOutputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE));
if (tensor_name == nullptr) {
MS_LOG(ERROR) << "tensor_name from java is nullptr.";
return jlong(nullptr);
}
auto c_tensor_name = env->GetStringUTFChars(tensor_name, JNI_FALSE);
if (is_input) {
tensor = lite_model_ptr->GetInputByTensorName(c_tensor_name);
} else {
tensor = lite_model_ptr->GetOutputByTensorName(c_tensor_name);
}
env->ReleaseStringUTFChars(tensor_name, c_tensor_name);
if (tensor.impl() == nullptr) {
return jlong(nullptr);
}
@ -283,8 +322,11 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputTensorNam
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto output_names = lite_model_ptr->GetOutputTensorNames();
for (const auto &output_name : output_names) {
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
auto output_name_jstring = env->NewStringUTF(output_name.c_str());
env->CallBooleanMethod(ret, array_list_add, output_name_jstring);
env->DeleteLocalRef(output_name_jstring);
}
env->DeleteLocalRef(array_list);
return ret;
}
@ -303,7 +345,13 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputsByNodeNa
return ret;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto tensors = lite_model_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE));
if (node_name == nullptr) {
MS_LOG(ERROR) << "node_name from java is nullptr";
return ret;
}
auto c_node_name = env->GetStringUTFChars(node_name, JNI_FALSE);
auto tensors = lite_model_ptr->GetOutputsByNodeName(c_node_name);
env->ReleaseStringUTFChars(node_name, c_node_name);
for (auto &tensor : tensors) {
auto tensor_ptr = std::make_unique<mindspore::MSTensor>(tensor);
if (tensor_ptr == nullptr) {
@ -312,7 +360,10 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputsByNodeNa
}
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
env->DeleteLocalRef(tensor_addr);
}
env->DeleteLocalRef(array_list);
env->DeleteLocalRef(long_object);
return ret;
}
@ -351,18 +402,24 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_runStep(JNIEnv *e
}
std::vector<mindspore::MSTensor> convertArrayToVector(JNIEnv *env, jlongArray inputs) {
std::vector<mindspore::MSTensor> c_inputs;
if (inputs == nullptr) {
MS_LOG(ERROR) << "inputs from java is nullptr";
return c_inputs;
}
auto input_size = static_cast<int>(env->GetArrayLength(inputs));
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
std::vector<mindspore::MSTensor> c_inputs;
for (int i = 0; i < input_size; i++) {
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
if (tensor_pointer == nullptr) {
MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
return c_inputs;
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(tensor_pointer);
c_inputs.push_back(*ms_tensor_ptr);
}
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
return c_inputs;
}
@ -389,14 +446,22 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
return (jboolean) false;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
if (inputs == nullptr || dims == nullptr) {
MS_LOG(ERROR) << "inputs or dims from java is nullptr";
return (jboolean) false;
}
auto input_size = static_cast<int>(env->GetArrayLength(inputs));
jlong *input_data = env->GetLongArrayElements(inputs, nullptr);
if (input_data == nullptr) {
MS_LOG(ERROR) << "input_data is nullptr";
return (jboolean) false;
}
std::vector<mindspore::MSTensor> c_inputs;
for (int i = 0; i < input_size; i++) {
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
if (tensor_pointer == nullptr) {
MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
return (jboolean) false;
}
auto &ms_tensor = *static_cast<mindspore::MSTensor *>(tensor_pointer);
@ -405,8 +470,19 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
auto tensor_size = static_cast<int>(env->GetArrayLength(dims));
for (int i = 0; i < tensor_size; i++) {
auto array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i));
if (array == nullptr) {
MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
return (jboolean) false;
}
auto dim_size = static_cast<int>(env->GetArrayLength(array));
jint *dim_data = env->GetIntArrayElements(array, nullptr);
if (dim_data == nullptr) {
MS_LOG(ERROR) << "dim_data is nullptr";
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
env->DeleteLocalRef(array);
return (jboolean) false;
}
std::vector<int64_t> tensor_dims(dim_size);
for (int j = 0; j < dim_size; j++) {
tensor_dims[j] = dim_data[j];
@ -416,16 +492,17 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
env->DeleteLocalRef(array);
}
auto ret = lite_model_ptr->Resize(c_inputs, c_dims);
env->ReleaseLongArrayElements(inputs, input_data, JNI_ABORT);
return (jboolean)(ret.IsOk());
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_loadConfig(JNIEnv *env, jobject thiz, jstring model_ptr,
jstring config_path) {
auto *model_pointer = reinterpret_cast<void *>(model_ptr);
if (model_pointer == nullptr) {
MS_LOG(ERROR) << "Model pointer from java is nullptr";
if (model_ptr == nullptr || config_path == nullptr) {
MS_LOG(ERROR) << "input params from java is nullptr";
return (jboolean) false;
}
auto *model_pointer = reinterpret_cast<void *>(model_ptr);
auto *lite_model_ptr = static_cast<mindspore::Model *>(model_pointer);
const char *c_config_path = env->GetStringUTFChars(config_path, nullptr);
std::string str_config_path(c_config_path, env->GetStringLength(config_path));

View File

@ -121,6 +121,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_addDev
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadNum(JNIEnv *env, jobject thiz,
jlong context_ptr, jint thread_num) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
c_context_ptr->SetThreadNum(thread_num);
}
@ -132,6 +136,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadN
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadNum(JNIEnv *env, jobject thiz,
jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return 0;
}
int32_t thread_num = c_context_ptr->GetThreadNum();
return thread_num;
}
@ -145,6 +153,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setInterOp
jlong context_ptr,
jint op_parallel_num) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
c_context_ptr->SetInterOpParallelNum((int32_t)op_parallel_num);
}
@ -156,6 +168,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setInterOp
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getInterOpParallelNum(JNIEnv *env, jobject thiz,
jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return 0;
}
auto inter_op_parallel_num = c_context_ptr->GetInterOpParallelNum();
return inter_op_parallel_num;
}
@ -169,6 +185,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadA
jlong context_ptr,
jint thread_affinity) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
c_context_ptr->SetThreadAffinity(thread_affinity);
}
@ -180,6 +200,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadA
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadAffinityMode(JNIEnv *env, jobject thiz,
jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return 0;
}
auto thread_affinity_mode = c_context_ptr->GetThreadAffinityMode();
return thread_affinity_mode;
}
@ -192,7 +216,15 @@ extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_config_MSContext_getThreadA
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setThreadAffinity__J_3I(JNIEnv *env, jobject thiz,
jlong context_ptr,
jintArray core_list) {
if (core_list == nullptr) {
MS_LOG(ERROR) << "core_list from java is nullptr";
return;
}
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
int32_t array_len = env->GetArrayLength(core_list);
jboolean is_copy = JNI_FALSE;
int *core_value = env->GetIntArrayElements(core_list, &is_copy);
@ -211,6 +243,10 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_config_MSContext_getThre
jobject thiz,
jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return nullptr;
}
std::vector<int32_t> core_list_tmp = c_context_ptr->GetThreadAffinityCoreList();
jobject core_list = newObjectArrayList<int32_t>(env, core_list_tmp, "java/lang/Integer", "(I)V");
return core_list;
@ -225,6 +261,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setEnableP
jlong context_ptr,
jboolean is_parallel) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
c_context_ptr->SetEnableParallel(static_cast<bool>(is_parallel));
}
@ -236,6 +276,10 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_setEnableP
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_getEnableParallel(JNIEnv *env, jobject thiz,
jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return (jboolean) false;
}
bool is_parallel = c_context_ptr->GetEnableParallel();
return (jboolean)is_parallel;
}
@ -243,5 +287,9 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_getEna
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_free(JNIEnv *env, jobject thiz,
jlong context_ptr) {
auto *c_context_ptr = static_cast<mindspore::Context *>(reinterpret_cast<void *>(context_ptr));
if (c_context_ptr == nullptr) {
MS_LOG(ERROR) << "Context pointer from java is nullptr";
return;
}
delete (c_context_ptr);
}

View File

@ -30,6 +30,10 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_MSTensor_getShape(JNIE
auto local_shape = ms_tensor_ptr->Shape();
auto shape_size = local_shape.size();
jintArray shape = env->NewIntArray(shape_size);
if (shape == nullptr) {
MS_LOG(ERROR) << "new intArray failed.";
return env->NewIntArray(0);
}
auto *tmp = new jint[shape_size];
for (size_t i = 0; i < shape_size; i++) {
tmp[i] = static_cast<int>(local_shape.at(i));
@ -69,6 +73,10 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_MSTensor_getByteData(
return env->NewByteArray(0);
}
auto ret = env->NewByteArray(local_size);
if (ret == nullptr) {
MS_LOG(ERROR) << "malloc failed.";
return env->NewByteArray(0);
}
env->SetByteArrayRegion(ret, 0, local_size, local_data);
return ret;
}
@ -99,6 +107,10 @@ extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_MSTensor_getLongData(
return env->NewLongArray(0);
}
auto ret = env->NewLongArray(local_element_num);
if (ret == nullptr) {
MS_LOG(ERROR) << "malloc failed.";
return env->NewLongArray(0);
}
env->SetLongArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
@ -129,6 +141,10 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_MSTensor_getIntData(JN
return env->NewIntArray(0);
}
auto ret = env->NewIntArray(local_element_num);
if (ret == nullptr) {
MS_LOG(ERROR) << "malloc failed.";
return env->NewIntArray(0);
}
env->SetIntArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
@ -159,6 +175,10 @@ extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_MSTensor_getFloatDat
return env->NewFloatArray(0);
}
auto ret = env->NewFloatArray(local_element_num);
if (ret == nullptr) {
MS_LOG(ERROR) << "malloc failed.";
return env->NewFloatArray(0);
}
env->SetFloatArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
@ -177,8 +197,18 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setByteData(JN
return static_cast<jboolean>(false);
}
jboolean is_copy = false;
if (data == nullptr) {
MS_LOG(ERROR) << "data from java is nullptr.";
return static_cast<jboolean>(false);
}
auto *data_arr = env->GetByteArrayElements(data, &is_copy);
auto *local_data = ms_tensor_ptr->MutableData();
if (data_arr == nullptr || local_data == nullptr) {
MS_LOG(ERROR) << "data_arr or local_data is nullptr.";
env->ReleaseByteArrayElements(data, data_arr, JNI_ABORT);
return static_cast<jboolean>(false);
}
memcpy(local_data, data_arr, data_len);
env->ReleaseByteArrayElements(data, data_arr, JNI_ABORT);
return static_cast<jboolean>(true);
@ -208,6 +238,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setFloatData(J
MS_LOG(ERROR) << "malloc memory failed.";
return static_cast<jboolean>(false);
}
if (data == nullptr) {
MS_LOG(ERROR) << "data from java is nullptr";
return static_cast<jboolean>(false);
}
env->GetFloatArrayRegion(data, 0, static_cast<jsize>(data_len), local_data);
return static_cast<jboolean>(true);
}
@ -236,6 +270,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setIntData(JNI
MS_LOG(ERROR) << "malloc memory failed.";
return static_cast<jboolean>(false);
}
if (data == nullptr) {
MS_LOG(ERROR) << "data from java is nullptr";
return static_cast<jboolean>(false);
}
env->GetIntArrayRegion(data, 0, static_cast<jsize>(data_len), local_data);
return static_cast<jboolean>(true);
}
@ -262,6 +300,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setLongData(JN
MS_LOG(ERROR) << "malloc memory failed.";
return static_cast<jboolean>(false);
}
if (data == nullptr) {
MS_LOG(ERROR) << "data from java is nullptr";
return static_cast<jboolean>(false);
}
env->GetLongArrayRegion(data, 0, static_cast<jsize>(data_len), local_data);
return static_cast<jboolean>(true);
}
@ -287,6 +329,10 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setByteBufferD
return static_cast<jboolean>(false);
}
auto *local_data = ms_tensor_ptr->MutableData();
if (local_data == nullptr) {
MS_LOG(ERROR) << "get MutableData nullptr.";
return static_cast<jboolean>(false);
}
memcpy(local_data, p_data, data_len);
return static_cast<jboolean>(true);
}
@ -304,8 +350,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_size(JNIEnv *env,
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setShape(JNIEnv *env, jobject thiz, jlong tensor_ptr,
jintArray tensor_shape) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOG(ERROR) << "Tensor pointer from java is nullptr";
if (pointer == nullptr || tensor_shape == nullptr) {
MS_LOG(ERROR) << "input params from java is nullptr";
return static_cast<jboolean>(false);
}
@ -357,6 +403,11 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_createTensorByNat
jstring tensor_name, jint data_type,
jintArray tensor_shape,
jobject buffer) {
// check inputs
if (buffer == nullptr || tensor_name == nullptr || tensor_shape == nullptr) {
MS_LOG(ERROR) << "input param from java is nullptr";
return 0;
}
auto *p_data = reinterpret_cast<jbyte *>(env->GetDirectBufferAddress(buffer));
jlong data_len = env->GetDirectBufferCapacity(buffer);
if (p_data == nullptr) {
@ -370,11 +421,11 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_createTensorByNat
for (int i = 0; i < size; i++) {
c_shape[i] = static_cast<int64_t>(shape_pointer[i]);
}
env->ReleaseIntArrayElements(tensor_shape, shape_pointer, JNI_ABORT);
const char *c_tensor_name = env->GetStringUTFChars(tensor_name, nullptr);
std::string str_tensor_name(c_tensor_name, env->GetStringLength(tensor_name));
auto tensor = mindspore::MSTensor::CreateTensor(str_tensor_name, static_cast<mindspore::DataType>(data_type), c_shape,
p_data, data_len);
env->ReleaseIntArrayElements(tensor_shape, shape_pointer, JNI_ABORT);
env->ReleaseStringUTFChars(tensor_name, c_tensor_name);
return jlong(tensor);
}

View File

@ -51,8 +51,10 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_TrainCfg_createTrai
}
if (loss_name != nullptr) {
std::vector<std::string> traincfg_loss_name = traincfg_ptr->GetLossName();
traincfg_loss_name.emplace_back(env->GetStringUTFChars(loss_name, JNI_FALSE));
auto c_loss_name = env->GetStringUTFChars(loss_name, JNI_FALSE);
traincfg_loss_name.emplace_back(c_loss_name);
traincfg_ptr->SetLossName(traincfg_loss_name);
env->ReleaseStringUTFChars(loss_name, c_loss_name);
}
traincfg_ptr->optimization_level_ = ol;
traincfg_ptr->accumulate_gradients_ = accmulateGrads;

View File

@ -0,0 +1,63 @@
/*
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import com.mindspore.config.DataType;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.Arrays;
/**
* Model Test
*
* @since 2023.2
*/
@RunWith(JUnit4.class)
public class MSTensorTest {
@Test
public void testCreateTensor() {
int[] tensorShape = {6, 5, 5, 1};
float[] tensorData = new float[6 * 5 * 5];
Arrays.fill(tensorData, 0);
ByteBuffer byteBuf = ByteBuffer.allocateDirect(6 * 5 * 5 * 4);
FloatBuffer floatBuf = byteBuf.asFloatBuffer();
floatBuf.put(tensorData);
MSTensor newTensor = MSTensor.createTensor("conv1.weight", DataType.kNumberTypeFloat32, tensorShape,
byteBuf);
assertNotNull(newTensor);
}
@Test
public void testSetData() {
int[] tensorShape = {6, 5, 5, 1};
ByteBuffer byteBuf = ByteBuffer.allocateDirect(6 * 5 * 5 * 4);
MSTensor newTensor = MSTensor.createTensor("conv1.weight", DataType.kNumberTypeFloat32, tensorShape,
byteBuf);
assertNotNull(newTensor);
float[] tensorData = new float[6 * 5 * 5];
Arrays.fill(tensorData, 0);
assertTrue(newTensor.setData(tensorData));
}
}