model worker

This commit is contained in:
yefeng 2022-09-30 10:53:54 +08:00
parent 93d0a60a4e
commit 2d31bdb8ab
1 changed files with 13 additions and 10 deletions

View File

@ -156,22 +156,25 @@ std::pair<std::vector<std::vector<int64_t>>, bool> ModelWorker::GetModelResize(
Status ModelWorker::CopyOutputTensor(std::vector<MSTensor> model_outputs, std::vector<MSTensor> *user_outputs) {
user_outputs->clear();
user_outputs->insert(user_outputs->end(), model_outputs.begin(), model_outputs.end());
std::vector<MSTensor> new_outputs;
auto output_size = user_outputs->size();
auto output_size = model_outputs.size();
for (size_t i = 0; i < output_size; i++) {
auto copy_tensor = mindspore::MSTensor::CreateTensor(user_outputs->at(i).Name(), user_outputs->at(i).DataType(),
user_outputs->at(i).Shape(), user_outputs->at(i).MutableData(),
user_outputs->at(i).DataSize());
auto copy_tensor =
mindspore::MSTensor::CreateTensor(model_outputs[i].Name(), model_outputs[i].DataType(), {}, nullptr, 0);
if (copy_tensor == nullptr) {
MS_LOG(ERROR) << "model thread copy output tensor failed.";
MS_LOG(ERROR) << "model worker copy output tensor failed.";
return kLiteError;
}
new_outputs.push_back(*copy_tensor);
copy_tensor->SetShape(model_outputs[i].Shape());
copy_tensor->SetFormat(model_outputs[i].format());
copy_tensor->SetQuantParams(model_outputs[i].QuantParams());
copy_tensor->SetData(model_outputs[i].MutableData());
copy_tensor->SetAllocator(model_outputs[i].allocator());
// The memory of the model output tensor is requested by the framework, and the framework will release it, so after
// being acquired by the user, the tensor needs to be set to nullptr
model_outputs[i].SetData(nullptr);
user_outputs->push_back(*copy_tensor);
delete copy_tensor;
}
user_outputs->clear();
user_outputs->insert(user_outputs->end(), new_outputs.begin(), new_outputs.end());
return kSuccess;
}