!45043 fixed scatter_nd gpu kernel bug
Merge pull request !45043 from huoxinyou/1103scatternd
This commit is contained in:
commit
78957f88cf
|
@ -130,14 +130,14 @@ bool ScatterNdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
|
|||
int ScatterNdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
if (!GetDynamicAttrIntValue(inputs, kShapeIndex_, inputsOnHost, kernel_name_, &attr_shape_)) {
|
||||
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
memcpy_flag_ = false;
|
||||
if (!TryGetIntValue(inputs, kShapeIndex_, kernel_name_, &attr_shape_)) {
|
||||
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
CalSize(inputs, outputs);
|
||||
auto indices_unit_size = abstract::TypeIdSize(inputs[0]->GetDtype());
|
||||
|
|
Loading…
Reference in New Issue