increase protection for java-interface

This commit is contained in:
xuanyue 2023-01-29 10:33:02 +08:00
parent 3c4b0f1083
commit 5e8fbb38eb
3 changed files with 26 additions and 12 deletions

View File

@ -32,14 +32,12 @@ int CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size
}
ShapePush(out_shape, out_shape_size, data[i]);
}
if (size == 0) {
return NNACL_ERR;
}
if ((int)(data[index]) == -1) {
if (index >= MAX_SHAPE_SIZE) {
return NNACL_ERR;
}
out_shape[index] = input_count / size;
out_shape[index] = size == 0 ? 0 : input_count / size;
}
return NNACL_OK;
}

View File

@ -52,7 +52,8 @@ public class MSTensor {
* @param buffer tensor buffer
*/
public static MSTensor createTensor(String tensorName, int dataType, int[] tensorShape, ByteBuffer buffer) {
if (tensorName == null || tensorShape == null || buffer == null) {
if (tensorName == null || tensorShape == null || buffer == null || dataType < DataType.kNumberTypeBool ||
dataType > DataType.kNumberTypeFloat64) {
return null;
}
long tensorPtr = createTensorByNative(tensorName, dataType, tensorShape, buffer);

View File

@ -63,10 +63,17 @@ public class ModelParallelRunner {
* @return init status.
*/
public boolean init(String modelPath, RunnerConfig runnerConfig) {
rwLock.writeLock().lock();
if (runnerConfig == null || modelPath == null) {
rwLock.writeLock().unlock();
return false;
}
if (modelParallelRunnerPtr != 0L){
rwLock.writeLock().unlock();
return true;
}
modelParallelRunnerPtr = this.init(modelPath, runnerConfig.getRunnerConfigPtr());
rwLock.writeLock().unlock();
return modelParallelRunnerPtr != 0L;
}
@ -77,10 +84,17 @@ public class ModelParallelRunner {
* @return init status.
*/
public boolean init(String modelPath) {
rwLock.writeLock().lock();
if (modelPath == null) {
rwLock.writeLock().unlock();
return false;
}
if (modelParallelRunnerPtr != 0L){
rwLock.writeLock().unlock();
return true;
}
modelParallelRunnerPtr = this.init(modelPath, 0L);
rwLock.writeLock().unlock();
return modelParallelRunnerPtr != 0;
}
@ -197,11 +211,12 @@ public class ModelParallelRunner {
break;
}
rwLock.writeLock().lock();
if (modelParallelRunnerPtr != 0L) {
this.free(modelParallelRunnerPtr);
long modelParallelRunnerTempPtr = modelParallelRunnerPtr;
modelParallelRunnerPtr = 0L;
}
rwLock.writeLock().unlock();
if (modelParallelRunnerTempPtr != 0L) {
this.free(modelParallelRunnerTempPtr);
}
}
private native long init(String modelPath, long runnerConfigPtr);