!44898 Fix out h/w shape infer error while in h/w shape is anyshape.
Merge pull request !44898 from zhangzhaoju/master_maxpool_any
This commit is contained in:
commit
e9e3915231
|
@ -91,6 +91,57 @@ void MaxPoolWithArgmax::Init(const std::vector<int64_t> &kernel_size, const std:
|
|||
}
|
||||
|
||||
namespace {
|
||||
std::vector<int64_t> GetOutShape(const string &op_name, const std::vector<int64_t> &in_shape, Format format,
|
||||
PadMode pad_mode, const std::vector<int64_t> &strides,
|
||||
const std::vector<int64_t> &kernel_size) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
int64_t batch = 0, in_h = 0, in_w = 0, channel = 0;
|
||||
int64_t kernel_h = kernel_size[kMaxPoolIdx1];
|
||||
int64_t kernel_w = kernel_size[kMaxPoolIdx2];
|
||||
int64_t stride_h = strides[kMaxPoolIdx1];
|
||||
int64_t stride_w = strides[kMaxPoolIdx2];
|
||||
if (format == NCHW) {
|
||||
batch = in_shape[kMaxPoolIdx0];
|
||||
channel = in_shape[kMaxPoolIdx1];
|
||||
in_h = in_shape[kMaxPoolIdx2];
|
||||
in_w = in_shape[kMaxPoolIdx3];
|
||||
} else if (format == NHWC) {
|
||||
batch = in_shape[kMaxPoolIdx0];
|
||||
in_h = in_shape[kMaxPoolIdx1];
|
||||
in_w = in_shape[kMaxPoolIdx2];
|
||||
channel = in_shape[kMaxPoolIdx3];
|
||||
}
|
||||
int64_t out_h = abstract::Shape::kShapeDimAny, out_w = abstract::Shape::kShapeDimAny;
|
||||
if (pad_mode == VALID && in_h != abstract::Shape::kShapeDimAny) {
|
||||
out_h = static_cast<int64_t>(std::ceil((in_h - (kernel_h - 1)) / static_cast<float>(stride_h)));
|
||||
}
|
||||
if (pad_mode == VALID && in_w != abstract::Shape::kShapeDimAny) {
|
||||
out_w = static_cast<int64_t>(std::ceil((in_w - (kernel_w - 1)) / static_cast<float>(stride_w)));
|
||||
}
|
||||
|
||||
if (pad_mode == SAME && in_h != abstract::Shape::kShapeDimAny) {
|
||||
out_h = static_cast<int64_t>(std::ceil(in_h / static_cast<float>(stride_h)));
|
||||
}
|
||||
if (pad_mode == SAME && in_w != abstract::Shape::kShapeDimAny) {
|
||||
out_w = static_cast<int64_t>(std::ceil(in_w / static_cast<float>(stride_w)));
|
||||
}
|
||||
std::vector<int64_t> out_shape{batch, channel, out_h, out_w};
|
||||
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
|
||||
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
||||
if (is_ascend || is_gpu) {
|
||||
for (size_t i = 0; i < out_shape.size(); i++) {
|
||||
if (out_shape[i] <= 0 && out_shape[i] != -1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "',"
|
||||
<< " the each element of the output shape must be larger than 0, but got: "
|
||||
<< "output shape: [" << batch << ", " << channel << ", " << out_h << ", " << out_w
|
||||
<< "].";
|
||||
}
|
||||
}
|
||||
}
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
abstract::TupleShapePtr MaxPoolWithArgmaxInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto context = MsContext::GetInstance();
|
||||
|
@ -125,43 +176,7 @@ abstract::TupleShapePtr MaxPoolWithArgmaxInferShape(const PrimitivePtr &primitiv
|
|||
(void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, attr_size, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, attr_size, op_name);
|
||||
|
||||
int64_t batch = 0, in_h = 0, in_w = 0, channel = 0;
|
||||
int64_t kernel_h = kernel_size[kMaxPoolIdx1];
|
||||
int64_t kernel_w = kernel_size[kMaxPoolIdx2];
|
||||
int64_t stride_h = strides[kMaxPoolIdx1];
|
||||
int64_t stride_w = strides[kMaxPoolIdx2];
|
||||
if (format == Format::NCHW) {
|
||||
batch = in_shape[kMaxPoolIdx0];
|
||||
channel = in_shape[kMaxPoolIdx1];
|
||||
in_h = in_shape[kMaxPoolIdx2];
|
||||
in_w = in_shape[kMaxPoolIdx3];
|
||||
} else if (format == Format::NHWC) {
|
||||
batch = in_shape[kMaxPoolIdx0];
|
||||
in_h = in_shape[kMaxPoolIdx1];
|
||||
in_w = in_shape[kMaxPoolIdx2];
|
||||
channel = in_shape[kMaxPoolIdx3];
|
||||
}
|
||||
int64_t out_h = abstract::Shape::kShapeDimAny, out_w = abstract::Shape::kShapeDimAny;
|
||||
if (pad_mode == PadMode::VALID) {
|
||||
out_h = static_cast<int64_t>(std::ceil((in_h - (kernel_h - 1)) / static_cast<float>(stride_h)));
|
||||
out_w = static_cast<int64_t>(std::ceil((in_w - (kernel_w - 1)) / static_cast<float>(stride_w)));
|
||||
} else if (pad_mode == PadMode::SAME) {
|
||||
out_h = static_cast<int64_t>(std::ceil(in_h / static_cast<float>(stride_h)));
|
||||
out_w = static_cast<int64_t>(std::ceil(in_w / static_cast<float>(stride_w)));
|
||||
}
|
||||
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
|
||||
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
|
||||
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
||||
if (is_ascend || is_gpu) {
|
||||
for (size_t i = 0; i < out_shape.size(); i++) {
|
||||
if (out_shape[i] <= 0 && out_shape[i] != -1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "',"
|
||||
<< " the each element of the output shape must be larger than 0, but got: "
|
||||
<< "output shape: [" << batch << ", " << channel << ", " << out_h << ", " << out_w
|
||||
<< "].";
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> out_shape = GetOutShape(op_name, in_shape, format, pad_mode, strides, kernel_size);
|
||||
|
||||
// Process attr mapping problems from mindspore to tbe
|
||||
// kernel_size -> ksize
|
||||
|
|
Loading…
Reference in New Issue