forked from mindspore-Ecosystem/mindspore
add cxx java api
This commit is contained in:
parent
48c3be7953
commit
dd25eba9f1
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore;
|
||||
|
||||
public class Graph {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
private long graphPtr;
|
||||
|
||||
/**
|
||||
* Construct function.
|
||||
*/
|
||||
public Graph() {
|
||||
this.graphPtr = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load file.
|
||||
*
|
||||
* @param file model file.
|
||||
* @return load status.
|
||||
*/
|
||||
public boolean Load(String file) {
|
||||
this.graphPtr = load(file);
|
||||
return this.graphPtr != 0L;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get graph pointer.
|
||||
*
|
||||
* @return graph pointer.
|
||||
*/
|
||||
public long getGraphPtr() {
|
||||
return this.graphPtr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fre graph pointer.
|
||||
*/
|
||||
public void free() {
|
||||
this.free(graphPtr);
|
||||
graphPtr = 0;
|
||||
}
|
||||
|
||||
private native long load(String file);
|
||||
|
||||
private native boolean free(long graphPtr);
|
||||
}
|
|
@ -0,0 +1,192 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
public class MSTensor {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
private long tensorPtr;
|
||||
|
||||
/**
|
||||
* MSTensor construct function.
|
||||
*/
|
||||
public MSTensor() {
|
||||
this.tensorPtr = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* MSTensor construct function.
|
||||
*
|
||||
* @param tensorPtr tensor pointer.
|
||||
*/
|
||||
public MSTensor(long tensorPtr) {
|
||||
this.tensorPtr = tensorPtr;
|
||||
}
|
||||
|
||||
/**
|
||||
* MSTensor construct function.
|
||||
*
|
||||
* @param tensorName tensor name
|
||||
* @param buffer tensor buffer
|
||||
*/
|
||||
public MSTensor(String tensorName, ByteBuffer buffer) {
|
||||
this.tensorPtr = createTensor(tensorName, buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the shape of the MindSpore MSTensor.
|
||||
*
|
||||
* @return A array of int as the shape of the MindSpore MSTensor.
|
||||
*/
|
||||
public int[] getShape() {
|
||||
return this.getShape(this.tensorPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* DataType is defined in com.mindspore.DataType.
|
||||
*
|
||||
* @return The MindSpore data type of the MindSpore MSTensor class.
|
||||
*/
|
||||
public int getDataType() {
|
||||
return this.getDataType(this.tensorPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get output data of MSTensor, the data type is byte.
|
||||
*
|
||||
* @return The byte array containing all MSTensor output data.
|
||||
*/
|
||||
public byte[] getByteData() {
|
||||
return this.getByteData(this.tensorPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get output data of MSTensor, the data type is float.
|
||||
*
|
||||
* @return The float array containing all MSTensor output data.
|
||||
*/
|
||||
public float[] getFloatData() {
|
||||
return this.getFloatData(this.tensorPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get output data of MSTensor, the data type is int.
|
||||
*
|
||||
* @return The int array containing all MSTensor output data.
|
||||
*/
|
||||
public int[] getIntData() {
|
||||
return this.getIntData(this.tensorPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get output data of MSTensor, the data type is long.
|
||||
*
|
||||
* @return The long array containing all MSTensor output data.
|
||||
*/
|
||||
public long[] getLongData() {
|
||||
return this.getLongData(this.tensorPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the input data of MSTensor.
|
||||
*
|
||||
* @param data Input data of byte[] type.
|
||||
* @return whether set data success.
|
||||
*/
|
||||
public boolean setData(byte[] data) {
|
||||
return this.setData(this.tensorPtr, data, data.length);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the input data of MSTensor.
|
||||
*
|
||||
* @param data data Input data of ByteBuffer type
|
||||
* @return whether set data success.
|
||||
*/
|
||||
public boolean setData(ByteBuffer data) {
|
||||
return this.setByteBufferData(this.tensorPtr, data);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the size of the data in MSTensor in bytes.
|
||||
*
|
||||
* @return The size of the data in MSTensor in bytes.
|
||||
*/
|
||||
public long size() {
|
||||
return this.size(this.tensorPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of elements in MSTensor.
|
||||
*
|
||||
* @return The number of elements in MSTensor.
|
||||
*/
|
||||
public int elementsNum() {
|
||||
return this.elementsNum(this.tensorPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Free all temporary memory in MindSpore MSTensor.
|
||||
*/
|
||||
public void free() {
|
||||
this.free(this.tensorPtr);
|
||||
this.tensorPtr = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Get tensor name
|
||||
*/
|
||||
public String tensorName() {
|
||||
return this.tensorName(this.tensorPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return MSTensor pointer
|
||||
*/
|
||||
public long getMSTensorPtr() {
|
||||
return tensorPtr;
|
||||
}
|
||||
|
||||
private native long createTensor(String tensorName, ByteBuffer buffer);
|
||||
|
||||
private native int[] getShape(long tensorPtr);
|
||||
|
||||
private native int getDataType(long tensorPtr);
|
||||
|
||||
private native byte[] getByteData(long tensorPtr);
|
||||
|
||||
private native long[] getLongData(long tensorPtr);
|
||||
|
||||
private native int[] getIntData(long tensorPtr);
|
||||
|
||||
private native float[] getFloatData(long tensorPtr);
|
||||
|
||||
private native boolean setData(long tensorPtr, byte[] data, long dataLen);
|
||||
|
||||
private native boolean setByteBufferData(long tensorPtr, ByteBuffer buffer);
|
||||
|
||||
private native long size(long tensorPtr);
|
||||
|
||||
private native int elementsNum(long tensorPtr);
|
||||
|
||||
private native void free(long tensorPtr);
|
||||
|
||||
private native String tensorName(long tensorPtr);
|
||||
}
|
|
@ -0,0 +1,277 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore;
|
||||
|
||||
import com.mindspore.config.ModelType;
|
||||
import com.mindspore.config.MSContext;
|
||||
import com.mindspore.config.TrainCfg;
|
||||
|
||||
import java.nio.MappedByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class Model {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
private long modelPtr = 0;
|
||||
|
||||
/**
|
||||
* Construct function.
|
||||
*/
|
||||
public Model() {
|
||||
this.modelPtr = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build model by graph.
|
||||
*
|
||||
* @param graph graph contains the buffer.
|
||||
* @param context model build context.
|
||||
* @param cfg model build train config.used for train.
|
||||
* @return build status.
|
||||
*/
|
||||
public boolean build(Graph graph, MSContext context, TrainCfg cfg) {
|
||||
modelPtr = this.buildByGraph(graph.getGraphPtr(), context.getMSContextPtr(), cfg.getTrainCfgPtr());
|
||||
return modelPtr != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build model.
|
||||
*
|
||||
* @param buffer model buffer.
|
||||
* @param modelType model type.
|
||||
* @param context model build context.
|
||||
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
||||
* @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC.
|
||||
* @return model build status.
|
||||
*/
|
||||
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key,
|
||||
String dec_mode) {
|
||||
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode);
|
||||
return modelPtr != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build model.
|
||||
*
|
||||
* @param buffer model buffer.
|
||||
* @param modelType model type.
|
||||
* @param context model build context.
|
||||
* @return model build status.
|
||||
*/
|
||||
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context) {
|
||||
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, "");
|
||||
return modelPtr != 0;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Build model.
|
||||
*
|
||||
* @param modelPath model path.
|
||||
* @param modelType model type.
|
||||
* @param context model build context.
|
||||
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
||||
* @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC.
|
||||
* @return model build status.
|
||||
*/
|
||||
public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode) {
|
||||
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode);
|
||||
return modelPtr != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build model.
|
||||
*
|
||||
* @param modelPath model path.
|
||||
* @param modelType model type.
|
||||
* @param context model build context.
|
||||
* @return build status.
|
||||
*/
|
||||
public boolean build(String modelPath, int modelType, MSContext context) {
|
||||
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, "");
|
||||
return modelPtr != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resize inputs shape.
|
||||
*
|
||||
* @param inputs Model inputs.
|
||||
* @param dims Define the new inputs shape.
|
||||
* @return Whether the resize is successful.
|
||||
*/
|
||||
public boolean resize(List<MSTensor> inputs, int[][] dims) {
|
||||
long[] inputsArray = new long[inputs.size()];
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
inputsArray[i] = inputs.get(i).getMSTensorPtr();
|
||||
}
|
||||
return this.resize(this.modelPtr, inputsArray, dims);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model inputs tensor.
|
||||
*
|
||||
* @return input tensors.
|
||||
*/
|
||||
public List<MSTensor> getInputs() {
|
||||
List<Long> ret = this.getInputs(this.modelPtr);
|
||||
List<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model outputs.
|
||||
*
|
||||
* @return model outputs tensor.
|
||||
*/
|
||||
public List<MSTensor> getOutputs() {
|
||||
List<Long> ret = this.getOutputs(this.modelPtr);
|
||||
List<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get input tensor by tensor name.
|
||||
*
|
||||
* @param tensorName name.
|
||||
* @return input tensor.
|
||||
*/
|
||||
public MSTensor getInputByTensorName(String tensorName) {
|
||||
long tensorAddr = this.getInputByTensorName(this.modelPtr, tensorName);
|
||||
return new MSTensor(tensorAddr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get output tensor by tensor name.
|
||||
*
|
||||
* @param tensorName output tensor name
|
||||
* @return output tensor
|
||||
*/
|
||||
public MSTensor getOutputByTensorName(String tensorName) {
|
||||
long tensorAddr = this.getOutputByTensorName(this.modelPtr, tensorName);
|
||||
return new MSTensor(tensorAddr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Export the model.
|
||||
*
|
||||
* @param fileName Name Model file name.
|
||||
* @param quantizationType The quant type.0,no_quant,1,weight_quant,2,full_quant.
|
||||
* @param isOnlyExportInfer if export only inferece.
|
||||
* @param outputTensorNames tensor name used for export inference graph.
|
||||
* @return Whether the export is successful.
|
||||
*/
|
||||
public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer,
|
||||
List<String> outputTensorNames) {
|
||||
return export(modelPtr, fileName, quantizationType, isOnlyExportInfer, outputTensorNames);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the FeatureMap.
|
||||
*
|
||||
* @return FeaturesMap Tensor list.
|
||||
*/
|
||||
public List<MSTensor> getFeatureMaps() {
|
||||
List<Long> ret = this.getFeatureMaps(this.modelPtr);
|
||||
ArrayList<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update model Features.
|
||||
*
|
||||
* @param features new FeatureMap Tensor List.
|
||||
* @return Whether the model features is successfully update.
|
||||
*/
|
||||
public boolean updateFeatureMaps(List<MSTensor> features) {
|
||||
long[] inputsArray = new long[features.size()];
|
||||
for (int i = 0; i < features.size(); i++) {
|
||||
inputsArray[i] = features.get(i).getMSTensorPtr();
|
||||
}
|
||||
return this.updateFeatureMaps(modelPtr, inputsArray);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set model work train mode
|
||||
*
|
||||
* @param isTrain is train mode.true work train mode.
|
||||
* @return set status.
|
||||
*/
|
||||
public boolean setTrainMode(boolean isTrain) {
|
||||
return this.setTrainMode(modelPtr, isTrain);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get train mode
|
||||
*
|
||||
* @return train mode.
|
||||
*/
|
||||
public boolean getTrainMode() {
|
||||
return this.getTrainMode(modelPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Free model
|
||||
*/
|
||||
public void free() {
|
||||
this.free(modelPtr);
|
||||
}
|
||||
|
||||
private native void free(long modelPtr);
|
||||
|
||||
private native long buildByGraph(long graphPtr, long contextPtr, long cfgPtr);
|
||||
|
||||
private native long buildByPath(String modelPath, int modelType, long contextPtr, char[] dec_key, String dec_mod);
|
||||
|
||||
private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr, char[] dec_key,
|
||||
String dec_mod);
|
||||
|
||||
private native List<Long> getInputs(long modelPtr);
|
||||
|
||||
private native long getInputByTensorName(long modelPtr, String tensorName);
|
||||
|
||||
private native List<Long> getOutputs(long modelPtr);
|
||||
|
||||
private native long getOutputByTensorName(long modelPtr, String tensorName);
|
||||
|
||||
private native boolean setTrainMode(long modelPtr, boolean isTrain);
|
||||
|
||||
private native boolean getTrainMode(long modelPtr);
|
||||
|
||||
private native boolean resize(long modelPtr, long[] inputs, int[][] dims);
|
||||
|
||||
private native boolean export(long modelPtr, String fileName, int quantizationType, boolean isOnlyExportInfer,
|
||||
List<String> outputTensorNames);
|
||||
|
||||
private native List<Long> getFeatureMaps(long modelPtr);
|
||||
|
||||
private native boolean updateFeatureMaps(long modelPtr, long[] newFeatures);
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.config;
|
||||
|
||||
/**
|
||||
* Define Cpu core bind mode
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class CpuBindMode {
|
||||
public static final int MID_CPU = 2;
|
||||
public static final int HIGHER_CPU = 1;
|
||||
public static final int NO_BIND = 0;
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.config;
|
||||
|
||||
/**
|
||||
* Define tensor data type.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class DataType {
|
||||
public static final int kNumberTypeBool = 30;
|
||||
public static final int kNumberTypeInt = 31;
|
||||
public static final int kNumberTypeInt8 = 32;
|
||||
public static final int kNumberTypeInt16 = 33;
|
||||
public static final int kNumberTypeInt32 = 34;
|
||||
public static final int kNumberTypeInt64 = 35;
|
||||
public static final int kNumberTypeUInt = 36;
|
||||
public static final int kNumberTypeUInt8 = 37;
|
||||
public static final int kNumberTypeUInt16 = 38;
|
||||
public static final int kNumberTypeUint32 = 39;
|
||||
public static final int kNumberTypeUInt64 = 40;
|
||||
public static final int kNumberTypeFloat = 41;
|
||||
public static final int kNumberTypeFloat16 = 42;
|
||||
public static final int kNumberTypeFloat32 = 43;
|
||||
public static final int kNumberTypeFloat64 = 44;
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.config;
|
||||
|
||||
/**
|
||||
* define device type
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class DeviceType {
|
||||
public static final int DT_CPU = 0;
|
||||
public static final int DT_GPU = 1;
|
||||
public static final int DT_NPU = 2;
|
||||
}
|
|
@ -0,0 +1,103 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.config;
|
||||
|
||||
public class MSContext {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
private long msContextPtr;
|
||||
|
||||
/**
|
||||
* Construct function.
|
||||
*/
|
||||
public MSContext() {
|
||||
this.msContextPtr = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add device info to context.
|
||||
*
|
||||
* @param deviceType support cpu,npu and gpu.
|
||||
* @param isEnableFloat16 whether to use float16 operator for priority.
|
||||
* @param npuFreq npu frequency used for npu device.
|
||||
* @return add status.
|
||||
*/
|
||||
public boolean addDeviceInfo(int deviceType, boolean isEnableFloat16, int npuFreq) {
|
||||
return addDeviceInfo(msContextPtr, deviceType, isEnableFloat16, npuFreq);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add device info to context.
|
||||
*
|
||||
* @param deviceType support cpu,npu and gpu.
|
||||
* @param isEnableFloat16 whether to use float16 operator for priority.
|
||||
* @return add status.
|
||||
*/
|
||||
public boolean addDeviceInfo(int deviceType, boolean isEnableFloat16) {
|
||||
return addDeviceInfo(msContextPtr, deviceType, isEnableFloat16, 3);
|
||||
}
|
||||
|
||||
/**
|
||||
* Init Context.
|
||||
*
|
||||
* @param threadNum thread nums.
|
||||
* @param cpuBindMode support bind high,mid cpu.0,no bind.1,bind mid cpu.2. bind high cpu.
|
||||
* @return init status.
|
||||
*/
|
||||
public boolean init(int threadNum, int cpuBindMode) {
|
||||
this.msContextPtr = createMSContext(threadNum, cpuBindMode, false);
|
||||
return this.msContextPtr != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Init Context.
|
||||
*
|
||||
* @param threadNum thread nums.
|
||||
* @param cpuBindMode support bind high,mid cpu.0,no bind.1,bind mid cpu.2. bind high cpu.
|
||||
* @param isEnableParallel enable parallel in multi devices.
|
||||
* @return init status.
|
||||
*/
|
||||
public boolean init(int threadNum, int cpuBindMode, boolean isEnableParallel) {
|
||||
this.msContextPtr = createMSContext(threadNum, cpuBindMode, isEnableParallel);
|
||||
return this.msContextPtr != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Free context.
|
||||
*/
|
||||
public void free() {
|
||||
this.free(this.msContextPtr);
|
||||
this.msContextPtr = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get context pointer.
|
||||
*
|
||||
* @return context pointer.
|
||||
*/
|
||||
public long getMSContextPtr() {
|
||||
return msContextPtr;
|
||||
}
|
||||
|
||||
private native long createMSContext(int threadNum, int cpuBindMode, boolean enableParallel);
|
||||
|
||||
private native boolean addDeviceInfo(long msContextPtr, int deviceType, boolean isEnableFloat16, int npuFrequency);
|
||||
|
||||
private native void free(long msContextPtr);
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.config;
|
||||
|
||||
/**
|
||||
* define model type
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class ModelType {
|
||||
public static final int MT_MINDIR = 0;
|
||||
public static final int MT_AIR = 1;
|
||||
public static final int MT_OM = 2;
|
||||
public static final int MT_ONNX = 3;
|
||||
public static final int MT_MINDIR_OPT = 4;
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.config;
|
||||
|
||||
public class TrainCfg {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-train-jni");
|
||||
}
|
||||
|
||||
private long trainCfgPtr;
|
||||
|
||||
/**
|
||||
* Construct function.
|
||||
*/
|
||||
public TrainCfg() {
|
||||
this.trainCfgPtr = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Init train config.
|
||||
*
|
||||
* @return init status.
|
||||
*/
|
||||
public boolean init() {
|
||||
this.trainCfgPtr = createTrainCfg(null, 0, false);
|
||||
return this.trainCfgPtr != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Init train config specified loss name.
|
||||
*
|
||||
* @param loss_name loss name used for split inference and train part.
|
||||
* @return init status.
|
||||
*/
|
||||
public boolean init(String loss_name) {
|
||||
this.trainCfgPtr = createTrainCfg(loss_name, 0, false);
|
||||
return this.trainCfgPtr != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Free train config.
|
||||
*/
|
||||
public void free() {
|
||||
this.free(this.trainCfgPtr);
|
||||
this.trainCfgPtr = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add mix precision config to train config.
|
||||
*
|
||||
* @param dynamicLossScale if dynamic or fix loss scale factor.
|
||||
* @param lossScale loss scale factor.
|
||||
* @param thresholdIterNum a threshold for modifying loss scale when dynamic loss scale is enabled.
|
||||
* @return add status.
|
||||
*/
|
||||
public boolean addMixPrecisionCfg(boolean dynamicLossScale, float lossScale, int thresholdIterNum) {
|
||||
return addMixPrecisionCfg(trainCfgPtr, dynamicLossScale, lossScale, thresholdIterNum);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get train config pointer.
|
||||
*
|
||||
* @return train config pointer.
|
||||
*/
|
||||
public long getTrainCfgPtr() {
|
||||
return trainCfgPtr;
|
||||
}
|
||||
|
||||
private native long createTrainCfg(String lossName, int optimizationLevel, boolean accmulateGrads);
|
||||
|
||||
private native boolean addMixPrecisionCfg(long trainCfgPtr, boolean dynamicLossScale, float lossScale,
|
||||
int thresholdIterNum);
|
||||
|
||||
private native void free(long trainCfgPtr);
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.config;
|
||||
|
||||
/**
|
||||
* Define mindspore version info.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class Version {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
/**
|
||||
* Get MindSpore Lite version info.
|
||||
*
|
||||
* @return MindSpore Lite version info.
|
||||
*/
|
||||
public static native String version();
|
||||
}
|
Loading…
Reference in New Issue