diff --git a/mindspore/core/ops/max_pool_with_argmax.cc b/mindspore/core/ops/max_pool_with_argmax.cc index e1a3f93dce1..424cf0994a2 100644 --- a/mindspore/core/ops/max_pool_with_argmax.cc +++ b/mindspore/core/ops/max_pool_with_argmax.cc @@ -91,6 +91,57 @@ void MaxPoolWithArgmax::Init(const std::vector &kernel_size, const std: } namespace { +std::vector GetOutShape(const string &op_name, const std::vector &in_shape, Format format, + PadMode pad_mode, const std::vector &strides, + const std::vector &kernel_size) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + int64_t batch = 0, in_h = 0, in_w = 0, channel = 0; + int64_t kernel_h = kernel_size[kMaxPoolIdx1]; + int64_t kernel_w = kernel_size[kMaxPoolIdx2]; + int64_t stride_h = strides[kMaxPoolIdx1]; + int64_t stride_w = strides[kMaxPoolIdx2]; + if (format == NCHW) { + batch = in_shape[kMaxPoolIdx0]; + channel = in_shape[kMaxPoolIdx1]; + in_h = in_shape[kMaxPoolIdx2]; + in_w = in_shape[kMaxPoolIdx3]; + } else if (format == NHWC) { + batch = in_shape[kMaxPoolIdx0]; + in_h = in_shape[kMaxPoolIdx1]; + in_w = in_shape[kMaxPoolIdx2]; + channel = in_shape[kMaxPoolIdx3]; + } + int64_t out_h = abstract::Shape::kShapeDimAny, out_w = abstract::Shape::kShapeDimAny; + if (pad_mode == VALID && in_h != abstract::Shape::kShapeDimAny) { + out_h = static_cast(std::ceil((in_h - (kernel_h - 1)) / static_cast(stride_h))); + } + if (pad_mode == VALID && in_w != abstract::Shape::kShapeDimAny) { + out_w = static_cast(std::ceil((in_w - (kernel_w - 1)) / static_cast(stride_w))); + } + + if (pad_mode == SAME && in_h != abstract::Shape::kShapeDimAny) { + out_h = static_cast(std::ceil(in_h / static_cast(stride_h))); + } + if (pad_mode == SAME && in_w != abstract::Shape::kShapeDimAny) { + out_w = static_cast(std::ceil(in_w / static_cast(stride_w))); + } + std::vector out_shape{batch, channel, out_h, out_w}; + bool is_ascend = (context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice); + bool is_gpu = (context->get_param(MS_CTX_DEVICE_TARGET) == kGPUDevice); + if (is_ascend || is_gpu) { + for (size_t i = 0; i < out_shape.size(); i++) { + if (out_shape[i] <= 0 && out_shape[i] != -1) { + MS_EXCEPTION(ValueError) << "For '" << op_name << "'," + << " the each element of the output shape must be larger than 0, but got: " + << "output shape: [" << batch << ", " << channel << ", " << out_h << ", " << out_w + << "]."; + } + } + } + return out_shape; +} + abstract::TupleShapePtr MaxPoolWithArgmaxInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto context = MsContext::GetInstance(); @@ -125,43 +176,7 @@ abstract::TupleShapePtr MaxPoolWithArgmaxInferShape(const PrimitivePtr &primitiv (void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, attr_size, op_name); (void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, attr_size, op_name); - int64_t batch = 0, in_h = 0, in_w = 0, channel = 0; - int64_t kernel_h = kernel_size[kMaxPoolIdx1]; - int64_t kernel_w = kernel_size[kMaxPoolIdx2]; - int64_t stride_h = strides[kMaxPoolIdx1]; - int64_t stride_w = strides[kMaxPoolIdx2]; - if (format == Format::NCHW) { - batch = in_shape[kMaxPoolIdx0]; - channel = in_shape[kMaxPoolIdx1]; - in_h = in_shape[kMaxPoolIdx2]; - in_w = in_shape[kMaxPoolIdx3]; - } else if (format == Format::NHWC) { - batch = in_shape[kMaxPoolIdx0]; - in_h = in_shape[kMaxPoolIdx1]; - in_w = in_shape[kMaxPoolIdx2]; - channel = in_shape[kMaxPoolIdx3]; - } - int64_t out_h = abstract::Shape::kShapeDimAny, out_w = abstract::Shape::kShapeDimAny; - if (pad_mode == PadMode::VALID) { - out_h = static_cast(std::ceil((in_h - (kernel_h - 1)) / static_cast(stride_h))); - out_w = static_cast(std::ceil((in_w - (kernel_w - 1)) / static_cast(stride_w))); - } else if (pad_mode == PadMode::SAME) { - out_h = static_cast(std::ceil(in_h / static_cast(stride_h))); - out_w = static_cast(std::ceil(in_w / static_cast(stride_w))); - } - std::vector out_shape = {batch, channel, out_h, out_w}; - bool is_ascend = (context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice); - bool is_gpu = (context->get_param(MS_CTX_DEVICE_TARGET) == kGPUDevice); - if (is_ascend || is_gpu) { - for (size_t i = 0; i < out_shape.size(); i++) { - if (out_shape[i] <= 0 && out_shape[i] != -1) { - MS_EXCEPTION(ValueError) << "For '" << op_name << "'," - << " the each element of the output shape must be larger than 0, but got: " - << "output shape: [" << batch << ", " << channel << ", " << out_h << ", " << out_w - << "]."; - } - } - } + std::vector out_shape = GetOutShape(op_name, in_shape, format, pad_mode, strides, kernel_size); // Process attr mapping problems from mindspore to tbe // kernel_size -> ksize