!49152 Adapt for empty tuple

Merge pull request !49152 from ZPaC/type-match-for-unknown-type
This commit is contained in:
i-robot 2023-02-21 07:11:00 +00:00 committed by Gitee
commit 13cfed8045
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 11 additions and 4 deletions

View File

@ -1661,7 +1661,7 @@ std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr,
bool mis_match = false;
for (size_t i = 0; i < input_num; ++i) {
auto dtype = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).dtype;
if (kernel_attr.GetInputAttr(i).dtype != dtype) {
if (kernel_attr.GetInputAttr(i).dtype != dtype && kernel_attr.GetInputAttr(i).dtype != kTypeUnknown) {
mis_match = true;
break;
}
@ -1672,7 +1672,7 @@ std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr,
for (size_t i = 0; i < output_num; ++i) {
auto dtype = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).dtype;
if (kernel_attr.GetOutputAttr(i).dtype != dtype) {
if (kernel_attr.GetOutputAttr(i).dtype != dtype && kernel_attr.GetOutputAttr(i).dtype != kTypeUnknown) {
mis_match = true;
break;
}

View File

@ -74,6 +74,9 @@ TypeId KernelTensor::GetDtype() const {
if (info.base_->dynamic_len()) {
return info.base_->dynamic_len_element_abs()->BuildType()->type_id();
}
if (info.base_->elements().empty()) {
return TypeId::kTypeUnknown;
}
return info.base_->elements()[0]->BuildType()->type_id();
} else if (meta_type_ == kObjectTypeList) {
// List
@ -81,6 +84,9 @@ TypeId KernelTensor::GetDtype() const {
if (info.base_->dynamic_len()) {
return info.base_->dynamic_len_element_abs()->BuildType()->type_id();
}
if (info.base_->elements().empty()) {
return TypeId::kTypeUnknown;
}
return info.base_->elements()[0]->BuildType()->type_id();
} else {
// Tensor

View File

@ -136,7 +136,8 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
const abstract::BaseShapePtr origin_shape =
common::AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
if (origin_type != device_type && origin_type != kTypeUnknown && device_type != kTypeUnknown) {
auto cast = AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape);
MS_EXCEPTION_IF_NULL(cast);
cast->set_scope(cnode->scope());

View File

@ -721,7 +721,7 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
return tuple_ptr->dynamic_element_type()->type_id();
}
if (tuple_ptr->size() == 0) {
return tuple_ptr->type_id();
return kTypeUnknown;
}
MS_EXCEPTION_IF_NULL(tuple_ptr);
if (output_idx >= tuple_ptr->size()) {