forked from mindspore-Ecosystem/mindspore
increase protection for java-interface
This commit is contained in:
parent
3c4b0f1083
commit
5e8fbb38eb
|
@ -32,14 +32,12 @@ int CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size
|
|||
}
|
||||
ShapePush(out_shape, out_shape_size, data[i]);
|
||||
}
|
||||
if (size == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
|
||||
if ((int)(data[index]) == -1) {
|
||||
if (index >= MAX_SHAPE_SIZE) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
out_shape[index] = input_count / size;
|
||||
out_shape[index] = size == 0 ? 0 : input_count / size;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -52,7 +52,8 @@ public class MSTensor {
|
|||
* @param buffer tensor buffer
|
||||
*/
|
||||
public static MSTensor createTensor(String tensorName, int dataType, int[] tensorShape, ByteBuffer buffer) {
|
||||
if (tensorName == null || tensorShape == null || buffer == null) {
|
||||
if (tensorName == null || tensorShape == null || buffer == null || dataType < DataType.kNumberTypeBool ||
|
||||
dataType > DataType.kNumberTypeFloat64) {
|
||||
return null;
|
||||
}
|
||||
long tensorPtr = createTensorByNative(tensorName, dataType, tensorShape, buffer);
|
||||
|
|
|
@ -56,36 +56,50 @@ public class ModelParallelRunner {
|
|||
}
|
||||
|
||||
/**
|
||||
* Build a model runner from model path so that it can run on a device.
|
||||
* Build a model runner from model path so that it can run on a device.
|
||||
*
|
||||
* @param modelPath the model path.
|
||||
* @param runnerConfig the RunnerConfig Object.
|
||||
* @return init status.
|
||||
*/
|
||||
public boolean init(String modelPath, RunnerConfig runnerConfig) {
|
||||
rwLock.writeLock().lock();
|
||||
if (runnerConfig == null || modelPath == null) {
|
||||
rwLock.writeLock().unlock();
|
||||
return false;
|
||||
}
|
||||
if (modelParallelRunnerPtr != 0L){
|
||||
rwLock.writeLock().unlock();
|
||||
return true;
|
||||
}
|
||||
modelParallelRunnerPtr = this.init(modelPath, runnerConfig.getRunnerConfigPtr());
|
||||
rwLock.writeLock().unlock();
|
||||
return modelParallelRunnerPtr != 0L;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a model runner from model path so that it can run on a device.
|
||||
* Build a model runner from model path so that it can run on a device.
|
||||
*
|
||||
* @param modelPath the model path.
|
||||
* @return init status.
|
||||
*/
|
||||
public boolean init(String modelPath) {
|
||||
rwLock.writeLock().lock();
|
||||
if (modelPath == null) {
|
||||
rwLock.writeLock().unlock();
|
||||
return false;
|
||||
}
|
||||
if (modelParallelRunnerPtr != 0L){
|
||||
rwLock.writeLock().unlock();
|
||||
return true;
|
||||
}
|
||||
modelParallelRunnerPtr = this.init(modelPath, 0L);
|
||||
rwLock.writeLock().unlock();
|
||||
return modelParallelRunnerPtr != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a model runner from model path so that it can run on a device.
|
||||
* Build a model runner from model path so that it can run on a device.
|
||||
*
|
||||
* @param inputs inputs A vector where model inputs are arranged in sequence.
|
||||
* @param outputs outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
|
||||
|
@ -197,11 +211,12 @@ public class ModelParallelRunner {
|
|||
break;
|
||||
}
|
||||
rwLock.writeLock().lock();
|
||||
if (modelParallelRunnerPtr != 0L) {
|
||||
this.free(modelParallelRunnerPtr);
|
||||
modelParallelRunnerPtr = 0L;
|
||||
}
|
||||
long modelParallelRunnerTempPtr = modelParallelRunnerPtr;
|
||||
modelParallelRunnerPtr = 0L;
|
||||
rwLock.writeLock().unlock();
|
||||
if (modelParallelRunnerTempPtr != 0L) {
|
||||
this.free(modelParallelRunnerTempPtr);
|
||||
}
|
||||
}
|
||||
|
||||
private native long init(String modelPath, long runnerConfigPtr);
|
||||
|
|
Loading…
Reference in New Issue