From 661b993475e404180665f2506cf8a345e3487ee2 Mon Sep 17 00:00:00 2001 From: Jonathan Yan Date: Tue, 21 Jul 2020 20:15:49 -0400 Subject: [PATCH] roi align v1 --- .../gpu/cuda_impl/roi_align_impl.cu | 228 ++++++++++++++++++ .../gpu/cuda_impl/roi_align_impl.cuh | 24 ++ .../gpu/nn/roi_align_gpu_kernel.cc | 32 +++ .../gpu/nn/roi_align_gpu_kernel.h | 140 +++++++++++ tests/st/ops/gpu/test_roi_align_op.py | 95 ++++++++ 5 files changed, 519 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_roi_align_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu new file mode 100644 index 0000000000..8d8aa800b2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu @@ -0,0 +1,228 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "roi_align_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +inline __device__ T gpu_atomic_add(const T val, T *address); + +template <> +inline __device__ float gpu_atomic_add(const float val, float *address) { + return atomicAdd(address, val); +} + +template +__device__ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high, + int *y_high, T *w1, T *w2, T *w3, T *w4) { + // return 0 if out of map boundary + if (y <= static_cast(-1.0) || y >= static_cast(height) || x <= static_cast(-1.0) || + x >= static_cast(width)) { + *w1 = *w2 = *w3 = *w4 = 0; + *x_low = *x_high = *y_low = *y_high = -1; + return; + } + + // low bounder is at least zero + y = y <= static_cast(.0) ? static_cast(.0) : y; + x = x <= static_cast(.0) ? static_cast(.0) : x; + + // top left point + *y_low = static_cast(y); + *x_low = static_cast(x); + + // bottom right point + if (*y_low >= height - 1) { + *y_high = *y_low = height - 1; + y = static_cast(*y_low); + } else { + *y_high = *y_low + 1; + } + + if (*x_low >= width - 1) { + *x_high = *x_low = width - 1; + x = static_cast(*x_low); + } else { + *x_high = *x_low + 1; + } + + // distance to nearest points + T lx, ly, hx, hy; + ly = y - static_cast(*y_low), lx = x - static_cast(*x_low); + hy = static_cast(1.) - ly, hx = static_cast(1.) - lx; + + // weight is evaluated by the distance to point away. + // the closer to point home, the more weight, the farther to point away. + *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; + return; +} + +template +__device__ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const T spatial_scale, const int sample_num, + int roi_end_mode, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, int *offset, int *n, int *c, int *ph, int *pw, + int *roi_bin_grid_h, int *roi_bin_grid_w, T *bin_size_h, T *bin_size_w, T *roi_start_h, + T *roi_start_w) { + // (n, c, ph, pw) is the base param of pooled map + *pw = thread_idx % pooled_width; + *ph = (thread_idx / pooled_width) % pooled_height; + *c = (thread_idx / pooled_width / pooled_height) % channels; + *n = thread_idx / pooled_width / pooled_height / channels; + + // Roi has + // 1. 4 points, or + // 2. indicator + 4 points (1 + 4) + const T *roi_box = roi_boxes + (*n) * roi_cols; + int roi_batch_ind = 0; + if (roi_cols == 5) { + roi_batch_ind = roi_box[0]; + roi_box++; + } + + // Scale and shift ROI + T roi_offset = roi_end_mode == 1 ? static_cast(0.5) : static_cast(.0); + *roi_start_w = roi_box[0] * spatial_scale - roi_offset; + *roi_start_h = roi_box[1] * spatial_scale - roi_offset; + T roi_end_w = roi_box[2] * spatial_scale - roi_offset; + T roi_end_h = roi_box[3] * spatial_scale - roi_offset; + + // New ROI height/width + T roi_width = roi_end_w - (*roi_start_w); + T roi_height = roi_end_h - (*roi_start_h); + + // ratio of roi / pooled + *bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + *bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + *offset = (roi_batch_ind * channels + (*c)) * height * width; + + // grid (int) by Sample ratio if defined, otherwise by pooled H/W + *roi_bin_grid_h = (sample_num > 0) ? sample_num : static_cast(roi_height / static_cast(pooled_height)); + *roi_bin_grid_w = (sample_num > 0) ? sample_num : static_cast(roi_width / static_cast(pooled_width)); + return; +} + +template +__global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes, int roi_cols, T *out_data, + const T spatial_scale, const int sample_num, int roi_end_mode, const int channels, + const int height, const int width, const int pooled_height, const int pooled_width) { + for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; + thread_idx += blockDim.x * gridDim.x) { + int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w; + T bin_size_h, bin_size_w, roi_start_h, roi_start_w; + + bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width, + pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h, + &bin_size_w, &roi_start_h, &roi_start_w); + + // (n, c, ph, pw) is the base param of pooled map + const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w; + + T accumulate_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + // Shift half point RIGHT for y / x, while previous scaled roi shift half point LEFT + const T y = roi_start_h + static_cast(ph) * bin_size_h + + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + static_cast(pw) * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + // bilinear interpolate by shifted y / x + // calculate bilinear interpolation + int x_low, y_low, x_high, y_high; + T w1, w2, w3, w4; + bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4); + T v1 = input[y_low * width + x_low + offset]; + T v2 = input[y_low * width + x_high + offset]; + T v3 = input[y_high * width + x_low + offset]; + T v4 = input[y_high * width + x_high + offset]; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + accumulate_val += val; + } + } + accumulate_val /= count_points_in_grid_cell; + + out_data[thread_idx] = accumulate_val; + } +} + +template +void ROIAlign(const T *x, const T *roi_boxes, int roi_rows, int roi_cols, T *out_data, const T spatial_scale, + const int sample_num, int roi_end_mode, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, cudaStream_t cuda_stream) { + size_t size = roi_rows * channels * pooled_height * pooled_width; + ROIAlignKernel<<>>(size, x, roi_boxes, roi_cols, out_data, + spatial_scale, sample_num, roi_end_mode, channels, + height, width, pooled_height, pooled_width); + return; +} + +template void ROIAlign(const float *x, const float *roi_boxes, int roi_rows, int roi_cols, float *out_data, + const float spatial_scale, const int sample_num, int roi_end_mode, const int channels, + const int height, const int width, const int pooled_height, const int pooled_width, + cudaStream_t cuda_stream); + +template void ROIAlign(const half *x, const half *roi_boxes, int roi_rows, int roi_cols, half *out_data, + const half spatial_scale, const int sample_num, int roi_end_mode, const int channels, + const int height, const int width, const int pooled_height, const int pooled_width, + cudaStream_t cuda_stream); + +template +__global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes, int roi_cols, T *dx, + const T spatial_scale, const int sample_num, int roi_end_mode, const int channels, + const int height, const int width, const int pooled_height, const int pooled_width) { + for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; + thread_idx += blockDim.x * gridDim.x) { + int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w; + T bin_size_h, bin_size_w, roi_start_h, roi_start_w; + + bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width, + pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h, + &bin_size_w, &roi_start_h, &roi_start_w); + + // (n, c, ph, pw) is the base param of pooled map + const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w; + + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T *offset_top_diff = dy + top_offset; + const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + // Shift half point RIGHT for y / x, while previous scaled roi shift half point LEFT + const T y = + roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = + roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + // bilinear interpolate by shifted y / x + // calculate bilinear interpolation + int x_low, y_low, x_high, y_high; + T w1, w2, w3, w4; + bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4); + T g1 = top_diff_this_bin * w1 / count_points_in_grid_cell; + T g2 = top_diff_this_bin * w2 / count_points_in_grid_cell; + T g3 = top_diff_this_bin * w3 / count_points_in_grid_cell; + T g4 = top_diff_this_bin * w4 / count_points_in_grid_cell; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + gpu_atomic_add(static_cast(g1), dx + offset + y_low * width + x_low); + gpu_atomic_add(static_cast(g2), dx + offset + y_low * width + x_high); + gpu_atomic_add(static_cast(g3), dx + offset + y_high * width + x_low); + gpu_atomic_add(static_cast(g4), dx + offset + y_high * width + x_high); + } + } + } + } +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh new file mode 100644 index 0000000000..53e31a1d50 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_ +template +void ROIAlign(const T *x, const T *roi_boxes, int roi_rows, int roi_cols, T *out_data, const T spatial_scale, + const int sample_num, int roi_end_mode, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.cc new file mode 100644 index 0000000000..c79e3af080 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + ROIAlign, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ROIAlignGpuFwdKernel, float) + +MS_REG_GPU_KERNEL_ONE( + ROIAlign, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ROIAlignGpuFwdKernel, half) + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h new file mode 100644 index 0000000000..943749ed11 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ROIAlignGpuFwdKernel : public GpuKernel { + public: + ROIAlignGpuFwdKernel() : x_size_(0), rois_size_(0), output_size_(0) {} + ~ROIAlignGpuFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + const T *x = GetDeviceAddress(inputs, 0); + const T *rois = GetDeviceAddress(inputs, 1); + + T *out_data = GetDeviceAddress(outputs, 0); + + ROIAlign(x, rois, roi_rows_, roi_cols_, out_data, spatial_scale_, sample_num_, roi_end_mode_, channels_, height_, + width_, pooled_height_, pooled_width_, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + // Get the number of input args + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but RioAlign needs 2 input."; + return false; + } + + // Get the number of output args + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but RioAlign needs 1 output."; + return false; + } + + // Get the input shapes + auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto rois_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + + auto x_shape_size = x_shape.size(); + if (x_shape_size < 2) { + MS_LOG(ERROR) << "x shape szie is " << x_shape_size << ", but at lease 2D."; + return false; + } + + // Get channels, height & width + channels_ = x_shape_size >= 3 ? x_shape[x_shape_size - 3] : 1; + height_ = x_shape[x_shape_size - 2]; + width_ = x_shape[x_shape_size - 1]; + x_shape_ = {channels_, height_, width_}; + x_size_ = channels_ * height_ * width_ * sizeof(T); + + // Get rois rows and cols + roi_rows_ = rois_shape[0]; + roi_cols_ = rois_shape[1]; + rois_size_ = roi_rows_ * roi_cols_ * sizeof(T); + rois_shape_ = {roi_rows_, roi_cols_}; + + // Get primitive args + pooled_height_ = GetAttr(kernel_node, "pooled_height"); + pooled_width_ = GetAttr(kernel_node, "pooled_width"); + spatial_scale_ = static_cast(GetAttr(kernel_node, "spatial_scale")); + sample_num_ = GetAttr(kernel_node, "sample_num"); + roi_end_mode_ = GetAttr(kernel_node, "roi_end_mode"); + + // Get output_shape + output_shape_ = {roi_rows_, channels_, pooled_height_, pooled_width_}; + output_size_ = 1; + for (size_t i = 0; i < 4; i++) { + output_size_ *= output_shape_[i]; + } + output_size_ *= sizeof(T); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(x_size_); + input_size_list_.push_back(rois_size_); + output_size_list_.push_back(output_size_); + } + + private: + int pooled_height_; + int pooled_width_; + T spatial_scale_; + int sample_num_; + int roi_end_mode_; + + int roi_rows_; + int roi_cols_; + int channels_; + int height_; + int width_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector x_shape_; + std::vector rois_shape_; + std::vector output_shape_; + + size_t x_size_; + size_t rois_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H diff --git a/tests/st/ops/gpu/test_roi_align_op.py b/tests/st/ops/gpu/test_roi_align_op.py new file mode 100644 index 0000000000..e31d54ef05 --- /dev/null +++ b/tests/st/ops/gpu/test_roi_align_op.py @@ -0,0 +1,95 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_roi_align(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x = Tensor(np.array([[ + [[1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + [19, 20, 21, 22, 23, 24], + [25, 26, 27, 28, 29, 30], + [31, 32, 33, 34, 35, 36]] + ]], np.float32)) + + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32)) + + # test case 1 + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[2.75, 4.5, 6.5], + [13.25, 15., 17.], + [25.25, 27., 29.]]]] + assert (output.asnumpy() == expect).all() + + # test case 1 + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[2.75, 4.5, 6.5], + [13.25, 15., 17.], + [25.25, 27., 29.]]]] + assert (output.asnumpy() == expect).all() + + # test case 2 + pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3 + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[1.2333, 2.1000, 3.3000, 4.5000], + [6.4333, 7.3000, 8.5000, 9.7000], + [13.6333, 14.5000, 15.7000, 16.9000], + [20.8333, 21.7000, 22.9000, 24.1000]]]] + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) + + # test case 3 + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.3, 3 + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0], + [0, 1.0, 0.0, 19.0, 18.0]], + np.float32)) + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[3.3333, 5.5000, 7.6667], + [16.3333, 18.5000, 20.6667], + [29.3333, 31.5000, 33.6667]]], + [[[4.5000, 6.3000, 8.1000], + [14.9000, 16.7000, 18.5000], + [25.7000, 27.5000, 29.3000]]]] + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) + + # test case 4 + pooled_height, pooled_width, spatial_scale, sample_num = 2, 2, 1.0, -1 + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32)) + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[4.625, 0.], + [0., 0.]]]] + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)