return unorderd_map rather than vector for LiteSession::GetOutputs
This commit is contained in:
parent
fcdc9c40d0
commit
551cdfe2f5
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/model.h"
|
||||
#include "include/context.h"
|
||||
|
@ -85,8 +86,8 @@ class MS_API LiteSession {
|
|||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model.
|
||||
///
|
||||
/// \return A vector of MindSpore Lite MSTensor.
|
||||
virtual std::vector<tensor::MSTensor *> GetOutputs() const = 0;
|
||||
/// \return A map of output node name and MindSpore Lite MSTensor.
|
||||
virtual std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const = 0;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model by node name.
|
||||
///
|
||||
|
|
|
@ -177,17 +177,8 @@ int LiteSession::RunGraph(const session::KernelCallBack &before, const session::
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputs() const {
|
||||
std::vector<mindspore::tensor::MSTensor *> ret;
|
||||
for (auto &iter : this->output_map) {
|
||||
auto &node_output_tensors = iter.second;
|
||||
for (auto tensor : node_output_tensors) {
|
||||
if (!IsContain(ret, tensor)) {
|
||||
ret.emplace_back(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> LiteSession::GetOutputs() const {
|
||||
return this->output_map;
|
||||
}
|
||||
|
||||
int LiteSession::Init(Context *context) {
|
||||
|
|
|
@ -49,7 +49,7 @@ class LiteSession : public session::LiteSession {
|
|||
int RunGraph(const session::KernelCallBack &before = nullptr,
|
||||
const session::KernelCallBack &after = nullptr) override;
|
||||
|
||||
std::vector<mindspore::tensor::MSTensor *> GetOutputs() const override;
|
||||
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const override;
|
||||
|
||||
std::vector<mindspore::tensor::MSTensor *> GetOutputsByName(const std::string &name) const override;
|
||||
|
||||
|
|
|
@ -130,7 +130,8 @@ TEST_F(InferTest, TestConvNode) {
|
|||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
auto outputs = session->GetOutputs();
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
auto outTensor = outputs.front();
|
||||
ASSERT_EQ(outputs.begin()->second.size(), 1);
|
||||
auto outTensor = outputs.begin()->second.front();
|
||||
ASSERT_NE(nullptr, outTensor);
|
||||
ASSERT_EQ(28 * 28 * 32, outTensor->ElementsNum());
|
||||
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
|
||||
|
@ -220,7 +221,8 @@ TEST_F(InferTest, TestAddNode) {
|
|||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
auto outputs = session->GetOutputs();
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
auto outTensor = outputs.front();
|
||||
ASSERT_EQ(outputs.begin()->second.size(), 1);
|
||||
auto outTensor = outputs.begin()->second.front();
|
||||
ASSERT_NE(nullptr, outTensor);
|
||||
ASSERT_EQ(28 * 28 * 3, outTensor->ElementsNum());
|
||||
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
|
||||
|
|
Loading…
Reference in New Issue