!32713 restore conv2d bug of infershape

Merge pull request !32713 from wangyanling/conv2d_debug
This commit is contained in:
i-robot 2022-04-09 09:59:59 +00:00 committed by Gitee
commit 3906a598ca
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 36 additions and 5 deletions

View File

@ -36,6 +36,11 @@ constexpr size_t stride_num = 2;
constexpr size_t dilation_num = 2;
constexpr size_t padding_num = 4;
constexpr size_t start_index = 2;
constexpr size_t top_padding = 0;
constexpr size_t bottom_padding = 1;
constexpr size_t left_padding = 2;
constexpr size_t right_padding = 3;
void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
for (size_t i = 0; i < shape.size(); ++i) {
if ((shape[i] < 0) && (shape[i] != Shape::SHP_ANY)) {
@ -153,12 +158,30 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
}
}
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
bool CheckConv2dShape(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args,
const std::vector<int64_t> &x_shape, const std::vector<int64_t> &w_shape,
const std::vector<int64_t> &padding, int64_t pad_mode, uint64_t w_axis, uint64_t h_axis) {
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
auto w_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
if (x_shape_ptr->IsDynamic() || w_shape_ptr->IsDynamic()) {
return true;
}
if (w_shape[w_axis] != Shape::SHP_ANY && pad_mode != PadMode::SAME) {
int64_t input_height = x_shape[h_axis];
int64_t input_width = x_shape[w_axis];
if (pad_mode == PadMode::PAD) {
input_height += padding[left_padding] + padding[right_padding];
input_width += padding[top_padding] + padding[bottom_padding];
}
if (input_height < w_shape[h_axis] || input_width < w_shape[w_axis]) {
return false;
}
}
return true;
}
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto w_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto x_shape = x_shape_map[kShape];
@ -219,6 +242,11 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
std::vector<int64_t> padding = CheckAttrIntOrTuple(primitive->GetAttr("pad"), 0, padding_num);
int64_t pad_mode;
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
if (!CheckConv2dShape(prim_name, input_args, x_shape, w_shape, padding, pad_mode, w_axis, h_axis)) {
MS_LOG(EXCEPTION)
<< "Shape error for Conv2d, input shape's h and w after padding is less than kernel_size's h and w dims.";
}
std::vector<int64_t> output_hw;
std::vector<int64_t> pad_list;
std::vector<int64_t> output_hw_min;
@ -231,6 +259,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
dilation, pad_mode, padding, true);
Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride,
dilation, pad_mode, padding);
std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]),
MakeValue(pad_list[3])};
primitive->set_attr("pad_list", MakeValue(pad_list_val));
@ -374,9 +403,11 @@ Format Conv2D::get_format() const {
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
primitive->name());