fixed scatter_nd gpu kernel bug

This commit is contained in:
huoxinyou 2022-11-03 10:11:35 +08:00
parent 91a40660ec
commit 87b829da5e
1 changed files with 4 additions and 4 deletions

View File

@ -130,14 +130,14 @@ bool ScatterNdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
int ScatterNdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (!GetDynamicAttrIntValue(inputs, kShapeIndex_, inputsOnHost, kernel_name_, &attr_shape_)) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
return KRET_RESIZE_FAILED;
}
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
memcpy_flag_ = false;
if (!TryGetIntValue(inputs, kShapeIndex_, kernel_name_, &attr_shape_)) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
return KRET_RESIZE_FAILED;
}
CalSize(inputs, outputs);
auto indices_unit_size = abstract::TypeIdSize(inputs[0]->GetDtype());