diff --git a/mindspore/lite/src/ops/tensorlist_setitem.cc b/mindspore/lite/src/ops/tensorlist_setitem.cc index 82bcd95c163..7891d7253cc 100644 --- a/mindspore/lite/src/ops/tensorlist_setitem.cc +++ b/mindspore/lite/src/ops/tensorlist_setitem.cc @@ -131,8 +131,13 @@ int TensorListSetItem::InferShape(std::vector inputs_, std::vect } output0->set_max_elements_num(input0->max_elements_num()); - output0->set_element_shape(input0->element_shape()); + if (input0->tensors().empty() && input0->element_shape().empty() && index == 0) { + input0->set_element_shape(value_tensor->shape()); + output0->set_element_shape(value_tensor->shape()); + } else { + output0->set_element_shape(input0->element_shape()); + } std::vector > out_shape; if (index == 0 && input0->tensors().size() == 0) { // uninitialized tensorlist out_shape.push_back(value_tensor->shape()); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc index b42e00e2fbd..9adc022b977 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc @@ -89,7 +89,9 @@ int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::Tens MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible"; return RET_ERROR; } - if (dst_tensor->element_shape() != src_tensor->element_shape()) { + if (dst_tensor->element_shape().empty()) { + dst_tensor->set_element_shape(src_tensor->element_shape()); + } else if (dst_tensor->element_shape() != src_tensor->element_shape()) { MS_LOG(ERROR) << "input tensorlist and output tensorlist element shape is incompatible"; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc index 82e741bf262..6b4618fe766 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc @@ -68,7 +68,7 @@ int TensorListGetItemCPUKernel::Run() { } else { // reset 0 and dtype = dtype_ // TODO(DT_VARIANT): dtype = DT_VARIANT is not handle - auto out_data = out_tensors_[0]->MutableData(); + auto out_data = out_tensors_[0]->data_c(); if (out_data == nullptr) { MS_LOG(ERROR) << "data of out_tensors_[0] is nullptr"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc index d7d2ecaae60..9ce67e1dd51 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc @@ -30,6 +30,22 @@ namespace mindspore::kernel { int TensorListSetItemCPUKernel::Init() { return RET_OK; } +int TensorListSetItemCPUKernel::CheckParam() { + if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) { + MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type(); + return RET_ERROR; + } + if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) { + MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int"; + return RET_ERROR; + } + if (in_tensors_[1]->ElementsNum() != 1) { + MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!"; + return RET_ERROR; + } + return RET_OK; +} + int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) { output0_ = reinterpret_cast(out_tensors_[0]); int new_tensors_size = origin_size + 1; @@ -46,19 +62,13 @@ int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) { int TensorListSetItemCPUKernel::Run() { input0_ = reinterpret_cast(in_tensors_[0]); - if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) { - MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type(); + + if (CheckParam() != RET_OK) { + MS_LOG(ERROR) << "check param failed."; return RET_ERROR; } + int dim0 = input0_->ElementsNum() - 1; - if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) { - MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int"; - return RET_ERROR; - } - if (in_tensors_[1]->ElementsNum() != 1) { - MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!"; - return RET_ERROR; - } index_ = reinterpret_cast(in_tensors_[1]->data_c())[0]; if (index_ < 0 || index_ > dim0) { if (IncrementOutputSize(output0_->shape()[0]) != RET_OK) { @@ -81,6 +91,10 @@ int TensorListSetItemCPUKernel::Run() { } } // copy each tensor in tensors_ + if (input0_->tensors().empty() && index_ == 0) { + input0_->set_element_shape(input2_->shape()); + output0_->set_element_shape(input2_->shape()); + } for (int i = 0; i < output0_->ElementsNum(); ++i) { if (i == index_) { auto dst = output0_->GetTensor(i); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.h index c4e89fd77f8..1c3ccb8b5e0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.h @@ -39,6 +39,7 @@ class TensorListSetItemCPUKernel : public LiteKernel { int IncrementOutputSize(int origin_size); private: + int CheckParam(); lite::TensorList *input0_ = nullptr; lite::Tensor *input2_ = nullptr; lite::TensorList *output0_ = nullptr; diff --git a/mindspore/lite/src/tensorlist.cc b/mindspore/lite/src/tensorlist.cc index 2301a1fd06d..0875714483b 100644 --- a/mindspore/lite/src/tensorlist.cc +++ b/mindspore/lite/src/tensorlist.cc @@ -240,6 +240,9 @@ Tensor *TensorList::GetTensor(int index) { } bool TensorList::IsCompatibleShape(const std::vector &shape) { + if (this->tensors_.empty() && this->element_shape_.empty()) { + return true; + } if (shape.size() != this->element_shape_.size()) { return false; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index a5353c5ed12..fcdb4069a3d 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -439,6 +439,9 @@ STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, const FuncG MS_LOG(ERROR) << "memcpy error: " << ret; return RET_ERROR; } + copy_param_value->set_tensor_shape(param_value->tensor_shape()); + copy_param_value->set_format(param_value->format()); + copy_param_value->set_tensor_type(param_value->tensor_type()); copy_param_value->SetTensorData(copy_data, param_value->tensor_size()); ext_subgraph_input->set_default_param(copy_param_value); } else {