forked from mindspore-Ecosystem/mindspore
!11716 [MS][LITE]fix bug of tensorliset when elementshape is nullptr
From: @mengyuanli Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiang
This commit is contained in:
commit
4719a66be6
|
@ -131,8 +131,13 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
|
|||
}
|
||||
|
||||
output0->set_max_elements_num(input0->max_elements_num());
|
||||
output0->set_element_shape(input0->element_shape());
|
||||
|
||||
if (input0->tensors().empty() && input0->element_shape().empty() && index == 0) {
|
||||
input0->set_element_shape(value_tensor->shape());
|
||||
output0->set_element_shape(value_tensor->shape());
|
||||
} else {
|
||||
output0->set_element_shape(input0->element_shape());
|
||||
}
|
||||
std::vector<std::vector<int> > out_shape;
|
||||
if (index == 0 && input0->tensors().size() == 0) { // uninitialized tensorlist
|
||||
out_shape.push_back(value_tensor->shape());
|
||||
|
|
|
@ -89,7 +89,9 @@ int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::Tens
|
|||
MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (dst_tensor->element_shape() != src_tensor->element_shape()) {
|
||||
if (dst_tensor->element_shape().empty()) {
|
||||
dst_tensor->set_element_shape(src_tensor->element_shape());
|
||||
} else if (dst_tensor->element_shape() != src_tensor->element_shape()) {
|
||||
MS_LOG(ERROR) << "input tensorlist and output tensorlist element shape is incompatible";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ int TensorListGetItemCPUKernel::Run() {
|
|||
} else {
|
||||
// reset 0 and dtype = dtype_
|
||||
// TODO(DT_VARIANT): dtype = DT_VARIANT is not handle
|
||||
auto out_data = out_tensors_[0]->MutableData();
|
||||
auto out_data = out_tensors_[0]->data_c();
|
||||
if (out_data == nullptr) {
|
||||
MS_LOG(ERROR) << "data of out_tensors_[0] is nullptr";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -30,6 +30,22 @@ namespace mindspore::kernel {
|
|||
|
||||
int TensorListSetItemCPUKernel::Init() { return RET_OK; }
|
||||
|
||||
int TensorListSetItemCPUKernel::CheckParam() {
|
||||
if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) {
|
||||
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
|
||||
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_[1]->ElementsNum() != 1) {
|
||||
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) {
|
||||
output0_ = reinterpret_cast<lite::TensorList *>(out_tensors_[0]);
|
||||
int new_tensors_size = origin_size + 1;
|
||||
|
@ -46,19 +62,13 @@ int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) {
|
|||
|
||||
int TensorListSetItemCPUKernel::Run() {
|
||||
input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
|
||||
if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) {
|
||||
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type();
|
||||
|
||||
if (CheckParam() != RET_OK) {
|
||||
MS_LOG(ERROR) << "check param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int dim0 = input0_->ElementsNum() - 1;
|
||||
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
|
||||
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_[1]->ElementsNum() != 1) {
|
||||
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
index_ = reinterpret_cast<int *>(in_tensors_[1]->data_c())[0];
|
||||
if (index_ < 0 || index_ > dim0) {
|
||||
if (IncrementOutputSize(output0_->shape()[0]) != RET_OK) {
|
||||
|
@ -81,6 +91,10 @@ int TensorListSetItemCPUKernel::Run() {
|
|||
}
|
||||
}
|
||||
// copy each tensor in tensors_
|
||||
if (input0_->tensors().empty() && index_ == 0) {
|
||||
input0_->set_element_shape(input2_->shape());
|
||||
output0_->set_element_shape(input2_->shape());
|
||||
}
|
||||
for (int i = 0; i < output0_->ElementsNum(); ++i) {
|
||||
if (i == index_) {
|
||||
auto dst = output0_->GetTensor(i);
|
||||
|
|
|
@ -39,6 +39,7 @@ class TensorListSetItemCPUKernel : public LiteKernel {
|
|||
int IncrementOutputSize(int origin_size);
|
||||
|
||||
private:
|
||||
int CheckParam();
|
||||
lite::TensorList *input0_ = nullptr;
|
||||
lite::Tensor *input2_ = nullptr;
|
||||
lite::TensorList *output0_ = nullptr;
|
||||
|
|
|
@ -240,6 +240,9 @@ Tensor *TensorList::GetTensor(int index) {
|
|||
}
|
||||
|
||||
bool TensorList::IsCompatibleShape(const std::vector<int> &shape) {
|
||||
if (this->tensors_.empty() && this->element_shape_.empty()) {
|
||||
return true;
|
||||
}
|
||||
if (shape.size() != this->element_shape_.size()) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -439,6 +439,9 @@ STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, const FuncG
|
|||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
copy_param_value->set_tensor_shape(param_value->tensor_shape());
|
||||
copy_param_value->set_format(param_value->format());
|
||||
copy_param_value->set_tensor_type(param_value->tensor_type());
|
||||
copy_param_value->SetTensorData(copy_data, param_value->tensor_size());
|
||||
ext_subgraph_input->set_default_param(copy_param_value);
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue