demo use setdata

This commit is contained in:
yefeng 2022-05-11 10:41:12 +08:00
parent e195cba9a6
commit e5eb840055
1 changed files with 35 additions and 22 deletions

View File

@ -93,31 +93,25 @@ void GenerateRandomData(int size, void *data, Distribution distribution) {
[&distribution, &random_engine]() { return static_cast<T>(distribution(random_engine)); });
}
std::vector<mindspore::MSTensor> GenerateInputDataWithRandom(std::vector<mindspore::MSTensor> model_inputs) {
std::vector<mindspore::MSTensor> inputs;
auto tensor_name = model_inputs[0].Name();
size_t size = model_inputs[0].DataSize();
std::vector<int64_t> shape = model_inputs[0].Shape();
int SetInputDataWithRandom(std::vector<mindspore::MSTensor> inputs) {
if (inputs.size() != 1) {
std::cerr << "input size must be 1.\n";
return -1;
}
size_t size = inputs[0].DataSize();
if (size == 0 || size > MAX_MALLOC_SIZE) {
std::cerr << "malloc size is wrong" << std::endl;
return {};
}
// user need malloc data for parallel predict input data;
void *input_data = malloc(size);
if (input_data == nullptr) {
std::cerr << "malloc failed" << std::endl;
return {};
}
GenerateRandomData<float>(size, input_data, std::uniform_real_distribution<float>(0.1f, 1.0f));
auto new_tensor = mindspore::MSTensor::CreateTensor(tensor_name, model_inputs[0].DataType(), shape, input_data, size);
if (new_tensor == nullptr) {
std::cerr << "CreateTensor failed" << std::endl;
return {};
}
inputs.push_back(*new_tensor);
delete new_tensor;
free(input_data);
return inputs;
inputs.at(0).SetData(input_data);
return 0;
}
int QuickStart(int argc, const char **argv) {
@ -167,21 +161,30 @@ int QuickStart(int argc, const char **argv) {
}
// Get Input
auto model_input = model_runner->GetInputs();
if (model_input.empty()) {
auto inputs = model_runner->GetInputs();
if (inputs.empty()) {
delete model_runner;
std::cerr << "model input is empty." << std::endl;
return -1;
}
// Generate random data as input data.
auto inputs = GenerateInputDataWithRandom(model_input);
if (inputs.empty()) {
// set random data to input data.
auto ret = SetInputDataWithRandom(inputs);
if (ret != 0) {
delete model_runner;
std::cerr << "input is empty." << std::endl;
std::cerr << "set input data failed." << std::endl;
return -1;
}
// Get Output
std::vector<mindspore::MSTensor> outputs;
auto outputs = model_runner->GetOutputs();
for (auto &output : outputs) {
size_t size = output.DataSize();
if (size == 0 || size > MAX_MALLOC_SIZE) {
std::cerr << "malloc size is wrong" << std::endl;
return -1;
}
auto out_data = malloc(size);
output.SetData(out_data);
}
// Model Predict
auto predict_ret = model_runner->Predict(inputs, &outputs);
@ -203,6 +206,16 @@ int QuickStart(int argc, const char **argv) {
std::cout << std::endl;
}
// user need free input data and output data
for (auto &input : inputs) {
free(input.MutableData());
input.SetData(nullptr);
}
for (auto &output : outputs) {
free(output.MutableData());
output.SetData(nullptr);
}
// Delete model runner.
delete model_runner;
return mindspore::kSuccess;