!4434 fix bug that lite_session GetInputs return repeated input tensor

Merge pull request !4434 from hangq/master
This commit is contained in:
mindspore-ci-bot 2020-08-15 09:57:42 +08:00 committed by Gitee
commit b34b7973be
4 changed files with 37 additions and 16 deletions

View File

@ -92,6 +92,16 @@ void LiteSession::InitGraphInputTensors(const lite::Model *model) {
}
}
void LiteSession::InitGraphInputMSTensors(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->input_vec_.empty());
MS_ASSERT(meta_graph != nullptr);
for (auto &input_tensor : this->inputs_) {
MS_ASSERT(input_tensor != nullptr);
this->input_vec_.emplace_back(new lite::tensor::LiteTensor(input_tensor));
}
}
void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->outputs_.empty());
@ -169,6 +179,7 @@ void LiteSession::InitGraphOutputMap(const lite::Model *model) {
void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
InitGraphInputTensors(model);
InitGraphInputMSTensors(model);
InitGraphOutputTensors(model);
InitGraphInputMap(model);
InitGraphOutputMap(model);
@ -201,16 +212,7 @@ int LiteSession::CompileGraph(Model *model) {
}
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const {
std::vector<mindspore::tensor::MSTensor *> ret;
for (auto &iter : this->input_map_) {
auto &node_input_tensors = iter.second;
for (auto tensor : node_input_tensors) {
if (!IsContain(ret, tensor)) {
ret.emplace_back(tensor);
}
}
}
return ret;
return this->input_vec_;
}
int LiteSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) {

View File

@ -57,13 +57,15 @@ class LiteSession : public session::LiteSession {
int ConvertTensors(const lite::Model *model);
void InitGraphInOutTensors(const lite::Model *model);
// init this->inputs_
void InitGraphInputTensors(const lite::Model *model);
// init this->input_vec_
void InitGraphInputMSTensors(const lite::Model *model);
// init this->outputs_
void InitGraphOutputTensors(const lite::Model *model);
// init this->input_map_
void InitGraphInputMap(const lite::Model *model);
// init this->output_map_
void InitGraphOutputMap(const lite::Model *model);
protected:
@ -74,6 +76,8 @@ class LiteSession : public session::LiteSession {
std::vector<tensor::Tensor *> inputs_;
// graph output tensors
std::vector<tensor::Tensor *> outputs_;
// graph input MSTensors
std::vector<mindspore::tensor::MSTensor *> input_vec_;
// graph input node name -- input tensors
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> input_map_;
// graph output node name -- output tensors

View File

@ -49,7 +49,8 @@ int Benchmark::GenerateInputData() {
auto tensorByteSize = tensor->Size();
auto status = GenerateRandomData(tensorByteSize, inputData);
if (status != 0) {
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed %d" << status;
std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status;
return status;
}
}
@ -60,12 +61,14 @@ int Benchmark::LoadInput() {
if (_flags->inDataPath.empty()) {
auto status = GenerateInputData();
if (status != 0) {
std::cerr << "Generate input data error " << status << std::endl;
MS_LOG(ERROR) << "Generate input data error " << status;
return status;
}
} else {
auto status = ReadInputFile();
if (status != 0) {
std::cerr << "ReadInputFile error, " << status << std::endl;
MS_LOG(ERROR) << "ReadInputFile error, " << status;
return status;
}
@ -97,6 +100,7 @@ int Benchmark::ReadInputFile() {
char *binBuf = ReadFile(_flags->input_data_list[i].c_str(), &size);
auto tensorDataSize = cur_tensor->Size();
if (size != tensorDataSize) {
std::cerr << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size << std::endl;
MS_LOG(ERROR) << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size;
return RET_ERROR;
}
@ -113,11 +117,13 @@ int Benchmark::ReadCalibData() {
// read calib data
std::ifstream inFile(calibDataPath);
if (!inFile.good()) {
std::cerr << "file: " << calibDataPath << " is not exist" << std::endl;
MS_LOG(ERROR) << "file: " << calibDataPath << " is not exist";
return RET_ERROR;
}
if (!inFile.is_open()) {
std::cerr << "file: " << calibDataPath << " open failed" << std::endl;
MS_LOG(ERROR) << "file: " << calibDataPath << " open failed";
inFile.close();
return RET_ERROR;
@ -181,6 +187,7 @@ float Benchmark::CompareData(const std::string &nodeName, std::vector<int> msSha
oss << dim << ",";
}
oss << ") are different";
std::cerr << oss.str() << std::endl;
MS_LOG(ERROR) << "%s", oss.str().c_str();
return RET_ERROR;
}
@ -193,6 +200,7 @@ float Benchmark::CompareData(const std::string &nodeName, std::vector<int> msSha
}
if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
return RET_ERROR;
}
@ -524,6 +532,13 @@ int Benchmark::Init() {
return RET_OK;
}
Benchmark::~Benchmark() {
for (auto iter : this->calibData) {
delete (iter.second);
}
this->calibData.clear();
}
int RunBenchmark(int argc, const char **argv) {
BenchmarkFlags flags;
Option<std::string> err = flags.ParseFlags(argc, argv);

View File

@ -104,7 +104,7 @@ class MS_API Benchmark {
public:
explicit Benchmark(BenchmarkFlags *flags) : _flags(flags) {}
virtual ~Benchmark() = default;
virtual ~Benchmark();
int Init();
int RunBenchmark(const std::string &deviceType = "NPU");