!45043 fixed scatter_nd gpu kernel bug

Merge pull request !45043 from huoxinyou/1103scatternd
This commit is contained in:
i-robot 2022-11-03 06:16:46 +00:00 committed by Gitee
commit 78957f88cf
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 4 additions and 4 deletions

View File

@ -130,14 +130,14 @@ bool ScatterNdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
int ScatterNdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (!GetDynamicAttrIntValue(inputs, kShapeIndex_, inputsOnHost, kernel_name_, &attr_shape_)) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
return KRET_RESIZE_FAILED;
}
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
memcpy_flag_ = false;
if (!TryGetIntValue(inputs, kShapeIndex_, kernel_name_, &attr_shape_)) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
return KRET_RESIZE_FAILED;
}
CalSize(inputs, outputs);
auto indices_unit_size = abstract::TypeIdSize(inputs[0]->GetDtype());