!34083 TensorScatterUpdate add bool

Merge pull request !34083 from VectorSL/tensorscatterupdate-add-bool
This commit is contained in:
i-robot 2022-05-09 11:54:21 +00:00 committed by Gitee
commit b3455f9508
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 16 additions and 1 deletions

View File

@ -140,6 +140,9 @@ bool ScatterUpdateCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
case kNumberTypeInt64:
LaunchKernel<int64_t>(inputs, outputs);
break;
case kNumberTypeBool:
LaunchKernel<bool>(inputs, outputs);
break;
default:
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dtype of 'input_x' must be float16, float32, float64, int32 or int64, but got "

View File

@ -121,7 +121,19 @@ class TensorScatterUpdateCpuKernelMod : public ScatterUpdateCpuKernelMod {
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)};
.AddOutputAttr(kNumberTypeInt64),
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool)};
return support_list;
}
};