!10188 [ms][lite][CPU] tensorlist stack bug
From: @lzkcode Reviewed-by: Signed-off-by:
This commit is contained in:
commit
ccd83ffa42
|
@ -95,7 +95,7 @@ int TensorListStack::UnPackToFlatBuilder(const schema::Primitive *primitive, fla
|
|||
MS_LOG(ERROR) << "value_as_TensorListStack return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateTensorListStack(*fbb, attr->elementDType(), attr->numElements());
|
||||
auto val_offset = schema::CreateTensorListStack(*fbb, attr->numElements(), attr->elementDType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListStack, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
|
@ -159,9 +159,8 @@ int TensorListStack::InferShape(std::vector<lite::Tensor *> inputs_, std::vector
|
|||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input0->tensors_data_type());
|
||||
output->set_shape(std::vector<int>(
|
||||
1,
|
||||
input0->ElementsNum() * std::accumulate(output_shape_.begin(), output_shape_.end(), 1LL, std::multiplies<int>())));
|
||||
output_shape_.insert(output_shape_.begin(), input0->ElementsNum());
|
||||
output->set_shape(output_shape_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ int TensorListSetItemCPUKernel::Init() { return RET_OK; }
|
|||
|
||||
int TensorListSetItemCPUKernel::Run() {
|
||||
input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
|
||||
if (dtype_ != input0_->data_type()) {
|
||||
if (dtype_ != input0_->tensors_data_type()) {
|
||||
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -42,6 +42,16 @@ int TensorListStackCPUKernel::CheckParam() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
num_element_ = input0_->ElementsNum();
|
||||
if (output0_->shape().size() < 1) {
|
||||
MS_LOG(ERROR) << "out_tensors_[0].shape().size():" << output0_->shape().size()
|
||||
<< " must be greater than or equal to 1!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int dim0 = output0_->shape()[0];
|
||||
if (dim0 != num_element_) {
|
||||
MS_LOG(ERROR) << "out_tensors_[0].shape()[0] must be:" << num_element_ << ", but now is:" << dim0;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -50,16 +60,7 @@ int TensorListStackCPUKernel::Init() {
|
|||
MS_ASSERT(input0_ != nullptr);
|
||||
output0_ = out_tensors_[0];
|
||||
MS_ASSERT(output0_ != nullptr);
|
||||
if (output0_->shape().size() != 2) {
|
||||
MS_LOG(ERROR) << "out_tensors_[0].shape().size():" << output0_->shape().size() << " must be equal to 2!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int dim0 = output0_->shape()[0];
|
||||
if (dim0 != 1) { // dim0 must be 1
|
||||
MS_LOG(ERROR) << "out_tensors_[0].shape()[0] must be 1, but now is:" << dim0;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return CheckParam();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool TensorListStackCPUKernel::IsFullyDefined(const std::vector<int> &shape) const {
|
||||
|
@ -129,26 +130,22 @@ int TensorListStackCPUKernel::MergeSubShape(const std::vector<int> &shape) {
|
|||
}
|
||||
|
||||
int TensorListStackCPUKernel::Run() {
|
||||
if (CheckParam() != RET_OK) {
|
||||
MS_LOG(ERROR) << "CheckParam failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (output0_->ElementsNum() == 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
size_t in_ele_num = 0;
|
||||
for (int i = 0; i < num_element_; ++i) {
|
||||
auto tensor = input0_->GetTensorIndex(i);
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
if (tensor->data_type() == kTypeUnknown) {
|
||||
if (TypeUnknownSize == 0) {
|
||||
TypeUnknownSize = MergeElementShape();
|
||||
}
|
||||
in_ele_num += TypeUnknownSize;
|
||||
} else {
|
||||
in_ele_num += std::accumulate(tensor->shape().begin(), tensor->shape().end(), 1LL, std::multiplies<int>());
|
||||
}
|
||||
auto ret = MergeElementShape();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MergeElementShape failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t in_ele_num = num_element_ * TypeUnknownSize;
|
||||
size_t out_ele_num = output0_->ElementsNum();
|
||||
if (in_ele_num > out_ele_num) {
|
||||
MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num
|
||||
<< "must be greater than or equal to in_ele_num:" << in_ele_num;
|
||||
if (in_ele_num != out_ele_num) {
|
||||
MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num << "must be equal to in_ele_num:" << in_ele_num;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto out_ptr = reinterpret_cast<float *>(output0_->MutableData());
|
||||
|
|
Loading…
Reference in New Issue