!44898 Fix out h/w shape infer error while in h/w shape is anyshape.

Merge pull request !44898 from zhangzhaoju/master_maxpool_any
This commit is contained in:
i-robot 2022-11-02 07:00:05 +00:00 committed by Gitee
commit e9e3915231
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 52 additions and 37 deletions

View File

@ -91,6 +91,57 @@ void MaxPoolWithArgmax::Init(const std::vector<int64_t> &kernel_size, const std:
}
namespace {
std::vector<int64_t> GetOutShape(const string &op_name, const std::vector<int64_t> &in_shape, Format format,
PadMode pad_mode, const std::vector<int64_t> &strides,
const std::vector<int64_t> &kernel_size) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
int64_t batch = 0, in_h = 0, in_w = 0, channel = 0;
int64_t kernel_h = kernel_size[kMaxPoolIdx1];
int64_t kernel_w = kernel_size[kMaxPoolIdx2];
int64_t stride_h = strides[kMaxPoolIdx1];
int64_t stride_w = strides[kMaxPoolIdx2];
if (format == NCHW) {
batch = in_shape[kMaxPoolIdx0];
channel = in_shape[kMaxPoolIdx1];
in_h = in_shape[kMaxPoolIdx2];
in_w = in_shape[kMaxPoolIdx3];
} else if (format == NHWC) {
batch = in_shape[kMaxPoolIdx0];
in_h = in_shape[kMaxPoolIdx1];
in_w = in_shape[kMaxPoolIdx2];
channel = in_shape[kMaxPoolIdx3];
}
int64_t out_h = abstract::Shape::kShapeDimAny, out_w = abstract::Shape::kShapeDimAny;
if (pad_mode == VALID && in_h != abstract::Shape::kShapeDimAny) {
out_h = static_cast<int64_t>(std::ceil((in_h - (kernel_h - 1)) / static_cast<float>(stride_h)));
}
if (pad_mode == VALID && in_w != abstract::Shape::kShapeDimAny) {
out_w = static_cast<int64_t>(std::ceil((in_w - (kernel_w - 1)) / static_cast<float>(stride_w)));
}
if (pad_mode == SAME && in_h != abstract::Shape::kShapeDimAny) {
out_h = static_cast<int64_t>(std::ceil(in_h / static_cast<float>(stride_h)));
}
if (pad_mode == SAME && in_w != abstract::Shape::kShapeDimAny) {
out_w = static_cast<int64_t>(std::ceil(in_w / static_cast<float>(stride_w)));
}
std::vector<int64_t> out_shape{batch, channel, out_h, out_w};
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
if (is_ascend || is_gpu) {
for (size_t i = 0; i < out_shape.size(); i++) {
if (out_shape[i] <= 0 && out_shape[i] != -1) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "',"
<< " the each element of the output shape must be larger than 0, but got: "
<< "output shape: [" << batch << ", " << channel << ", " << out_h << ", " << out_w
<< "].";
}
}
}
return out_shape;
}
abstract::TupleShapePtr MaxPoolWithArgmaxInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto context = MsContext::GetInstance();
@ -125,43 +176,7 @@ abstract::TupleShapePtr MaxPoolWithArgmaxInferShape(const PrimitivePtr &primitiv
(void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, attr_size, op_name);
(void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, attr_size, op_name);
int64_t batch = 0, in_h = 0, in_w = 0, channel = 0;
int64_t kernel_h = kernel_size[kMaxPoolIdx1];
int64_t kernel_w = kernel_size[kMaxPoolIdx2];
int64_t stride_h = strides[kMaxPoolIdx1];
int64_t stride_w = strides[kMaxPoolIdx2];
if (format == Format::NCHW) {
batch = in_shape[kMaxPoolIdx0];
channel = in_shape[kMaxPoolIdx1];
in_h = in_shape[kMaxPoolIdx2];
in_w = in_shape[kMaxPoolIdx3];
} else if (format == Format::NHWC) {
batch = in_shape[kMaxPoolIdx0];
in_h = in_shape[kMaxPoolIdx1];
in_w = in_shape[kMaxPoolIdx2];
channel = in_shape[kMaxPoolIdx3];
}
int64_t out_h = abstract::Shape::kShapeDimAny, out_w = abstract::Shape::kShapeDimAny;
if (pad_mode == PadMode::VALID) {
out_h = static_cast<int64_t>(std::ceil((in_h - (kernel_h - 1)) / static_cast<float>(stride_h)));
out_w = static_cast<int64_t>(std::ceil((in_w - (kernel_w - 1)) / static_cast<float>(stride_w)));
} else if (pad_mode == PadMode::SAME) {
out_h = static_cast<int64_t>(std::ceil(in_h / static_cast<float>(stride_h)));
out_w = static_cast<int64_t>(std::ceil(in_w / static_cast<float>(stride_w)));
}
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
if (is_ascend || is_gpu) {
for (size_t i = 0; i < out_shape.size(); i++) {
if (out_shape[i] <= 0 && out_shape[i] != -1) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "',"
<< " the each element of the output shape must be larger than 0, but got: "
<< "output shape: [" << batch << ", " << channel << ", " << out_h << ", " << out_w
<< "].";
}
}
}
std::vector<int64_t> out_shape = GetOutShape(op_name, in_shape, format, pad_mode, strides, kernel_size);
// Process attr mapping problems from mindspore to tbe
// kernel_size -> ksize