forked from mindspore-Ecosystem/mindspore
!45043 fixed scatter_nd gpu kernel bug
Merge pull request !45043 from huoxinyou/1103scatternd
This commit is contained in:
commit
78957f88cf
|
@ -130,14 +130,14 @@ bool ScatterNdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
|
||||||
int ScatterNdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
int ScatterNdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
const std::vector<KernelTensorPtr> &outputs,
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
if (!GetDynamicAttrIntValue(inputs, kShapeIndex_, inputsOnHost, kernel_name_, &attr_shape_)) {
|
|
||||||
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
|
|
||||||
return KRET_RESIZE_FAILED;
|
|
||||||
}
|
|
||||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
memcpy_flag_ = false;
|
memcpy_flag_ = false;
|
||||||
|
if (!TryGetIntValue(inputs, kShapeIndex_, kernel_name_, &attr_shape_)) {
|
||||||
|
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
CalSize(inputs, outputs);
|
CalSize(inputs, outputs);
|
||||||
auto indices_unit_size = abstract::TypeIdSize(inputs[0]->GetDtype());
|
auto indices_unit_size = abstract::TypeIdSize(inputs[0]->GetDtype());
|
||||||
|
|
Loading…
Reference in New Issue