!17479 [MS][LITE]fix bug of litesession and tensorlist

From: @mengyuanli
Reviewed-by: @zhanghaibo5,@jpc_chenjianping
Signed-off-by: @jpc_chenjianping
This commit is contained in:
mindspore-ci-bot 2021-06-02 15:47:36 +08:00 committed by Gitee
commit 69596b8d4a
6 changed files with 23 additions and 8 deletions

View File

@ -18,7 +18,7 @@
#include "nnacl/infer/infer_register.h"
int PreJudge(const TensorC *get_index, TensorListC *input0, const TensorC *value_tensor) {
if (get_index->data_ == NULL || value_tensor->data_ == NULL) {
if (get_index->data_ == NULL) {
return NNACL_INFER_INVALID;
}

View File

@ -376,7 +376,7 @@ void LiteSession::FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kern
auto inputs = kernel->in_tensors();
for (auto *tensor : inputs) {
MS_ASSERT(tensor != nullptr);
if (!tensor->IsConst() || tensor->init_ref_count() != 1) {
if (!tensor->IsConst()) {
continue;
}
tensor->FreeData();

View File

@ -42,9 +42,9 @@ int TensorListGetItemCPUKernel::Run() {
dtype_ = input0->tensors_data_type();
MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr);
index_ = reinterpret_cast<int *>(in_tensors_.at(1)->data_c())[0];
int dim0 = input0->ElementsNum() - 1;
if (index_ < 0 || index_ > dim0) {
MS_LOG(ERROR) << "index tensor:[" << index_ << "] must be in [0, " << dim0 << "]!";
int dim0 = input0->ElementsNum();
if (index_ < 0 || index_ >= dim0) {
MS_LOG(ERROR) << "index tensor:[" << index_ << "] must be in [0, " << dim0 << ")!";
return RET_ERROR;
}
auto src_ptr = input0->GetTensor(index_);

View File

@ -48,7 +48,13 @@ STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<sch
STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs);
inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) { return cNodeT.primitive->value.type; }
inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) {
if (cNodeT.primitive != nullptr) {
return cNodeT.primitive->value.type;
} else {
return schema::PrimitiveType_NONE;
}
}
inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) {
return schema::EnumNamePrimitiveType(GetCNodeTType(cNodeT));

View File

@ -35,6 +35,14 @@ bool SubgraphTensorPass::IsUsing(schema::MetaGraphT *graph, const uint32_t &tens
return true;
}
}
for (const auto &subgraph : graph->subGraph) {
if (IsContain<uint32_t>(subgraph->inputIndices, tensor_idx)) {
return true;
}
if (IsContain<uint32_t>(subgraph->outputIndices, tensor_idx)) {
return true;
}
}
return false;
}

View File

@ -31,7 +31,8 @@ STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) {
std::vector<size_t> sinked_tensor_idxes;
// put all const tensor index into sinked_tensor_idxes
for (size_t i = 0; i < graph->allTensors.size(); i++) {
if (graph->allTensors.at(i)->nodeType == NodeType_ValueNode) {
if (graph->allTensors.at(i)->nodeType == NodeType_ValueNode ||
graph->allTensors.at(i)->nodeType == NodeType_Parameter) {
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), i);
}
}
@ -80,7 +81,7 @@ STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) {
bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
const std::vector<size_t> &sinked_tensor_idxes) {
MS_ASSERT(node != nullptr);
if (node->primitive->value.type == schema::PrimitiveType_Merge) {
if (node->primitive && node->primitive->value.type == schema::PrimitiveType_Merge) {
auto node_input_index = node->inputIndex;
MS_ASSERT(node_input_index.size() % 2 == 0);
return std::all_of(node_input_index.begin(), node_input_index.begin() + node_input_index.size() / 2,