fix benchmark

This commit is contained in:
yefeng 2022-08-09 09:38:19 +08:00
parent eeefa36747
commit 2952f3d6fc
2 changed files with 25 additions and 1 deletions

View File

@ -174,7 +174,12 @@ int BenchmarkBase::ReadTensorData(std::ifstream &in_file_stream, const std::stri
std::vector<float> data;
std::vector<std::string> strings_data;
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1u, std::multiplies<size_t>());
if (GetDataTypeByTensorName(tensor_name) == static_cast<int>(kObjectTypeString)) {
auto tensor_data_type = GetDataTypeByTensorName(tensor_name);
if (tensor_data_type == static_cast<int>(kTypeUnknown)) {
MS_LOG(ERROR) << "get data type failed.";
return RET_ERROR;
}
if (tensor_data_type == static_cast<int>(kObjectTypeString)) {
strings_data.push_back(line);
for (size_t i = 1; i < shape_size; i++) {
getline(in_file_stream, line);

View File

@ -356,6 +356,16 @@ int BenchmarkUnifiedApi::ReadInputFile() {
}
int BenchmarkUnifiedApi::GetDataTypeByTensorName(const std::string &tensor_name) {
#ifdef PARALLEL_INFERENCE
for (auto tensor : ms_outputs_for_api_) {
auto name = tensor.Name();
if (name == tensor_name) {
return static_cast<int>(tensor.DataType());
}
}
MS_LOG(ERROR) << "not find tensor name in model output.";
return static_cast<int>(DataType::kTypeUnknown);
#endif
return static_cast<int>(ms_model_.GetOutputByTensorName(tensor_name).DataType());
}
@ -994,6 +1004,8 @@ int BenchmarkUnifiedApi::ParallelInference(std::shared_ptr<mindspore::Context> c
// load data
ms_inputs_for_api_ = model_runner_.GetInputs();
MS_CHECK_FALSE_MSG(ms_inputs_for_api_.empty(), RET_ERROR, "model pool input is empty.");
ms_outputs_for_api_ = model_runner_.GetOutputs();
MS_CHECK_FALSE_MSG(ms_outputs_for_api_.empty(), RET_ERROR, "model pool output is empty.");
for (int i = 0; i < flags_->parallel_num_ + flags_->warm_up_loop_count_; i++) {
status = LoadInput();
MS_CHECK_FALSE_MSG(status != RET_OK, status, "Generate input data error");
@ -1004,6 +1016,7 @@ int BenchmarkUnifiedApi::ParallelInference(std::shared_ptr<mindspore::Context> c
for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) {
auto &tensor = ms_inputs_for_api_[i];
tensor.SetShape(resize_dims_[i]);
tensor.SetData(all_inputs_data_[0][i]);
}
status = PrintInputData();
MS_CHECK_FALSE_MSG(status != RET_OK, status, "PrintInputData error ");
@ -1527,6 +1540,12 @@ int BenchmarkUnifiedApi::InitDumpTensorDataCallbackParameter() {
BenchmarkUnifiedApi::~BenchmarkUnifiedApi() {
#ifdef PARALLEL_INFERENCE
for (auto tensor : ms_inputs_for_api_) {
auto data = tensor.MutableData();
if (data != nullptr) {
tensor.SetData(nullptr);
}
}
for (auto &input : all_inputs_data_) {
for (auto &data : input) {
if (data != nullptr) {