diff --git a/mindspore/ccsrc/kernel/common_utils.cc b/mindspore/ccsrc/kernel/common_utils.cc index ee3b192a5ac..a162186fd3c 100644 --- a/mindspore/ccsrc/kernel/common_utils.cc +++ b/mindspore/ccsrc/kernel/common_utils.cc @@ -1661,7 +1661,7 @@ std::pair MatchKernelAttr(const KernelAttr &kernel_attr, bool mis_match = false; for (size_t i = 0; i < input_num; ++i) { auto dtype = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).dtype; - if (kernel_attr.GetInputAttr(i).dtype != dtype) { + if (kernel_attr.GetInputAttr(i).dtype != dtype && kernel_attr.GetInputAttr(i).dtype != kTypeUnknown) { mis_match = true; break; } @@ -1672,7 +1672,7 @@ std::pair MatchKernelAttr(const KernelAttr &kernel_attr, for (size_t i = 0; i < output_num; ++i) { auto dtype = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).dtype; - if (kernel_attr.GetOutputAttr(i).dtype != dtype) { + if (kernel_attr.GetOutputAttr(i).dtype != dtype && kernel_attr.GetOutputAttr(i).dtype != kTypeUnknown) { mis_match = true; break; } diff --git a/mindspore/ccsrc/kernel/kernel.cc b/mindspore/ccsrc/kernel/kernel.cc index 2ec52e7b128..3ad4ab61d48 100644 --- a/mindspore/ccsrc/kernel/kernel.cc +++ b/mindspore/ccsrc/kernel/kernel.cc @@ -74,6 +74,9 @@ TypeId KernelTensor::GetDtype() const { if (info.base_->dynamic_len()) { return info.base_->dynamic_len_element_abs()->BuildType()->type_id(); } + if (info.base_->elements().empty()) { + return TypeId::kTypeUnknown; + } return info.base_->elements()[0]->BuildType()->type_id(); } else if (meta_type_ == kObjectTypeList) { // List @@ -81,6 +84,9 @@ TypeId KernelTensor::GetDtype() const { if (info.base_->dynamic_len()) { return info.base_->dynamic_len_element_abs()->BuildType()->type_id(); } + if (info.base_->elements().empty()) { + return TypeId::kTypeUnknown; + } return info.base_->elements()[0]->BuildType()->type_id(); } else { // Tensor diff --git a/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_cast_cpu.cc b/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_cast_cpu.cc index 4d6e251fc1e..e2d9a974d7e 100644 --- a/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_cast_cpu.cc +++ b/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_cast_cpu.cc @@ -136,7 +136,8 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); const abstract::BaseShapePtr origin_shape = common::AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second); - if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) { + TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); + if (origin_type != device_type && origin_type != kTypeUnknown && device_type != kTypeUnknown) { auto cast = AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape); MS_EXCEPTION_IF_NULL(cast); cast->set_scope(cnode->scope()); diff --git a/mindspore/ccsrc/utils/anfalgo.cc b/mindspore/ccsrc/utils/anfalgo.cc index db63453a44e..d385d648971 100644 --- a/mindspore/ccsrc/utils/anfalgo.cc +++ b/mindspore/ccsrc/utils/anfalgo.cc @@ -721,7 +721,7 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) { return tuple_ptr->dynamic_element_type()->type_id(); } if (tuple_ptr->size() == 0) { - return tuple_ptr->type_id(); + return kTypeUnknown; } MS_EXCEPTION_IF_NULL(tuple_ptr); if (output_idx >= tuple_ptr->size()) {