forked from mindspore-Ecosystem/mindspore
fix bug of conv2d cpp infer
This commit is contained in:
parent
8f9666c93f
commit
a39b312191
|
@ -72,13 +72,13 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
|
||||
auto pad_needed_h =
|
||||
std::max(static_cast<int64_t>(0), (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
|
||||
pad_list.emplace_back(floor(pad_needed_h / 2));
|
||||
pad_list.emplace_back(pad_needed_h / 2);
|
||||
pad_list[0] = floor(pad_needed_h / 2);
|
||||
pad_list[1] = pad_needed_h / 2;
|
||||
auto pad_needed_w =
|
||||
std::max(static_cast<int64_t>(0), (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
|
||||
auto pad_left = floor(pad_needed_w / 2);
|
||||
pad_list.emplace_back(pad_left);
|
||||
pad_list.emplace_back(pad_needed_h - pad_left);
|
||||
pad_list[2] = pad_left;
|
||||
pad_list[3] = pad_needed_h - pad_left;
|
||||
} else if (pad_mode == PAD) {
|
||||
auto pad = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
|
||||
std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list));
|
||||
|
|
Loading…
Reference in New Issue