From dd25eba9f1563f36a44b476c5826ff146778f57d Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Thu, 16 Dec 2021 17:41:54 +0800 Subject: [PATCH] add cxx java api --- .../src/main/java/com/mindspore/Graph.java | 64 ++++ .../src/main/java/com/mindspore/MSTensor.java | 192 ++++++++++++ .../src/main/java/com/mindspore/Model.java | 277 ++++++++++++++++++ .../com/mindspore/config/CpuBindMode.java | 28 ++ .../java/com/mindspore/config/DataType.java | 39 +++ .../java/com/mindspore/config/DeviceType.java | 27 ++ .../java/com/mindspore/config/MSContext.java | 103 +++++++ .../java/com/mindspore/config/ModelType.java | 29 ++ .../java/com/mindspore/config/TrainCfg.java | 89 ++++++ .../java/com/mindspore/config/Version.java | 35 +++ 10 files changed, 883 insertions(+) create mode 100644 mindspore/lite/java/src/main/java/com/mindspore/Graph.java create mode 100644 mindspore/lite/java/src/main/java/com/mindspore/MSTensor.java create mode 100644 mindspore/lite/java/src/main/java/com/mindspore/Model.java create mode 100644 mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java create mode 100644 mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java create mode 100644 mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java create mode 100644 mindspore/lite/java/src/main/java/com/mindspore/config/MSContext.java create mode 100644 mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java create mode 100644 mindspore/lite/java/src/main/java/com/mindspore/config/TrainCfg.java create mode 100644 mindspore/lite/java/src/main/java/com/mindspore/config/Version.java diff --git a/mindspore/lite/java/src/main/java/com/mindspore/Graph.java b/mindspore/lite/java/src/main/java/com/mindspore/Graph.java new file mode 100644 index 00000000000..7945879204b --- /dev/null +++ b/mindspore/lite/java/src/main/java/com/mindspore/Graph.java @@ -0,0 +1,64 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore; + +public class Graph { + static { + System.loadLibrary("mindspore-lite-jni"); + } + + private long graphPtr; + + /** + * Construct function. + */ + public Graph() { + this.graphPtr = 0; + } + + /** + * Load file. + * + * @param file model file. + * @return load status. + */ + public boolean Load(String file) { + this.graphPtr = load(file); + return this.graphPtr != 0L; + } + + /** + * Get graph pointer. + * + * @return graph pointer. + */ + public long getGraphPtr() { + return this.graphPtr; + } + + /** + * Fre graph pointer. + */ + public void free() { + this.free(graphPtr); + graphPtr = 0; + } + + private native long load(String file); + + private native boolean free(long graphPtr); +} \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/MSTensor.java b/mindspore/lite/java/src/main/java/com/mindspore/MSTensor.java new file mode 100644 index 00000000000..92d446cdc90 --- /dev/null +++ b/mindspore/lite/java/src/main/java/com/mindspore/MSTensor.java @@ -0,0 +1,192 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore; + +import java.nio.ByteBuffer; + +public class MSTensor { + static { + System.loadLibrary("mindspore-lite-jni"); + } + private long tensorPtr; + + /** + * MSTensor construct function. + */ + public MSTensor() { + this.tensorPtr = 0; + } + + /** + * MSTensor construct function. + * + * @param tensorPtr tensor pointer. + */ + public MSTensor(long tensorPtr) { + this.tensorPtr = tensorPtr; + } + + /** + * MSTensor construct function. + * + * @param tensorName tensor name + * @param buffer tensor buffer + */ + public MSTensor(String tensorName, ByteBuffer buffer) { + this.tensorPtr = createTensor(tensorName, buffer); + } + + /** + * Get the shape of the MindSpore MSTensor. + * + * @return A array of int as the shape of the MindSpore MSTensor. + */ + public int[] getShape() { + return this.getShape(this.tensorPtr); + } + + /** + * DataType is defined in com.mindspore.DataType. + * + * @return The MindSpore data type of the MindSpore MSTensor class. + */ + public int getDataType() { + return this.getDataType(this.tensorPtr); + } + + /** + * Get output data of MSTensor, the data type is byte. + * + * @return The byte array containing all MSTensor output data. + */ + public byte[] getByteData() { + return this.getByteData(this.tensorPtr); + } + + /** + * Get output data of MSTensor, the data type is float. + * + * @return The float array containing all MSTensor output data. + */ + public float[] getFloatData() { + return this.getFloatData(this.tensorPtr); + } + + /** + * Get output data of MSTensor, the data type is int. + * + * @return The int array containing all MSTensor output data. + */ + public int[] getIntData() { + return this.getIntData(this.tensorPtr); + } + + /** + * Get output data of MSTensor, the data type is long. + * + * @return The long array containing all MSTensor output data. + */ + public long[] getLongData() { + return this.getLongData(this.tensorPtr); + } + + /** + * Set the input data of MSTensor. + * + * @param data Input data of byte[] type. + * @return whether set data success. + */ + public boolean setData(byte[] data) { + return this.setData(this.tensorPtr, data, data.length); + } + + /** + * Set the input data of MSTensor. + * + * @param data data Input data of ByteBuffer type + * @return whether set data success. + */ + public boolean setData(ByteBuffer data) { + return this.setByteBufferData(this.tensorPtr, data); + } + + /** + * Get the size of the data in MSTensor in bytes. + * + * @return The size of the data in MSTensor in bytes. + */ + public long size() { + return this.size(this.tensorPtr); + } + + /** + * Get the number of elements in MSTensor. + * + * @return The number of elements in MSTensor. + */ + public int elementsNum() { + return this.elementsNum(this.tensorPtr); + } + + /** + * Free all temporary memory in MindSpore MSTensor. + */ + public void free() { + this.free(this.tensorPtr); + this.tensorPtr = 0; + } + + /** + * @return Get tensor name + */ + public String tensorName() { + return this.tensorName(this.tensorPtr); + } + + /** + * @return MSTensor pointer + */ + public long getMSTensorPtr() { + return tensorPtr; + } + + private native long createTensor(String tensorName, ByteBuffer buffer); + + private native int[] getShape(long tensorPtr); + + private native int getDataType(long tensorPtr); + + private native byte[] getByteData(long tensorPtr); + + private native long[] getLongData(long tensorPtr); + + private native int[] getIntData(long tensorPtr); + + private native float[] getFloatData(long tensorPtr); + + private native boolean setData(long tensorPtr, byte[] data, long dataLen); + + private native boolean setByteBufferData(long tensorPtr, ByteBuffer buffer); + + private native long size(long tensorPtr); + + private native int elementsNum(long tensorPtr); + + private native void free(long tensorPtr); + + private native String tensorName(long tensorPtr); +} \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/Model.java b/mindspore/lite/java/src/main/java/com/mindspore/Model.java new file mode 100644 index 00000000000..0709738a686 --- /dev/null +++ b/mindspore/lite/java/src/main/java/com/mindspore/Model.java @@ -0,0 +1,277 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore; + +import com.mindspore.config.ModelType; +import com.mindspore.config.MSContext; +import com.mindspore.config.TrainCfg; + +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.List; + +public class Model { + static { + System.loadLibrary("mindspore-lite-jni"); + } + + private long modelPtr = 0; + + /** + * Construct function. + */ + public Model() { + this.modelPtr = 0; + } + + /** + * Build model by graph. + * + * @param graph graph contains the buffer. + * @param context model build context. + * @param cfg model build train config.used for train. + * @return build status. + */ + public boolean build(Graph graph, MSContext context, TrainCfg cfg) { + modelPtr = this.buildByGraph(graph.getGraphPtr(), context.getMSContextPtr(), cfg.getTrainCfgPtr()); + return modelPtr != 0; + } + + /** + * Build model. + * + * @param buffer model buffer. + * @param modelType model type. + * @param context model build context. + * @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32. + * @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC. + * @return model build status. + */ + public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, + String dec_mode) { + modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode); + return modelPtr != 0; + } + + /** + * Build model. + * + * @param buffer model buffer. + * @param modelType model type. + * @param context model build context. + * @return model build status. + */ + public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context) { + modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, ""); + return modelPtr != 0; + } + + + /** + * Build model. + * + * @param modelPath model path. + * @param modelType model type. + * @param context model build context. + * @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32. + * @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC. + * @return model build status. + */ + public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode) { + modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode); + return modelPtr != 0; + } + + /** + * Build model. + * + * @param modelPath model path. + * @param modelType model type. + * @param context model build context. + * @return build status. + */ + public boolean build(String modelPath, int modelType, MSContext context) { + modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, ""); + return modelPtr != 0; + } + + /** + * Resize inputs shape. + * + * @param inputs Model inputs. + * @param dims Define the new inputs shape. + * @return Whether the resize is successful. + */ + public boolean resize(List inputs, int[][] dims) { + long[] inputsArray = new long[inputs.size()]; + for (int i = 0; i < inputs.size(); i++) { + inputsArray[i] = inputs.get(i).getMSTensorPtr(); + } + return this.resize(this.modelPtr, inputsArray, dims); + } + + /** + * Get model inputs tensor. + * + * @return input tensors. + */ + public List getInputs() { + List ret = this.getInputs(this.modelPtr); + List tensors = new ArrayList<>(); + for (Long msTensorAddr : ret) { + MSTensor msTensor = new MSTensor(msTensorAddr); + tensors.add(msTensor); + } + return tensors; + } + + /** + * Get model outputs. + * + * @return model outputs tensor. + */ + public List getOutputs() { + List ret = this.getOutputs(this.modelPtr); + List tensors = new ArrayList<>(); + for (Long msTensorAddr : ret) { + MSTensor msTensor = new MSTensor(msTensorAddr); + tensors.add(msTensor); + } + return tensors; + } + + /** + * Get input tensor by tensor name. + * + * @param tensorName name. + * @return input tensor. + */ + public MSTensor getInputByTensorName(String tensorName) { + long tensorAddr = this.getInputByTensorName(this.modelPtr, tensorName); + return new MSTensor(tensorAddr); + } + + /** + * Get output tensor by tensor name. + * + * @param tensorName output tensor name + * @return output tensor + */ + public MSTensor getOutputByTensorName(String tensorName) { + long tensorAddr = this.getOutputByTensorName(this.modelPtr, tensorName); + return new MSTensor(tensorAddr); + } + + /** + * Export the model. + * + * @param fileName Name Model file name. + * @param quantizationType The quant type.0,no_quant,1,weight_quant,2,full_quant. + * @param isOnlyExportInfer if export only inferece. + * @param outputTensorNames tensor name used for export inference graph. + * @return Whether the export is successful. + */ + public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer, + List outputTensorNames) { + return export(modelPtr, fileName, quantizationType, isOnlyExportInfer, outputTensorNames); + } + + /** + * Get the FeatureMap. + * + * @return FeaturesMap Tensor list. + */ + public List getFeatureMaps() { + List ret = this.getFeatureMaps(this.modelPtr); + ArrayList tensors = new ArrayList<>(); + for (Long msTensorAddr : ret) { + MSTensor msTensor = new MSTensor(msTensorAddr); + tensors.add(msTensor); + } + return tensors; + } + + /** + * Update model Features. + * + * @param features new FeatureMap Tensor List. + * @return Whether the model features is successfully update. + */ + public boolean updateFeatureMaps(List features) { + long[] inputsArray = new long[features.size()]; + for (int i = 0; i < features.size(); i++) { + inputsArray[i] = features.get(i).getMSTensorPtr(); + } + return this.updateFeatureMaps(modelPtr, inputsArray); + } + + /** + * Set model work train mode + * + * @param isTrain is train mode.true work train mode. + * @return set status. + */ + public boolean setTrainMode(boolean isTrain) { + return this.setTrainMode(modelPtr, isTrain); + } + + /** + * Get train mode + * + * @return train mode. + */ + public boolean getTrainMode() { + return this.getTrainMode(modelPtr); + } + + /** + * Free model + */ + public void free() { + this.free(modelPtr); + } + + private native void free(long modelPtr); + + private native long buildByGraph(long graphPtr, long contextPtr, long cfgPtr); + + private native long buildByPath(String modelPath, int modelType, long contextPtr, char[] dec_key, String dec_mod); + + private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr, char[] dec_key, + String dec_mod); + + private native List getInputs(long modelPtr); + + private native long getInputByTensorName(long modelPtr, String tensorName); + + private native List getOutputs(long modelPtr); + + private native long getOutputByTensorName(long modelPtr, String tensorName); + + private native boolean setTrainMode(long modelPtr, boolean isTrain); + + private native boolean getTrainMode(long modelPtr); + + private native boolean resize(long modelPtr, long[] inputs, int[][] dims); + + private native boolean export(long modelPtr, String fileName, int quantizationType, boolean isOnlyExportInfer, + List outputTensorNames); + + private native List getFeatureMaps(long modelPtr); + + private native boolean updateFeatureMaps(long modelPtr, long[] newFeatures); +} \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java b/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java new file mode 100644 index 00000000000..97175153338 --- /dev/null +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java @@ -0,0 +1,28 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.config; + +/** + * Define Cpu core bind mode + * + * @since v1.0 + */ +public class CpuBindMode { + public static final int MID_CPU = 2; + public static final int HIGHER_CPU = 1; + public static final int NO_BIND = 0; +} \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java b/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java new file mode 100644 index 00000000000..1c8232b9721 --- /dev/null +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java @@ -0,0 +1,39 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.config; + +/** + * Define tensor data type. + * + * @since v1.0 + */ +public class DataType { + public static final int kNumberTypeBool = 30; + public static final int kNumberTypeInt = 31; + public static final int kNumberTypeInt8 = 32; + public static final int kNumberTypeInt16 = 33; + public static final int kNumberTypeInt32 = 34; + public static final int kNumberTypeInt64 = 35; + public static final int kNumberTypeUInt = 36; + public static final int kNumberTypeUInt8 = 37; + public static final int kNumberTypeUInt16 = 38; + public static final int kNumberTypeUint32 = 39; + public static final int kNumberTypeUInt64 = 40; + public static final int kNumberTypeFloat = 41; + public static final int kNumberTypeFloat16 = 42; + public static final int kNumberTypeFloat32 = 43; + public static final int kNumberTypeFloat64 = 44; +} \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java b/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java new file mode 100644 index 00000000000..389c2428e43 --- /dev/null +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java @@ -0,0 +1,27 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.config; + +/** + * define device type + * + * @since v1.0 + */ +public class DeviceType { + public static final int DT_CPU = 0; + public static final int DT_GPU = 1; + public static final int DT_NPU = 2; +} \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/MSContext.java b/mindspore/lite/java/src/main/java/com/mindspore/config/MSContext.java new file mode 100644 index 00000000000..d4a28b1bc52 --- /dev/null +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/MSContext.java @@ -0,0 +1,103 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.config; + +public class MSContext { + static { + System.loadLibrary("mindspore-lite-jni"); + } + + private long msContextPtr; + + /** + * Construct function. + */ + public MSContext() { + this.msContextPtr = 0; + } + + /** + * Add device info to context. + * + * @param deviceType support cpu,npu and gpu. + * @param isEnableFloat16 whether to use float16 operator for priority. + * @param npuFreq npu frequency used for npu device. + * @return add status. + */ + public boolean addDeviceInfo(int deviceType, boolean isEnableFloat16, int npuFreq) { + return addDeviceInfo(msContextPtr, deviceType, isEnableFloat16, npuFreq); + } + + /** + * Add device info to context. + * + * @param deviceType support cpu,npu and gpu. + * @param isEnableFloat16 whether to use float16 operator for priority. + * @return add status. + */ + public boolean addDeviceInfo(int deviceType, boolean isEnableFloat16) { + return addDeviceInfo(msContextPtr, deviceType, isEnableFloat16, 3); + } + + /** + * Init Context. + * + * @param threadNum thread nums. + * @param cpuBindMode support bind high,mid cpu.0,no bind.1,bind mid cpu.2. bind high cpu. + * @return init status. + */ + public boolean init(int threadNum, int cpuBindMode) { + this.msContextPtr = createMSContext(threadNum, cpuBindMode, false); + return this.msContextPtr != 0; + } + + /** + * Init Context. + * + * @param threadNum thread nums. + * @param cpuBindMode support bind high,mid cpu.0,no bind.1,bind mid cpu.2. bind high cpu. + * @param isEnableParallel enable parallel in multi devices. + * @return init status. + */ + public boolean init(int threadNum, int cpuBindMode, boolean isEnableParallel) { + this.msContextPtr = createMSContext(threadNum, cpuBindMode, isEnableParallel); + return this.msContextPtr != 0; + } + + /** + * Free context. + */ + public void free() { + this.free(this.msContextPtr); + this.msContextPtr = 0; + } + + /** + * Get context pointer. + * + * @return context pointer. + */ + public long getMSContextPtr() { + return msContextPtr; + } + + private native long createMSContext(int threadNum, int cpuBindMode, boolean enableParallel); + + private native boolean addDeviceInfo(long msContextPtr, int deviceType, boolean isEnableFloat16, int npuFrequency); + + private native void free(long msContextPtr); +} \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java b/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java new file mode 100644 index 00000000000..88a3e1b6bdd --- /dev/null +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java @@ -0,0 +1,29 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.config; + +/** + * define model type + * + * @since v1.0 + */ +public class ModelType { + public static final int MT_MINDIR = 0; + public static final int MT_AIR = 1; + public static final int MT_OM = 2; + public static final int MT_ONNX = 3; + public static final int MT_MINDIR_OPT = 4; +} \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/TrainCfg.java b/mindspore/lite/java/src/main/java/com/mindspore/config/TrainCfg.java new file mode 100644 index 00000000000..ae2a10105b6 --- /dev/null +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/TrainCfg.java @@ -0,0 +1,89 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.config; + +public class TrainCfg { + static { + System.loadLibrary("mindspore-lite-train-jni"); + } + + private long trainCfgPtr; + + /** + * Construct function. + */ + public TrainCfg() { + this.trainCfgPtr = 0; + } + + /** + * Init train config. + * + * @return init status. + */ + public boolean init() { + this.trainCfgPtr = createTrainCfg(null, 0, false); + return this.trainCfgPtr != 0; + } + + /** + * Init train config specified loss name. + * + * @param loss_name loss name used for split inference and train part. + * @return init status. + */ + public boolean init(String loss_name) { + this.trainCfgPtr = createTrainCfg(loss_name, 0, false); + return this.trainCfgPtr != 0; + } + + /** + * Free train config. + */ + public void free() { + this.free(this.trainCfgPtr); + this.trainCfgPtr = 0; + } + + /** + * Add mix precision config to train config. + * + * @param dynamicLossScale if dynamic or fix loss scale factor. + * @param lossScale loss scale factor. + * @param thresholdIterNum a threshold for modifying loss scale when dynamic loss scale is enabled. + * @return add status. + */ + public boolean addMixPrecisionCfg(boolean dynamicLossScale, float lossScale, int thresholdIterNum) { + return addMixPrecisionCfg(trainCfgPtr, dynamicLossScale, lossScale, thresholdIterNum); + } + + /** + * Get train config pointer. + * + * @return train config pointer. + */ + public long getTrainCfgPtr() { + return trainCfgPtr; + } + + private native long createTrainCfg(String lossName, int optimizationLevel, boolean accmulateGrads); + + private native boolean addMixPrecisionCfg(long trainCfgPtr, boolean dynamicLossScale, float lossScale, + int thresholdIterNum); + + private native void free(long trainCfgPtr); +} \ No newline at end of file diff --git a/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java b/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java new file mode 100644 index 00000000000..f1f3e9abc27 --- /dev/null +++ b/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java @@ -0,0 +1,35 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.config; + +/** + * Define mindspore version info. + * + * @since v1.0 + */ +public class Version { + static { + System.loadLibrary("mindspore-lite-jni"); + } + + /** + * Get MindSpore Lite version info. + * + * @return MindSpore Lite version info. + */ + public static native String version(); +} \ No newline at end of file