!32713 restore conv2d bug of infershape
Merge pull request !32713 from wangyanling/conv2d_debug
This commit is contained in:
commit
3906a598ca
|
@ -36,6 +36,11 @@ constexpr size_t stride_num = 2;
|
|||
constexpr size_t dilation_num = 2;
|
||||
constexpr size_t padding_num = 4;
|
||||
constexpr size_t start_index = 2;
|
||||
constexpr size_t top_padding = 0;
|
||||
constexpr size_t bottom_padding = 1;
|
||||
constexpr size_t left_padding = 2;
|
||||
constexpr size_t right_padding = 3;
|
||||
|
||||
void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if ((shape[i] < 0) && (shape[i] != Shape::SHP_ANY)) {
|
||||
|
@ -153,12 +158,30 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
|
|||
}
|
||||
}
|
||||
|
||||
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
bool CheckConv2dShape(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args,
|
||||
const std::vector<int64_t> &x_shape, const std::vector<int64_t> &w_shape,
|
||||
const std::vector<int64_t> &padding, int64_t pad_mode, uint64_t w_axis, uint64_t h_axis) {
|
||||
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
|
||||
auto w_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
|
||||
if (x_shape_ptr->IsDynamic() || w_shape_ptr->IsDynamic()) {
|
||||
return true;
|
||||
}
|
||||
if (w_shape[w_axis] != Shape::SHP_ANY && pad_mode != PadMode::SAME) {
|
||||
int64_t input_height = x_shape[h_axis];
|
||||
int64_t input_width = x_shape[w_axis];
|
||||
if (pad_mode == PadMode::PAD) {
|
||||
input_height += padding[left_padding] + padding[right_padding];
|
||||
input_width += padding[top_padding] + padding[bottom_padding];
|
||||
}
|
||||
if (input_height < w_shape[h_axis] || input_width < w_shape[w_axis]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto w_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto x_shape = x_shape_map[kShape];
|
||||
|
@ -219,6 +242,11 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
std::vector<int64_t> padding = CheckAttrIntOrTuple(primitive->GetAttr("pad"), 0, padding_num);
|
||||
int64_t pad_mode;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
|
||||
if (!CheckConv2dShape(prim_name, input_args, x_shape, w_shape, padding, pad_mode, w_axis, h_axis)) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Shape error for Conv2d, input shape's h and w after padding is less than kernel_size's h and w dims.";
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_hw;
|
||||
std::vector<int64_t> pad_list;
|
||||
std::vector<int64_t> output_hw_min;
|
||||
|
@ -231,6 +259,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
dilation, pad_mode, padding, true);
|
||||
Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride,
|
||||
dilation, pad_mode, padding);
|
||||
|
||||
std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]),
|
||||
MakeValue(pad_list[3])};
|
||||
primitive->set_attr("pad_list", MakeValue(pad_list_val));
|
||||
|
@ -374,9 +403,11 @@ Format Conv2D::get_format() const {
|
|||
|
||||
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
|
|
Loading…
Reference in New Issue