!45078 Add check for GatherD and NLLLossGrad.
Merge pull request !45078 from TronZhang/add_more_check_for_gatherd_nlllossgrad
This commit is contained in:
commit
413dba1523
|
@ -281,6 +281,11 @@ bool GatherFwdGpuKernelMod::SetDimParam(int64_t dim_value) {
|
|||
dim_value += x_rank;
|
||||
}
|
||||
|
||||
if (input_shapes_.size() <= LongToSize(dim_value) || output_shapes_.size() <= LongToSize(dim_value)) {
|
||||
MS_LOG(EXCEPTION) << "Rank should be great than " << dim_value << ", but got inputs size " << input_shapes_.size()
|
||||
<< ", outputs size: " << output_shapes_.size();
|
||||
}
|
||||
|
||||
size_t dim_before_axis = 1;
|
||||
for (size_t i = 0; i < LongToSize(dim_value); i++) {
|
||||
dim_before_axis *= output_shapes_[i];
|
||||
|
@ -366,7 +371,9 @@ int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
int64_t dim_value = 0;
|
||||
if (!is_dynamic_case_) {
|
||||
const std::string kAttrDim = "dim";
|
||||
auto dim_attr = base_operator->GetPrim()->GetAttr(kAttrDim);
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto dim_attr = prim->GetAttr(kAttrDim);
|
||||
if (dim_attr == nullptr) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
@ -377,7 +384,10 @@ int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
}
|
||||
dim_value = value_res.second;
|
||||
} else {
|
||||
GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value);
|
||||
if (!GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value)) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dim value must be valid.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
}
|
||||
if (!SetDimParam(dim_value)) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
|
|
|
@ -139,8 +139,11 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto prim_name = primitive->name();
|
||||
|
||||
// check
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t inputs_num = 3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, inputs_num, prim_name);
|
||||
std::set<TypePtr> index_valid_types = {kInt32, kInt64};
|
||||
std::set<TypePtr> dim_valid_types = {kInt32, kInt64, std::make_shared<TensorType>(kInt32),
|
||||
std::make_shared<TensorType>(kInt64)};
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "ops/grad/nllloss_grad.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
|
Loading…
Reference in New Issue