forked from mindspore-Ecosystem/mindspore
!15249 [MS_LITE] fix infershape datatype
From: @YeFeng_24 Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiang
This commit is contained in:
commit
de0b9e5506
|
@ -35,6 +35,29 @@ int MergeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t o
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void MergeDataTypeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size) {
|
||||
for (size_t i = 0; i < outputs_size; i++) {
|
||||
if (inputs[i]->data_type_ == kObjectTypeTensorType) {
|
||||
TensorListC *input_tensor_list = (TensorListC *)inputs[i];
|
||||
if (input_tensor_list->tensors_data_type_ != kTypeUnknown) {
|
||||
outputs[i] = inputs[i];
|
||||
inputs[i] = NULL;
|
||||
} else {
|
||||
outputs[i] = inputs[i + outputs_size];
|
||||
inputs[i + outputs_size] = NULL;
|
||||
}
|
||||
} else {
|
||||
if (inputs[i]->data_type_ != kTypeUnknown) {
|
||||
outputs[i] = inputs[i];
|
||||
inputs[i] = NULL;
|
||||
} else {
|
||||
outputs[i] = inputs[i + outputs_size];
|
||||
inputs[i + outputs_size] = NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
#ifdef Debug
|
||||
|
@ -49,6 +72,7 @@ int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
|
|||
#endif
|
||||
|
||||
if (!parameter->infer_flag_) {
|
||||
MergeDataTypeInfer((struct TensorC **)inputs, inputs_size, outputs, outputs_size);
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
|
|
|
@ -31,10 +31,6 @@ int SwitchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
|||
}
|
||||
#endif
|
||||
|
||||
if (!parameter->infer_flag_) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < outputs_size / 2; i++) {
|
||||
outputs[i] = (TensorC *)inputs[i + 1];
|
||||
if (inputs[i + 1]->data_type_ == kObjectTypeTensorType) {
|
||||
|
@ -63,7 +59,9 @@ int SwitchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
|||
}
|
||||
*((const TensorC **)inputs + i + 1) = NULL;
|
||||
}
|
||||
|
||||
if (!parameter->infer_flag_) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -27,13 +27,14 @@ int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_s
|
|||
#endif
|
||||
|
||||
TensorListC *output = (TensorListC *)(outputs[0]);
|
||||
const TensorC *input0 = inputs[0];
|
||||
output->data_type_ = kObjectTypeTensorType;
|
||||
output->format_ = Format_NHWC;
|
||||
output->tensors_data_type_ = input0->data_type_;
|
||||
|
||||
if (!parameter->infer_flag_) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
const TensorC *input0 = inputs[0];
|
||||
|
||||
if (input0->shape_size_ < 1) {
|
||||
return NNACL_ERR;
|
||||
|
|
|
@ -49,6 +49,7 @@ int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size
|
|||
TensorListC *output0 = (TensorListC *)(outputs[0]);
|
||||
output0->data_type_ = input0->data_type_;
|
||||
output0->format_ = input0->format_;
|
||||
output0->tensors_data_type_ = value_tensor->data_type_;
|
||||
|
||||
if (!parameter->infer_flag_) {
|
||||
return NNACL_INFER_INVALID;
|
||||
|
|
|
@ -25,11 +25,13 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size,
|
|||
return check_ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
TensorC *output = outputs[0];
|
||||
TensorListC *input0 = (TensorListC *)(inputs[0]);
|
||||
output->data_type_ = input0->tensors_data_type_;
|
||||
output->format_ = input0->format_;
|
||||
if (!parameter->infer_flag_) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
TensorListC *input0 = (TensorListC *)(inputs[0]);
|
||||
if (input0->element_num_ == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
|
@ -63,9 +65,6 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size,
|
|||
}
|
||||
}
|
||||
}
|
||||
TensorC *output = outputs[0];
|
||||
output->data_type_ = input0->tensors_data_type_;
|
||||
output->format_ = input0->format_;
|
||||
ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_);
|
||||
SetShapeArray(output, output_shape, output_shape_size);
|
||||
return NNACL_OK;
|
||||
|
|
|
@ -67,6 +67,11 @@ void SetOutputTensorAttr(const std::vector<TensorC *> &tensors_in, std::vector<l
|
|||
tensors_out->at(i)->set_format(static_cast<schema::Format>(tensors_in[i]->format_));
|
||||
tensors_out->at(i)->set_data_type(static_cast<TypeId>(tensors_in[i]->data_type_));
|
||||
tensors_out->at(i)->set_shape({tensors_in[i]->shape_, tensors_in[i]->shape_ + tensors_in[i]->shape_size_});
|
||||
if (tensors_in.at(i)->data_type_ == TypeIdC::kObjectTypeTensorType) {
|
||||
auto tensor_list_in = reinterpret_cast<TensorListC *>(tensors_in.at(i));
|
||||
auto tensor_list_out = reinterpret_cast<TensorList *>(tensors_out->at(i));
|
||||
tensor_list_out->set_tensors_data_type(TypeId(tensor_list_in->tensors_data_type_));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue