forked from mindspore-Ecosystem/mindspore
demo use setdata
This commit is contained in:
parent
e195cba9a6
commit
e5eb840055
|
@ -93,31 +93,25 @@ void GenerateRandomData(int size, void *data, Distribution distribution) {
|
|||
[&distribution, &random_engine]() { return static_cast<T>(distribution(random_engine)); });
|
||||
}
|
||||
|
||||
std::vector<mindspore::MSTensor> GenerateInputDataWithRandom(std::vector<mindspore::MSTensor> model_inputs) {
|
||||
std::vector<mindspore::MSTensor> inputs;
|
||||
auto tensor_name = model_inputs[0].Name();
|
||||
size_t size = model_inputs[0].DataSize();
|
||||
std::vector<int64_t> shape = model_inputs[0].Shape();
|
||||
int SetInputDataWithRandom(std::vector<mindspore::MSTensor> inputs) {
|
||||
if (inputs.size() != 1) {
|
||||
std::cerr << "input size must be 1.\n";
|
||||
return -1;
|
||||
}
|
||||
size_t size = inputs[0].DataSize();
|
||||
if (size == 0 || size > MAX_MALLOC_SIZE) {
|
||||
std::cerr << "malloc size is wrong" << std::endl;
|
||||
return {};
|
||||
}
|
||||
// user need malloc data for parallel predict input data;
|
||||
void *input_data = malloc(size);
|
||||
if (input_data == nullptr) {
|
||||
std::cerr << "malloc failed" << std::endl;
|
||||
return {};
|
||||
}
|
||||
GenerateRandomData<float>(size, input_data, std::uniform_real_distribution<float>(0.1f, 1.0f));
|
||||
|
||||
auto new_tensor = mindspore::MSTensor::CreateTensor(tensor_name, model_inputs[0].DataType(), shape, input_data, size);
|
||||
if (new_tensor == nullptr) {
|
||||
std::cerr << "CreateTensor failed" << std::endl;
|
||||
return {};
|
||||
}
|
||||
inputs.push_back(*new_tensor);
|
||||
delete new_tensor;
|
||||
free(input_data);
|
||||
return inputs;
|
||||
inputs.at(0).SetData(input_data);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int QuickStart(int argc, const char **argv) {
|
||||
|
@ -167,21 +161,30 @@ int QuickStart(int argc, const char **argv) {
|
|||
}
|
||||
|
||||
// Get Input
|
||||
auto model_input = model_runner->GetInputs();
|
||||
if (model_input.empty()) {
|
||||
auto inputs = model_runner->GetInputs();
|
||||
if (inputs.empty()) {
|
||||
delete model_runner;
|
||||
std::cerr << "model input is empty." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
// Generate random data as input data.
|
||||
auto inputs = GenerateInputDataWithRandom(model_input);
|
||||
if (inputs.empty()) {
|
||||
// set random data to input data.
|
||||
auto ret = SetInputDataWithRandom(inputs);
|
||||
if (ret != 0) {
|
||||
delete model_runner;
|
||||
std::cerr << "input is empty." << std::endl;
|
||||
std::cerr << "set input data failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
// Get Output
|
||||
std::vector<mindspore::MSTensor> outputs;
|
||||
auto outputs = model_runner->GetOutputs();
|
||||
for (auto &output : outputs) {
|
||||
size_t size = output.DataSize();
|
||||
if (size == 0 || size > MAX_MALLOC_SIZE) {
|
||||
std::cerr << "malloc size is wrong" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
auto out_data = malloc(size);
|
||||
output.SetData(out_data);
|
||||
}
|
||||
|
||||
// Model Predict
|
||||
auto predict_ret = model_runner->Predict(inputs, &outputs);
|
||||
|
@ -203,6 +206,16 @@ int QuickStart(int argc, const char **argv) {
|
|||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
// user need free input data and output data
|
||||
for (auto &input : inputs) {
|
||||
free(input.MutableData());
|
||||
input.SetData(nullptr);
|
||||
}
|
||||
for (auto &output : outputs) {
|
||||
free(output.MutableData());
|
||||
output.SetData(nullptr);
|
||||
}
|
||||
|
||||
// Delete model runner.
|
||||
delete model_runner;
|
||||
return mindspore::kSuccess;
|
||||
|
|
Loading…
Reference in New Issue