!45022 fix bug of use of CheckTensorSize in new NativeGpuKernelMod

Merge pull request !45022 from hanhuifeng/kernel_mod_new_4
This commit is contained in:
i-robot 2022-11-03 02:14:14 +00:00 committed by Gitee
commit 16e2d3ec98
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 9 additions and 3 deletions

View File

@ -124,7 +124,9 @@ int Conv2dFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
if (is_null_input_) {
return KRET_OK;
}
CheckTensorSize({in_shape, filter_shape, output_shape});
if (!CheckTensorSize({in_shape, filter_shape, output_shape})) {
return KRET_RESIZE_FAILED;
}
std::vector<int> pad_list;
// The pad_list is computed in infer shape
auto pad_list_me = GetValue<std::vector<int64_t>>(base_operator->GetAttr("pad_list"));

View File

@ -315,7 +315,9 @@ int ConvGradFilterBkwGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
filter_shape_ = value_res.value();
}
auto filter_shape = filter_shape_;
CheckTensorSize({in_shape, dy_shape, filter_shape});
if (!CheckTensorSize({in_shape, dy_shape, filter_shape})) {
return KRET_RESIZE_FAILED;
}
int h_index = k2DHeightIndexNCHW;
int w_index = k2DHeightIndexNCHW + 1;
if (data_format_ == kOpFormat_NHWC) {

View File

@ -314,7 +314,9 @@ int ConvGradInputBkwGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
ShapeNCHW2NHWC(&input_shape);
}
}
CheckTensorSize({input_shape, dy_shape, filter_shape});
if (!CheckTensorSize({input_shape, dy_shape, filter_shape})) {
return KRET_RESIZE_FAILED;
}
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
Set4DDesc(dy_shape, input_shape, filter_shape);
auto pad_list = pad_list_;