forked from mindspore-Ecosystem/mindspore
!49152 Adapt for empty tuple
Merge pull request !49152 from ZPaC/type-match-for-unknown-type
This commit is contained in:
commit
13cfed8045
|
@ -1661,7 +1661,7 @@ std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr,
|
|||
bool mis_match = false;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto dtype = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).dtype;
|
||||
if (kernel_attr.GetInputAttr(i).dtype != dtype) {
|
||||
if (kernel_attr.GetInputAttr(i).dtype != dtype && kernel_attr.GetInputAttr(i).dtype != kTypeUnknown) {
|
||||
mis_match = true;
|
||||
break;
|
||||
}
|
||||
|
@ -1672,7 +1672,7 @@ std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr,
|
|||
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto dtype = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).dtype;
|
||||
if (kernel_attr.GetOutputAttr(i).dtype != dtype) {
|
||||
if (kernel_attr.GetOutputAttr(i).dtype != dtype && kernel_attr.GetOutputAttr(i).dtype != kTypeUnknown) {
|
||||
mis_match = true;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -74,6 +74,9 @@ TypeId KernelTensor::GetDtype() const {
|
|||
if (info.base_->dynamic_len()) {
|
||||
return info.base_->dynamic_len_element_abs()->BuildType()->type_id();
|
||||
}
|
||||
if (info.base_->elements().empty()) {
|
||||
return TypeId::kTypeUnknown;
|
||||
}
|
||||
return info.base_->elements()[0]->BuildType()->type_id();
|
||||
} else if (meta_type_ == kObjectTypeList) {
|
||||
// List
|
||||
|
@ -81,6 +84,9 @@ TypeId KernelTensor::GetDtype() const {
|
|||
if (info.base_->dynamic_len()) {
|
||||
return info.base_->dynamic_len_element_abs()->BuildType()->type_id();
|
||||
}
|
||||
if (info.base_->elements().empty()) {
|
||||
return TypeId::kTypeUnknown;
|
||||
}
|
||||
return info.base_->elements()[0]->BuildType()->type_id();
|
||||
} else {
|
||||
// Tensor
|
||||
|
|
|
@ -136,7 +136,8 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|||
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
||||
const abstract::BaseShapePtr origin_shape =
|
||||
common::AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
|
||||
if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
||||
TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
|
||||
if (origin_type != device_type && origin_type != kTypeUnknown && device_type != kTypeUnknown) {
|
||||
auto cast = AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape);
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
cast->set_scope(cnode->scope());
|
||||
|
|
|
@ -721,7 +721,7 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
|
|||
return tuple_ptr->dynamic_element_type()->type_id();
|
||||
}
|
||||
if (tuple_ptr->size() == 0) {
|
||||
return tuple_ptr->type_id();
|
||||
return kTypeUnknown;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
||||
if (output_idx >= tuple_ptr->size()) {
|
||||
|
|
Loading…
Reference in New Issue