forked from mindspore-Ecosystem/mindspore
tensorscatterupdate add bool
This commit is contained in:
parent
79e9d91a5c
commit
0330cfdb9b
|
@ -140,6 +140,9 @@ bool ScatterUpdateCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
|
|||
case kNumberTypeInt64:
|
||||
LaunchKernel<int64_t>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeBool:
|
||||
LaunchKernel<bool>(inputs, outputs);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dtype of 'input_x' must be float16, float32, float64, int32 or int64, but got "
|
||||
|
|
|
@ -121,7 +121,19 @@ class TensorScatterUpdateCpuKernelMod : public ScatterUpdateCpuKernelMod {
|
|||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)};
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeBool)};
|
||||
return support_list;
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue