add GetOutputsByNodeName
This commit is contained in:
parent
08ed63ff91
commit
84c1705319
|
@ -49,6 +49,7 @@ class MS_API Model {
|
|||
std::vector<MSTensor> GetOutputs();
|
||||
inline std::vector<std::string> GetOutputTensorNames();
|
||||
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
|
||||
inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &tensor_name);
|
||||
|
||||
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
|
||||
|
||||
|
@ -71,5 +72,9 @@ std::vector<std::string> Model::GetOutputTensorNames() { return VectorCharToStri
|
|||
MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
|
||||
return GetOutputByTensorName(StringToChar(tensor_name));
|
||||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &tensor_name) {
|
||||
return GetOutputsByNodeName(StringToChar(tensor_name));
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
||||
|
|
|
@ -45,12 +45,12 @@ MSTensor::Impl *MSTensor::Impl::CreateTensorImpl(const std::string &name, enum D
|
|||
MS_LOG(ERROR) << "Failed to allocate lite tensor.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = new (std::nothrow) Impl();
|
||||
auto impl = new (std::nothrow) Impl(lite_tensor);
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to allocate tensor impl.";
|
||||
return nullptr;
|
||||
}
|
||||
impl->set_lite_tensor(lite_tensor);
|
||||
impl->set_from_session(false);
|
||||
return impl;
|
||||
}
|
||||
|
||||
|
|
|
@ -149,6 +149,8 @@ class MSTensor::Impl {
|
|||
|
||||
void set_own_data(bool own_data) { own_data_ = own_data; }
|
||||
|
||||
void set_from_session(bool from_session) { from_session_ = from_session; }
|
||||
|
||||
private:
|
||||
tensor::MSTensor *lite_tensor_ = nullptr;
|
||||
std::string tensor_name_ = "";
|
||||
|
|
Loading…
Reference in New Issue