forked from mindspore-Ecosystem/mindspore
increase protection for java-interface
This commit is contained in:
parent
3c4b0f1083
commit
5e8fbb38eb
|
@ -32,14 +32,12 @@ int CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size
|
||||||
}
|
}
|
||||||
ShapePush(out_shape, out_shape_size, data[i]);
|
ShapePush(out_shape, out_shape_size, data[i]);
|
||||||
}
|
}
|
||||||
if (size == 0) {
|
|
||||||
return NNACL_ERR;
|
|
||||||
}
|
|
||||||
if ((int)(data[index]) == -1) {
|
if ((int)(data[index]) == -1) {
|
||||||
if (index >= MAX_SHAPE_SIZE) {
|
if (index >= MAX_SHAPE_SIZE) {
|
||||||
return NNACL_ERR;
|
return NNACL_ERR;
|
||||||
}
|
}
|
||||||
out_shape[index] = input_count / size;
|
out_shape[index] = size == 0 ? 0 : input_count / size;
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,8 @@ public class MSTensor {
|
||||||
* @param buffer tensor buffer
|
* @param buffer tensor buffer
|
||||||
*/
|
*/
|
||||||
public static MSTensor createTensor(String tensorName, int dataType, int[] tensorShape, ByteBuffer buffer) {
|
public static MSTensor createTensor(String tensorName, int dataType, int[] tensorShape, ByteBuffer buffer) {
|
||||||
if (tensorName == null || tensorShape == null || buffer == null) {
|
if (tensorName == null || tensorShape == null || buffer == null || dataType < DataType.kNumberTypeBool ||
|
||||||
|
dataType > DataType.kNumberTypeFloat64) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
long tensorPtr = createTensorByNative(tensorName, dataType, tensorShape, buffer);
|
long tensorPtr = createTensorByNative(tensorName, dataType, tensorShape, buffer);
|
||||||
|
|
|
@ -63,10 +63,17 @@ public class ModelParallelRunner {
|
||||||
* @return init status.
|
* @return init status.
|
||||||
*/
|
*/
|
||||||
public boolean init(String modelPath, RunnerConfig runnerConfig) {
|
public boolean init(String modelPath, RunnerConfig runnerConfig) {
|
||||||
|
rwLock.writeLock().lock();
|
||||||
if (runnerConfig == null || modelPath == null) {
|
if (runnerConfig == null || modelPath == null) {
|
||||||
|
rwLock.writeLock().unlock();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (modelParallelRunnerPtr != 0L){
|
||||||
|
rwLock.writeLock().unlock();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
modelParallelRunnerPtr = this.init(modelPath, runnerConfig.getRunnerConfigPtr());
|
modelParallelRunnerPtr = this.init(modelPath, runnerConfig.getRunnerConfigPtr());
|
||||||
|
rwLock.writeLock().unlock();
|
||||||
return modelParallelRunnerPtr != 0L;
|
return modelParallelRunnerPtr != 0L;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,10 +84,17 @@ public class ModelParallelRunner {
|
||||||
* @return init status.
|
* @return init status.
|
||||||
*/
|
*/
|
||||||
public boolean init(String modelPath) {
|
public boolean init(String modelPath) {
|
||||||
|
rwLock.writeLock().lock();
|
||||||
if (modelPath == null) {
|
if (modelPath == null) {
|
||||||
|
rwLock.writeLock().unlock();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (modelParallelRunnerPtr != 0L){
|
||||||
|
rwLock.writeLock().unlock();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
modelParallelRunnerPtr = this.init(modelPath, 0L);
|
modelParallelRunnerPtr = this.init(modelPath, 0L);
|
||||||
|
rwLock.writeLock().unlock();
|
||||||
return modelParallelRunnerPtr != 0;
|
return modelParallelRunnerPtr != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -197,11 +211,12 @@ public class ModelParallelRunner {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
rwLock.writeLock().lock();
|
rwLock.writeLock().lock();
|
||||||
if (modelParallelRunnerPtr != 0L) {
|
long modelParallelRunnerTempPtr = modelParallelRunnerPtr;
|
||||||
this.free(modelParallelRunnerPtr);
|
|
||||||
modelParallelRunnerPtr = 0L;
|
modelParallelRunnerPtr = 0L;
|
||||||
}
|
|
||||||
rwLock.writeLock().unlock();
|
rwLock.writeLock().unlock();
|
||||||
|
if (modelParallelRunnerTempPtr != 0L) {
|
||||||
|
this.free(modelParallelRunnerTempPtr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private native long init(String modelPath, long runnerConfigPtr);
|
private native long init(String modelPath, long runnerConfigPtr);
|
||||||
|
|
Loading…
Reference in New Issue