tensorscatterupdate add bool

This commit is contained in:
VectorSL 2022-05-09 14:54:12 +08:00
parent 79e9d91a5c
commit 0330cfdb9b
2 changed files with 16 additions and 1 deletions

View File

@ -140,6 +140,9 @@ bool ScatterUpdateCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
case kNumberTypeInt64:
LaunchKernel<int64_t>(inputs, outputs);
break;
case kNumberTypeBool:
LaunchKernel<bool>(inputs, outputs);
break;
default:
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dtype of 'input_x' must be float16, float32, float64, int32 or int64, but got "

View File

@ -121,7 +121,19 @@ class TensorScatterUpdateCpuKernelMod : public ScatterUpdateCpuKernelMod {
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)};
.AddOutputAttr(kNumberTypeInt64),
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool)};
return support_list;
}
};