!45022 fix bug of use of CheckTensorSize in new NativeGpuKernelMod
Merge pull request !45022 from hanhuifeng/kernel_mod_new_4
This commit is contained in:
commit
16e2d3ec98
|
@ -124,7 +124,9 @@ int Conv2dFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
CheckTensorSize({in_shape, filter_shape, output_shape});
|
||||
if (!CheckTensorSize({in_shape, filter_shape, output_shape})) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
std::vector<int> pad_list;
|
||||
// The pad_list is computed in infer shape
|
||||
auto pad_list_me = GetValue<std::vector<int64_t>>(base_operator->GetAttr("pad_list"));
|
||||
|
|
|
@ -315,7 +315,9 @@ int ConvGradFilterBkwGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
filter_shape_ = value_res.value();
|
||||
}
|
||||
auto filter_shape = filter_shape_;
|
||||
CheckTensorSize({in_shape, dy_shape, filter_shape});
|
||||
if (!CheckTensorSize({in_shape, dy_shape, filter_shape})) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
int h_index = k2DHeightIndexNCHW;
|
||||
int w_index = k2DHeightIndexNCHW + 1;
|
||||
if (data_format_ == kOpFormat_NHWC) {
|
||||
|
|
|
@ -314,7 +314,9 @@ int ConvGradInputBkwGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
ShapeNCHW2NHWC(&input_shape);
|
||||
}
|
||||
}
|
||||
CheckTensorSize({input_shape, dy_shape, filter_shape});
|
||||
if (!CheckTensorSize({input_shape, dy_shape, filter_shape})) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
|
||||
Set4DDesc(dy_shape, input_shape, filter_shape);
|
||||
auto pad_list = pad_list_;
|
||||
|
|
Loading…
Reference in New Issue