!45078 Add check for GatherD and NLLLossGrad.

Merge pull request !45078 from TronZhang/add_more_check_for_gatherd_nlllossgrad
This commit is contained in:
i-robot 2022-11-07 01:02:19 +00:00 committed by Gitee
commit 413dba1523
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 18 additions and 3 deletions

View File

@ -281,6 +281,11 @@ bool GatherFwdGpuKernelMod::SetDimParam(int64_t dim_value) {
dim_value += x_rank;
}
if (input_shapes_.size() <= LongToSize(dim_value) || output_shapes_.size() <= LongToSize(dim_value)) {
MS_LOG(EXCEPTION) << "Rank should be great than " << dim_value << ", but got inputs size " << input_shapes_.size()
<< ", outputs size: " << output_shapes_.size();
}
size_t dim_before_axis = 1;
for (size_t i = 0; i < LongToSize(dim_value); i++) {
dim_before_axis *= output_shapes_[i];
@ -366,7 +371,9 @@ int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
int64_t dim_value = 0;
if (!is_dynamic_case_) {
const std::string kAttrDim = "dim";
auto dim_attr = base_operator->GetPrim()->GetAttr(kAttrDim);
auto prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
auto dim_attr = prim->GetAttr(kAttrDim);
if (dim_attr == nullptr) {
return KRET_RESIZE_FAILED;
}
@ -377,7 +384,10 @@ int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
}
dim_value = value_res.second;
} else {
GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value);
if (!GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value)) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dim value must be valid.";
return KRET_RESIZE_FAILED;
}
}
if (!SetDimParam(dim_value)) {
return KRET_RESIZE_FAILED;

View File

@ -139,8 +139,11 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto prim_name = primitive->name();
// check
auto prim_name = primitive->name();
const int64_t inputs_num = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, inputs_num, prim_name);
std::set<TypePtr> index_valid_types = {kInt32, kInt64};
std::set<TypePtr> dim_valid_types = {kInt32, kInt64, std::make_shared<TensorType>(kInt32),
std::make_shared<TensorType>(kInt64)};

View File

@ -15,6 +15,8 @@
*/
#include "ops/grad/nllloss_grad.h"
#include <sstream>
#include <map>
#include <vector>
#include <memory>