fix bug of conv2d cpp infer

This commit is contained in:
LianLiguang 2021-02-24 19:44:09 +08:00
parent 8f9666c93f
commit a39b312191
1 changed files with 4 additions and 4 deletions

View File

@ -72,13 +72,13 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
auto pad_needed_h =
std::max(static_cast<int64_t>(0), (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
pad_list.emplace_back(floor(pad_needed_h / 2));
pad_list.emplace_back(pad_needed_h / 2);
pad_list[0] = floor(pad_needed_h / 2);
pad_list[1] = pad_needed_h / 2;
auto pad_needed_w =
std::max(static_cast<int64_t>(0), (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
auto pad_left = floor(pad_needed_w / 2);
pad_list.emplace_back(pad_left);
pad_list.emplace_back(pad_needed_h - pad_left);
pad_list[2] = pad_left;
pad_list[3] = pad_needed_h - pad_left;
} else if (pad_mode == PAD) {
auto pad = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list));