fix benchmark
This commit is contained in:
parent
eeefa36747
commit
2952f3d6fc
|
@ -174,7 +174,12 @@ int BenchmarkBase::ReadTensorData(std::ifstream &in_file_stream, const std::stri
|
|||
std::vector<float> data;
|
||||
std::vector<std::string> strings_data;
|
||||
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1u, std::multiplies<size_t>());
|
||||
if (GetDataTypeByTensorName(tensor_name) == static_cast<int>(kObjectTypeString)) {
|
||||
auto tensor_data_type = GetDataTypeByTensorName(tensor_name);
|
||||
if (tensor_data_type == static_cast<int>(kTypeUnknown)) {
|
||||
MS_LOG(ERROR) << "get data type failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (tensor_data_type == static_cast<int>(kObjectTypeString)) {
|
||||
strings_data.push_back(line);
|
||||
for (size_t i = 1; i < shape_size; i++) {
|
||||
getline(in_file_stream, line);
|
||||
|
|
|
@ -356,6 +356,16 @@ int BenchmarkUnifiedApi::ReadInputFile() {
|
|||
}
|
||||
|
||||
int BenchmarkUnifiedApi::GetDataTypeByTensorName(const std::string &tensor_name) {
|
||||
#ifdef PARALLEL_INFERENCE
|
||||
for (auto tensor : ms_outputs_for_api_) {
|
||||
auto name = tensor.Name();
|
||||
if (name == tensor_name) {
|
||||
return static_cast<int>(tensor.DataType());
|
||||
}
|
||||
}
|
||||
MS_LOG(ERROR) << "not find tensor name in model output.";
|
||||
return static_cast<int>(DataType::kTypeUnknown);
|
||||
#endif
|
||||
return static_cast<int>(ms_model_.GetOutputByTensorName(tensor_name).DataType());
|
||||
}
|
||||
|
||||
|
@ -994,6 +1004,8 @@ int BenchmarkUnifiedApi::ParallelInference(std::shared_ptr<mindspore::Context> c
|
|||
// load data
|
||||
ms_inputs_for_api_ = model_runner_.GetInputs();
|
||||
MS_CHECK_FALSE_MSG(ms_inputs_for_api_.empty(), RET_ERROR, "model pool input is empty.");
|
||||
ms_outputs_for_api_ = model_runner_.GetOutputs();
|
||||
MS_CHECK_FALSE_MSG(ms_outputs_for_api_.empty(), RET_ERROR, "model pool output is empty.");
|
||||
for (int i = 0; i < flags_->parallel_num_ + flags_->warm_up_loop_count_; i++) {
|
||||
status = LoadInput();
|
||||
MS_CHECK_FALSE_MSG(status != RET_OK, status, "Generate input data error");
|
||||
|
@ -1004,6 +1016,7 @@ int BenchmarkUnifiedApi::ParallelInference(std::shared_ptr<mindspore::Context> c
|
|||
for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) {
|
||||
auto &tensor = ms_inputs_for_api_[i];
|
||||
tensor.SetShape(resize_dims_[i]);
|
||||
tensor.SetData(all_inputs_data_[0][i]);
|
||||
}
|
||||
status = PrintInputData();
|
||||
MS_CHECK_FALSE_MSG(status != RET_OK, status, "PrintInputData error ");
|
||||
|
@ -1527,6 +1540,12 @@ int BenchmarkUnifiedApi::InitDumpTensorDataCallbackParameter() {
|
|||
|
||||
BenchmarkUnifiedApi::~BenchmarkUnifiedApi() {
|
||||
#ifdef PARALLEL_INFERENCE
|
||||
for (auto tensor : ms_inputs_for_api_) {
|
||||
auto data = tensor.MutableData();
|
||||
if (data != nullptr) {
|
||||
tensor.SetData(nullptr);
|
||||
}
|
||||
}
|
||||
for (auto &input : all_inputs_data_) {
|
||||
for (auto &data : input) {
|
||||
if (data != nullptr) {
|
||||
|
|
Loading…
Reference in New Issue