forked from mindspore-Ecosystem/mindspore
!4434 fix bug that lite_session GetInputs return repeated input tensor
Merge pull request !4434 from hangq/master
This commit is contained in:
commit
b34b7973be
|
@ -92,6 +92,16 @@ void LiteSession::InitGraphInputTensors(const lite::Model *model) {
|
|||
}
|
||||
}
|
||||
|
||||
void LiteSession::InitGraphInputMSTensors(const lite::Model *model) {
|
||||
auto meta_graph = model->GetMetaGraph();
|
||||
MS_ASSERT(this->input_vec_.empty());
|
||||
MS_ASSERT(meta_graph != nullptr);
|
||||
for (auto &input_tensor : this->inputs_) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
this->input_vec_.emplace_back(new lite::tensor::LiteTensor(input_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
|
||||
auto meta_graph = model->GetMetaGraph();
|
||||
MS_ASSERT(this->outputs_.empty());
|
||||
|
@ -169,6 +179,7 @@ void LiteSession::InitGraphOutputMap(const lite::Model *model) {
|
|||
|
||||
void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
|
||||
InitGraphInputTensors(model);
|
||||
InitGraphInputMSTensors(model);
|
||||
InitGraphOutputTensors(model);
|
||||
InitGraphInputMap(model);
|
||||
InitGraphOutputMap(model);
|
||||
|
@ -201,16 +212,7 @@ int LiteSession::CompileGraph(Model *model) {
|
|||
}
|
||||
|
||||
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const {
|
||||
std::vector<mindspore::tensor::MSTensor *> ret;
|
||||
for (auto &iter : this->input_map_) {
|
||||
auto &node_input_tensors = iter.second;
|
||||
for (auto tensor : node_input_tensors) {
|
||||
if (!IsContain(ret, tensor)) {
|
||||
ret.emplace_back(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
return this->input_vec_;
|
||||
}
|
||||
|
||||
int LiteSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) {
|
||||
|
|
|
@ -57,13 +57,15 @@ class LiteSession : public session::LiteSession {
|
|||
int ConvertTensors(const lite::Model *model);
|
||||
|
||||
void InitGraphInOutTensors(const lite::Model *model);
|
||||
|
||||
// init this->inputs_
|
||||
void InitGraphInputTensors(const lite::Model *model);
|
||||
|
||||
// init this->input_vec_
|
||||
void InitGraphInputMSTensors(const lite::Model *model);
|
||||
// init this->outputs_
|
||||
void InitGraphOutputTensors(const lite::Model *model);
|
||||
|
||||
// init this->input_map_
|
||||
void InitGraphInputMap(const lite::Model *model);
|
||||
|
||||
// init this->output_map_
|
||||
void InitGraphOutputMap(const lite::Model *model);
|
||||
|
||||
protected:
|
||||
|
@ -74,6 +76,8 @@ class LiteSession : public session::LiteSession {
|
|||
std::vector<tensor::Tensor *> inputs_;
|
||||
// graph output tensors
|
||||
std::vector<tensor::Tensor *> outputs_;
|
||||
// graph input MSTensors
|
||||
std::vector<mindspore::tensor::MSTensor *> input_vec_;
|
||||
// graph input node name -- input tensors
|
||||
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> input_map_;
|
||||
// graph output node name -- output tensors
|
||||
|
|
|
@ -49,7 +49,8 @@ int Benchmark::GenerateInputData() {
|
|||
auto tensorByteSize = tensor->Size();
|
||||
auto status = GenerateRandomData(tensorByteSize, inputData);
|
||||
if (status != 0) {
|
||||
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed %d" << status;
|
||||
std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
|
||||
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status;
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -60,12 +61,14 @@ int Benchmark::LoadInput() {
|
|||
if (_flags->inDataPath.empty()) {
|
||||
auto status = GenerateInputData();
|
||||
if (status != 0) {
|
||||
std::cerr << "Generate input data error " << status << std::endl;
|
||||
MS_LOG(ERROR) << "Generate input data error " << status;
|
||||
return status;
|
||||
}
|
||||
} else {
|
||||
auto status = ReadInputFile();
|
||||
if (status != 0) {
|
||||
std::cerr << "ReadInputFile error, " << status << std::endl;
|
||||
MS_LOG(ERROR) << "ReadInputFile error, " << status;
|
||||
return status;
|
||||
}
|
||||
|
@ -97,6 +100,7 @@ int Benchmark::ReadInputFile() {
|
|||
char *binBuf = ReadFile(_flags->input_data_list[i].c_str(), &size);
|
||||
auto tensorDataSize = cur_tensor->Size();
|
||||
if (size != tensorDataSize) {
|
||||
std::cerr << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size << std::endl;
|
||||
MS_LOG(ERROR) << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -113,11 +117,13 @@ int Benchmark::ReadCalibData() {
|
|||
// read calib data
|
||||
std::ifstream inFile(calibDataPath);
|
||||
if (!inFile.good()) {
|
||||
std::cerr << "file: " << calibDataPath << " is not exist" << std::endl;
|
||||
MS_LOG(ERROR) << "file: " << calibDataPath << " is not exist";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!inFile.is_open()) {
|
||||
std::cerr << "file: " << calibDataPath << " open failed" << std::endl;
|
||||
MS_LOG(ERROR) << "file: " << calibDataPath << " open failed";
|
||||
inFile.close();
|
||||
return RET_ERROR;
|
||||
|
@ -181,6 +187,7 @@ float Benchmark::CompareData(const std::string &nodeName, std::vector<int> msSha
|
|||
oss << dim << ",";
|
||||
}
|
||||
oss << ") are different";
|
||||
std::cerr << oss.str() << std::endl;
|
||||
MS_LOG(ERROR) << "%s", oss.str().c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -193,6 +200,7 @@ float Benchmark::CompareData(const std::string &nodeName, std::vector<int> msSha
|
|||
}
|
||||
|
||||
if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
|
||||
std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
|
||||
MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -524,6 +532,13 @@ int Benchmark::Init() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
Benchmark::~Benchmark() {
|
||||
for (auto iter : this->calibData) {
|
||||
delete (iter.second);
|
||||
}
|
||||
this->calibData.clear();
|
||||
}
|
||||
|
||||
int RunBenchmark(int argc, const char **argv) {
|
||||
BenchmarkFlags flags;
|
||||
Option<std::string> err = flags.ParseFlags(argc, argv);
|
||||
|
|
|
@ -104,7 +104,7 @@ class MS_API Benchmark {
|
|||
public:
|
||||
explicit Benchmark(BenchmarkFlags *flags) : _flags(flags) {}
|
||||
|
||||
virtual ~Benchmark() = default;
|
||||
virtual ~Benchmark();
|
||||
|
||||
int Init();
|
||||
int RunBenchmark(const std::string &deviceType = "NPU");
|
||||
|
|
Loading…
Reference in New Issue