diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gather_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gather_gpu_kernel.cc index 33cb277e7d0..51e57743056 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gather_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gather_gpu_kernel.cc @@ -281,6 +281,11 @@ bool GatherFwdGpuKernelMod::SetDimParam(int64_t dim_value) { dim_value += x_rank; } + if (input_shapes_.size() <= LongToSize(dim_value) || output_shapes_.size() <= LongToSize(dim_value)) { + MS_LOG(EXCEPTION) << "Rank should be great than " << dim_value << ", but got inputs size " << input_shapes_.size() + << ", outputs size: " << output_shapes_.size(); + } + size_t dim_before_axis = 1; for (size_t i = 0; i < LongToSize(dim_value); i++) { dim_before_axis *= output_shapes_[i]; @@ -366,7 +371,9 @@ int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st int64_t dim_value = 0; if (!is_dynamic_case_) { const std::string kAttrDim = "dim"; - auto dim_attr = base_operator->GetPrim()->GetAttr(kAttrDim); + auto prim = base_operator->GetPrim(); + MS_EXCEPTION_IF_NULL(prim); + auto dim_attr = prim->GetAttr(kAttrDim); if (dim_attr == nullptr) { return KRET_RESIZE_FAILED; } @@ -377,7 +384,10 @@ int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st } dim_value = value_res.second; } else { - GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value); + if (!GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value)) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', dim value must be valid."; + return KRET_RESIZE_FAILED; + } } if (!SetDimParam(dim_value)) { return KRET_RESIZE_FAILED; diff --git a/mindspore/core/ops/gather_d.cc b/mindspore/core/ops/gather_d.cc index 666d639aeb7..8cfa17cc5f2 100644 --- a/mindspore/core/ops/gather_d.cc +++ b/mindspore/core/ops/gather_d.cc @@ -139,8 +139,11 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv for (auto item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto prim_name = primitive->name(); + // check + auto prim_name = primitive->name(); + const int64_t inputs_num = 3; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, inputs_num, prim_name); std::set index_valid_types = {kInt32, kInt64}; std::set dim_valid_types = {kInt32, kInt64, std::make_shared(kInt32), std::make_shared(kInt64)}; diff --git a/mindspore/core/ops/grad/nllloss_grad.cc b/mindspore/core/ops/grad/nllloss_grad.cc index 8938d3ec471..dbc623c9de1 100644 --- a/mindspore/core/ops/grad/nllloss_grad.cc +++ b/mindspore/core/ops/grad/nllloss_grad.cc @@ -15,6 +15,8 @@ */ #include "ops/grad/nllloss_grad.h" + +#include #include #include #include