!15249 [MS_LITE] fix infershape datatype

From: @YeFeng_24
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-04-16 09:44:00 +08:00 committed by Gitee
commit de0b9e5506
6 changed files with 39 additions and 11 deletions

View File

@ -35,6 +35,29 @@ int MergeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t o
return NNACL_OK;
}
void MergeDataTypeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size) {
for (size_t i = 0; i < outputs_size; i++) {
if (inputs[i]->data_type_ == kObjectTypeTensorType) {
TensorListC *input_tensor_list = (TensorListC *)inputs[i];
if (input_tensor_list->tensors_data_type_ != kTypeUnknown) {
outputs[i] = inputs[i];
inputs[i] = NULL;
} else {
outputs[i] = inputs[i + outputs_size];
inputs[i + outputs_size] = NULL;
}
} else {
if (inputs[i]->data_type_ != kTypeUnknown) {
outputs[i] = inputs[i];
inputs[i] = NULL;
} else {
outputs[i] = inputs[i + outputs_size];
inputs[i + outputs_size] = NULL;
}
}
}
}
int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
@ -49,6 +72,7 @@ int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
#endif
if (!parameter->infer_flag_) {
MergeDataTypeInfer((struct TensorC **)inputs, inputs_size, outputs, outputs_size);
return NNACL_INFER_INVALID;
}

View File

@ -31,10 +31,6 @@ int SwitchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
}
#endif
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
for (size_t i = 0; i < outputs_size / 2; i++) {
outputs[i] = (TensorC *)inputs[i + 1];
if (inputs[i + 1]->data_type_ == kObjectTypeTensorType) {
@ -63,7 +59,9 @@ int SwitchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
}
*((const TensorC **)inputs + i + 1) = NULL;
}
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
return NNACL_OK;
}

View File

@ -27,13 +27,14 @@ int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_s
#endif
TensorListC *output = (TensorListC *)(outputs[0]);
const TensorC *input0 = inputs[0];
output->data_type_ = kObjectTypeTensorType;
output->format_ = Format_NHWC;
output->tensors_data_type_ = input0->data_type_;
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
const TensorC *input0 = inputs[0];
if (input0->shape_size_ < 1) {
return NNACL_ERR;

View File

@ -49,6 +49,7 @@ int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size
TensorListC *output0 = (TensorListC *)(outputs[0]);
output0->data_type_ = input0->data_type_;
output0->format_ = input0->format_;
output0->tensors_data_type_ = value_tensor->data_type_;
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;

View File

@ -25,11 +25,13 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size,
return check_ret;
}
#endif
TensorC *output = outputs[0];
TensorListC *input0 = (TensorListC *)(inputs[0]);
output->data_type_ = input0->tensors_data_type_;
output->format_ = input0->format_;
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
TensorListC *input0 = (TensorListC *)(inputs[0]);
if (input0->element_num_ == 0) {
return NNACL_ERR;
}
@ -63,9 +65,6 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size,
}
}
}
TensorC *output = outputs[0];
output->data_type_ = input0->tensors_data_type_;
output->format_ = input0->format_;
ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_);
SetShapeArray(output, output_shape, output_shape_size);
return NNACL_OK;

View File

@ -67,6 +67,11 @@ void SetOutputTensorAttr(const std::vector<TensorC *> &tensors_in, std::vector<l
tensors_out->at(i)->set_format(static_cast<schema::Format>(tensors_in[i]->format_));
tensors_out->at(i)->set_data_type(static_cast<TypeId>(tensors_in[i]->data_type_));
tensors_out->at(i)->set_shape({tensors_in[i]->shape_, tensors_in[i]->shape_ + tensors_in[i]->shape_size_});
if (tensors_in.at(i)->data_type_ == TypeIdC::kObjectTypeTensorType) {
auto tensor_list_in = reinterpret_cast<TensorListC *>(tensors_in.at(i));
auto tensor_list_out = reinterpret_cast<TensorList *>(tensors_out->at(i));
tensor_list_out->set_tensors_data_type(TypeId(tensor_list_in->tensors_data_type_));
}
}
}
}