increase protection for java-interface

This commit is contained in:
xuanyue 2023-01-29 10:33:02 +08:00
parent 3c4b0f1083
commit 5e8fbb38eb
3 changed files with 26 additions and 12 deletions

View File

@ -32,14 +32,12 @@ int CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size
} }
ShapePush(out_shape, out_shape_size, data[i]); ShapePush(out_shape, out_shape_size, data[i]);
} }
if (size == 0) {
return NNACL_ERR;
}
if ((int)(data[index]) == -1) { if ((int)(data[index]) == -1) {
if (index >= MAX_SHAPE_SIZE) { if (index >= MAX_SHAPE_SIZE) {
return NNACL_ERR; return NNACL_ERR;
} }
out_shape[index] = input_count / size; out_shape[index] = size == 0 ? 0 : input_count / size;
} }
return NNACL_OK; return NNACL_OK;
} }

View File

@ -52,7 +52,8 @@ public class MSTensor {
* @param buffer tensor buffer * @param buffer tensor buffer
*/ */
public static MSTensor createTensor(String tensorName, int dataType, int[] tensorShape, ByteBuffer buffer) { public static MSTensor createTensor(String tensorName, int dataType, int[] tensorShape, ByteBuffer buffer) {
if (tensorName == null || tensorShape == null || buffer == null) { if (tensorName == null || tensorShape == null || buffer == null || dataType < DataType.kNumberTypeBool ||
dataType > DataType.kNumberTypeFloat64) {
return null; return null;
} }
long tensorPtr = createTensorByNative(tensorName, dataType, tensorShape, buffer); long tensorPtr = createTensorByNative(tensorName, dataType, tensorShape, buffer);

View File

@ -63,10 +63,17 @@ public class ModelParallelRunner {
* @return init status. * @return init status.
*/ */
public boolean init(String modelPath, RunnerConfig runnerConfig) { public boolean init(String modelPath, RunnerConfig runnerConfig) {
rwLock.writeLock().lock();
if (runnerConfig == null || modelPath == null) { if (runnerConfig == null || modelPath == null) {
rwLock.writeLock().unlock();
return false; return false;
} }
if (modelParallelRunnerPtr != 0L){
rwLock.writeLock().unlock();
return true;
}
modelParallelRunnerPtr = this.init(modelPath, runnerConfig.getRunnerConfigPtr()); modelParallelRunnerPtr = this.init(modelPath, runnerConfig.getRunnerConfigPtr());
rwLock.writeLock().unlock();
return modelParallelRunnerPtr != 0L; return modelParallelRunnerPtr != 0L;
} }
@ -77,10 +84,17 @@ public class ModelParallelRunner {
* @return init status. * @return init status.
*/ */
public boolean init(String modelPath) { public boolean init(String modelPath) {
rwLock.writeLock().lock();
if (modelPath == null) { if (modelPath == null) {
rwLock.writeLock().unlock();
return false; return false;
} }
if (modelParallelRunnerPtr != 0L){
rwLock.writeLock().unlock();
return true;
}
modelParallelRunnerPtr = this.init(modelPath, 0L); modelParallelRunnerPtr = this.init(modelPath, 0L);
rwLock.writeLock().unlock();
return modelParallelRunnerPtr != 0; return modelParallelRunnerPtr != 0;
} }
@ -197,11 +211,12 @@ public class ModelParallelRunner {
break; break;
} }
rwLock.writeLock().lock(); rwLock.writeLock().lock();
if (modelParallelRunnerPtr != 0L) { long modelParallelRunnerTempPtr = modelParallelRunnerPtr;
this.free(modelParallelRunnerPtr);
modelParallelRunnerPtr = 0L; modelParallelRunnerPtr = 0L;
}
rwLock.writeLock().unlock(); rwLock.writeLock().unlock();
if (modelParallelRunnerTempPtr != 0L) {
this.free(modelParallelRunnerTempPtr);
}
} }
private native long init(String modelPath, long runnerConfigPtr); private native long init(String modelPath, long runnerConfigPtr);