fix demo bug
This commit is contained in:
parent
a56032e8a8
commit
a0ae204c7d
|
@ -22,7 +22,7 @@ get_version
|
|||
MODEL_DOWNLOAD_URL="https://download.mindspore.cn/model_zoo/official/lite/quick_start/mobilenetv2.ms"
|
||||
MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64"
|
||||
MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz"
|
||||
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/x86_64/${MINDSPORE_FILE}"
|
||||
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}.B310/MindSpore/lite/release/linux/x86_64/server/${MINDSPORE_FILE}"
|
||||
|
||||
mkdir -p build || exit
|
||||
mkdir -p lib || exit
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
namespace {
|
||||
constexpr int kNumPrintOfOutData = 50;
|
||||
constexpr int kNumWorkers = 2;
|
||||
constexpr int kElementsNum = 1001;
|
||||
constexpr int64_t MAX_MALLOC_SIZE = static_cast<size_t>(2000) * 1024 * 1024;
|
||||
} // namespace
|
||||
std::string RealPath(const char *path) {
|
||||
|
@ -177,12 +178,13 @@ int QuickStart(int argc, const char **argv) {
|
|||
// Get Output
|
||||
auto outputs = model_runner->GetOutputs();
|
||||
for (auto &output : outputs) {
|
||||
size_t size = output.DataSize();
|
||||
size_t size = kElementsNum * sizeof(float);
|
||||
if (size == 0 || size > MAX_MALLOC_SIZE) {
|
||||
std::cerr << "malloc size is wrong" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
auto out_data = malloc(size);
|
||||
output.SetShape({1, kElementsNum});
|
||||
output.SetData(out_data);
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ get_version
|
|||
MODEL_DOWNLOAD_URL="https://download.mindspore.cn/model_zoo/official/lite/quick_start/mobilenetv2.ms"
|
||||
MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64"
|
||||
MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz"
|
||||
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/x86_64/${MINDSPORE_FILE}"
|
||||
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}.B310/MindSpore/lite/release/linux/x86_64/server/${MINDSPORE_FILE}"
|
||||
|
||||
mkdir -p build
|
||||
mkdir -p lib
|
||||
|
|
|
@ -126,12 +126,7 @@ public class Main {
|
|||
System.err.println("outputs size is wrong.");
|
||||
return;
|
||||
}
|
||||
MSTensor output = outputs.get(0);
|
||||
int outputElementNums = output.elementsNum();
|
||||
float[] outputRandomData = generateArray(outputElementNums);
|
||||
ByteBuffer outputData = floatArrayToByteBuffer(outputRandomData);
|
||||
output.setData(outputData);
|
||||
|
||||
List<MSTensor> outputs = new ArrayList<>();
|
||||
|
||||
// runner do predict
|
||||
ret = runner.predict(inputs,outputs);
|
||||
|
|
|
@ -111,6 +111,19 @@ public class MSTensor {
|
|||
return this.getLongData(this.tensorPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the shape of MSTensor.
|
||||
*
|
||||
* @param shape of int[] type.
|
||||
* @return whether set shape success.
|
||||
*/
|
||||
public boolean setShape(int[] tensorShape) {
|
||||
if (tensorShape == null) {
|
||||
return false;
|
||||
}
|
||||
return this.setShape(this.tensorPtr, tensorShape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the input data of MSTensor.
|
||||
*
|
||||
|
@ -194,6 +207,8 @@ public class MSTensor {
|
|||
|
||||
private native boolean setData(long tensorPtr, byte[] data, long dataLen);
|
||||
|
||||
private native boolean setShape(long tensorPtr, int[] tensorShape);
|
||||
|
||||
private native boolean setByteBufferData(long tensorPtr, ByteBuffer buffer);
|
||||
|
||||
private native long size(long tensorPtr);
|
||||
|
|
|
@ -226,6 +226,26 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_MSTensor_size(JNIEnv *env,
|
|||
return ms_tensor_ptr->DataSize();
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_MSTensor_setShape(JNIEnv *env, jobject thiz, jlong tensor_ptr,
|
||||
jintArray tensor_shape) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
|
||||
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(pointer);
|
||||
auto size = static_cast<int>(env->GetArrayLength(tensor_shape));
|
||||
std::vector<int64_t> c_shape(size);
|
||||
jint *shape_pointer = env->GetIntArrayElements(tensor_shape, nullptr);
|
||||
for (int i = 0; i < size; i++) {
|
||||
c_shape[i] = static_cast<int64_t>(shape_pointer[i]);
|
||||
}
|
||||
env->ReleaseIntArrayElements(tensor_shape, shape_pointer, JNI_ABORT);
|
||||
ms_tensor_ptr->SetShape(c_shape);
|
||||
return static_cast<jboolean>(true);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_MSTensor_elementsNum(JNIEnv *env, jobject thiz, jlong tensor_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
|
|
Loading…
Reference in New Issue