!43693 fix cleancode check in ROIAlign&Grad kernels, add dynamic rank testcases

Merge pull request !43693 from zhengzuohe/roialigngrad
This commit is contained in:
i-robot 2022-10-12 09:05:04 +00:00 committed by Gitee
commit 2490ab5152
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 194 additions and 123 deletions

View File

@ -24,23 +24,23 @@ namespace {
template <typename T>
void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high, int *y_high,
T *w1, T *w2, T *w3, T *w4) {
constexpr float eps = 0.00007;
const T ZERO = T(0.0);
const T ONE = T(1.0);
const T NEG_ONE = static_cast<T>(-1.0);
if (y < NEG_ONE || y > static_cast<T>(height) || x < NEG_ONE || x > static_cast<T>(width)) {
constexpr float kEps = 0.00007;
const T kZero = T(0.0);
const T kOne = T(1.0);
const T kMinusOne = static_cast<T>(-1.0);
if (y < kMinusOne || y > static_cast<T>(height) || x < kMinusOne || x > static_cast<T>(width)) {
*w1 = *w2 = *w3 = *w4 = static_cast<T>(0);
*x_low = *x_high = *y_low = *y_high = -1;
return;
}
// low bounder is at least zero
y = y <= ZERO ? ZERO : y;
x = x <= ZERO ? ZERO : x;
y = y <= kZero ? kZero : y;
x = x <= kZero ? kZero : x;
// top left point
*y_low = (y <= static_cast<T>(eps) ? 0 : static_cast<int>(floor(y)));
*x_low = (x <= static_cast<T>(eps) ? 0 : static_cast<int>(floor(x)));
*y_low = (y <= static_cast<T>(kEps) ? 0 : static_cast<int>(floor(y)));
*x_low = (x <= static_cast<T>(kEps) ? 0 : static_cast<int>(floor(x)));
// bottom right point
if (*y_low >= height - 1) {
@ -60,7 +60,7 @@ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_lo
// distance to nearest points
T lx, ly, hx, hy;
ly = y - static_cast<T>(*y_low), lx = x - static_cast<T>(*x_low);
hy = ONE - ly, hx = ONE - lx;
hy = kOne - ly, hx = kOne - lx;
// weight is evaluated by the distance to point away.
// the closer to point home, the more weight, the farther to point away.
@ -73,12 +73,12 @@ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const T spatial_s
int roi_end_mode, const int channels, const int height, const int width, const int pooled_height,
const int pooled_width, int *offset, int *n, int *c, int *ph, int *pw, int *roi_bin_grid_h,
int *roi_bin_grid_w, T *bin_size_h, T *bin_size_w, T *roi_start_h, T *roi_start_w) {
constexpr int START_W = 0;
constexpr int START_H = 1;
constexpr int END_W = 2;
constexpr int END_H = 3;
constexpr float eps = 0.00007;
constexpr int ROIS_COLS = 5;
constexpr float kEps = 0.00007;
constexpr int kStartW = 0;
constexpr int kStartH = 1;
constexpr int kEndW = 2;
constexpr int kEndH = 3;
constexpr int kRoisCols = 5;
// (n, c, ph, pw) is the base param of pooled map
*pw = thread_idx % pooled_width;
*ph = (thread_idx / pooled_width) % pooled_height;
@ -90,16 +90,16 @@ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const T spatial_s
// 2. indicator + 4 points (1 + 4)
const T *roi_box = roi_boxes + (*n) * roi_cols;
int roi_batch_ind = 0;
if (roi_cols == ROIS_COLS) {
roi_batch_ind = FloatToInt(rintf(static_cast<float>(roi_box[0]) + eps));
if (roi_cols == kRoisCols) {
roi_batch_ind = FloatToInt(rintf(static_cast<float>(roi_box[0]) + kEps));
roi_box++;
}
// Scale and shift ROI
*roi_start_w = roi_box[START_W] * spatial_scale;
*roi_start_h = roi_box[START_H] * spatial_scale;
T roi_end_w = (roi_box[END_W] + static_cast<T>(roi_end_mode)) * spatial_scale;
T roi_end_h = (roi_box[END_H] + static_cast<T>(roi_end_mode)) * spatial_scale;
*roi_start_w = roi_box[kStartW] * spatial_scale;
*roi_start_h = roi_box[kStartH] * spatial_scale;
T roi_end_w = (roi_box[kEndW] + static_cast<T>(roi_end_mode)) * spatial_scale;
T roi_end_h = (roi_box[kEndH] + static_cast<T>(roi_end_mode)) * spatial_scale;
// New ROI height/width
T roi_width = roi_end_w - (*roi_start_w);
@ -159,14 +159,14 @@ int ROIAlignCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
// Get the input shapes
auto x_shape = inputs[kIndex0]->GetShapeVector();
auto rois_shape = inputs[kIndex1]->GetShapeVector();
constexpr size_t X_DIMS = 4;
constexpr size_t ROIS_DIMS = 2;
if (x_shape.size() > X_DIMS) {
constexpr size_t kFeatureDims = 4;
constexpr size_t kRoisDims = 2;
if (x_shape.size() > kFeatureDims) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'features' cannot be greater than 4, but got "
<< x_shape.size() << ".";
return KRET_RESIZE_FAILED;
}
if (rois_shape.size() != ROIS_DIMS) {
if (rois_shape.size() != kRoisDims) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'rois' must be equal to 2, but got "
<< rois_shape.size() << ".";
return KRET_RESIZE_FAILED;
@ -201,13 +201,13 @@ bool ROIAlignCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
size_t elem_num = IntToSize(roi_rows_ * channels_ * pooled_height_ * pooled_width_);
auto task = [this, &input, &rois, &out_data](size_t start, size_t end) {
const T OFFSET = T(0.001);
const T ZERO = T(0.0);
const T kOffset = T(0.001);
const T kZero = T(0.0);
const T spatial_scale = static_cast<T>(spatial_scale_);
for (size_t thread_idx = start; thread_idx < end; thread_idx++) {
int n = SizeToInt(thread_idx) / pooled_width_ / pooled_height_ / channels_;
const T *roi_box = rois + n * roi_cols_;
if (roi_box[1] < OFFSET && roi_box[3] < OFFSET && roi_box[1] > -OFFSET && roi_box[3] > -OFFSET) {
if (roi_box[1] < kOffset && roi_box[3] < kOffset && roi_box[1] > -kOffset && roi_box[3] > -kOffset) {
continue;
}
int offset = -1;
@ -221,7 +221,7 @@ bool ROIAlignCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
// (n, c, ph, pw) is the base param of pooled map
const T count_points_in_grid_cell = static_cast<T>(roi_bin_grid_h) * static_cast<T>(roi_bin_grid_w);
T accumulate_val = ZERO;
T accumulate_val = kZero;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
// Shift half point RIGHT for y / x, while previous scaled roi shift half point LEFT
const T y = roi_start_h + static_cast<T>(ph) * bin_size_h +

View File

@ -28,23 +28,23 @@ namespace {
template <typename T>
void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high, int *y_high,
T *w1, T *w2, T *w3, T *w4) {
constexpr float eps = 0.00007;
const T ZERO = T(0.0);
const T ONE = T(1.0);
const T NEG_ONE = static_cast<T>(-1.0);
if (y < NEG_ONE || y > static_cast<T>(height) || x < NEG_ONE || x > static_cast<T>(width)) {
*w1 = *w2 = *w3 = *w4 = ZERO;
constexpr float kEps = 0.00007;
const T kZero = T(0.0);
const T kOne = T(1.0);
const T kMinusOne = static_cast<T>(-1.0);
if (y < kMinusOne || y > static_cast<T>(height) || x < kMinusOne || x > static_cast<T>(width)) {
*w1 = *w2 = *w3 = *w4 = kZero;
*x_low = *x_high = *y_low = *y_high = -1;
return;
}
// low bounder is at least zero
y = y <= ZERO ? ZERO : y;
x = x <= ZERO ? ZERO : x;
y = y <= kZero ? kZero : y;
x = x <= kZero ? kZero : x;
// top left point
*y_low = (y <= static_cast<T>(eps) ? 0 : static_cast<int>(floor(y)));
*x_low = (x <= static_cast<T>(eps) ? 0 : static_cast<int>(floor(x)));
*y_low = (y <= static_cast<T>(kEps) ? 0 : static_cast<int>(floor(y)));
*x_low = (x <= static_cast<T>(kEps) ? 0 : static_cast<int>(floor(x)));
// bottom right point
if (*y_low >= height - 1) {
@ -64,7 +64,7 @@ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_lo
// distance to nearest points
T lx, ly, hx, hy;
ly = y - static_cast<T>(*y_low), lx = x - static_cast<T>(*x_low);
hy = ONE - ly, hx = ONE - lx;
hy = kOne - ly, hx = kOne - lx;
// weight is evaluated by the distance to point away.
// the closer to point home, the more weight, the farther to point away.
@ -77,12 +77,12 @@ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const T spatial_s
int roi_end_mode, const int channels, const int height, const int width, const int pooled_height,
const int pooled_width, int *offset, int *n, int *c, int *ph, int *pw, int *roi_bin_grid_h,
int *roi_bin_grid_w, T *bin_size_h, T *bin_size_w, T *roi_start_h, T *roi_start_w) {
constexpr float eps = 0.00007;
constexpr int START_W = 0;
constexpr int START_H = 1;
constexpr int END_W = 2;
constexpr int END_H = 3;
constexpr size_t ROIS_COLS = 5;
constexpr float kEps = 0.00007;
constexpr int kStartW = 0;
constexpr int kStartH = 1;
constexpr int kEndW = 2;
constexpr int kEndH = 3;
constexpr size_t kRoisCols = 5;
// (n, c, ph, pw) is the base param of pooled map
*pw = thread_idx % pooled_width;
*ph = (thread_idx / pooled_width) % pooled_height;
@ -94,16 +94,16 @@ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const T spatial_s
// 2. indicator + 4 points (1 + 4)
const T *roi_box = roi_boxes + (*n) * roi_cols;
int roi_batch_ind = 0;
if (roi_cols == ROIS_COLS) {
roi_batch_ind = FloatToInt(rintf(static_cast<float>(roi_box[0]) + eps));
if (roi_cols == kRoisCols) {
roi_batch_ind = FloatToInt(rintf(static_cast<float>(roi_box[0]) + kEps));
roi_box++;
}
// Scale and shift ROI
*roi_start_w = roi_box[START_W] * spatial_scale;
*roi_start_h = roi_box[START_H] * spatial_scale;
T roi_end_w = (roi_box[END_W] + static_cast<T>(roi_end_mode)) * spatial_scale;
T roi_end_h = (roi_box[END_H] + static_cast<T>(roi_end_mode)) * spatial_scale;
*roi_start_w = roi_box[kStartW] * spatial_scale;
*roi_start_h = roi_box[kStartH] * spatial_scale;
T roi_end_w = (roi_box[kEndW] + static_cast<T>(roi_end_mode)) * spatial_scale;
T roi_end_h = (roi_box[kEndH] + static_cast<T>(roi_end_mode)) * spatial_scale;
// New ROI height/width
T roi_width = roi_end_w - (*roi_start_w);
@ -173,19 +173,19 @@ int ROIAlignGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
// Get the input shapes
auto dy_shape = inputs[kIndex0]->GetShapeVector();
auto rois_shape = inputs[kIndex1]->GetShapeVector();
constexpr size_t DX_DY_DIMS = 4;
constexpr size_t ROIS_DIMS = 2;
if (dy_shape.size() != DX_DY_DIMS) {
constexpr size_t kDiffDims = 4;
constexpr size_t kRoisDims = 2;
if (dy_shape.size() != kDiffDims) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'dy' must be 4, but got " << dy_shape.size()
<< ".";
return KRET_RESIZE_FAILED;
}
if (rois_shape.size() != ROIS_DIMS) {
if (rois_shape.size() != kRoisDims) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'rois' must be 2, but got " << rois_shape.size()
<< ".";
return KRET_RESIZE_FAILED;
}
if (xdiff_shape_.size() > DX_DY_DIMS) {
if (xdiff_shape_.size() > kDiffDims) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the length of xdiff_shape cannot be greater than 4, but got "
<< xdiff_shape_.size() << ".";
return KRET_RESIZE_FAILED;
@ -242,21 +242,21 @@ bool ROIAlignGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &input
int size_init = batch_ * channels_ * height_ * width_;
auto task1 = [this, &dx](size_t start, size_t end) {
const T ZERO = T(0.0);
const T kZero = T(0.0);
for (size_t thread_idx = start; thread_idx < end; thread_idx++) {
dx[thread_idx] = ZERO;
dx[thread_idx] = kZero;
}
};
ParallelLaunchAutoSearch(task1, IntToSize(size_init), this, &parallel_search_info_);
int elem_num = roi_rows_ * channels_ * pooled_height_ * pooled_width_;
auto task2 = [this, &dy, &rois, &dx](size_t start, size_t end) {
const T OFFSET = T(0.001);
const T kOffset = T(0.001);
for (size_t thread_idx = start; thread_idx < end; thread_idx++) {
int n = SizeToInt(thread_idx) / pooled_width_ / pooled_height_ / channels_;
const T *roi_box = rois + n * roi_cols_;
const T spatial_scale = static_cast<T>(spatial_scale_);
if (roi_box[1] < OFFSET && roi_box[3] < OFFSET && roi_box[1] > -OFFSET && roi_box[3] > -OFFSET) {
if (roi_box[1] < kOffset && roi_box[3] < kOffset && roi_box[1] > -kOffset && roi_box[3] > -kOffset) {
continue;
}
int offset = -1;

View File

@ -21,11 +21,11 @@ namespace kernel {
bool ROIAlignGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
// Check input and output numbers
constexpr size_t input_num = 2;
constexpr size_t output_num = 1;
constexpr size_t kInputNum = 2;
constexpr size_t kOutputNum = 1;
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
@ -48,14 +48,14 @@ int ROIAlignGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
// Get the input shapes
auto x_shape = inputs[kIndex0]->GetShapeVector();
auto rois_shape = inputs[kIndex1]->GetShapeVector();
constexpr size_t X_DIMS = 4;
constexpr size_t ROIS_DIMS = 2;
if (x_shape.size() > X_DIMS) {
constexpr size_t kFeatureDims = 4;
constexpr size_t kRoisDims = 2;
if (x_shape.size() > kFeatureDims) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of features cannot be greater than 4, but got "
<< x_shape.size() << ".";
return KRET_RESIZE_FAILED;
}
if (rois_shape.size() != ROIS_DIMS) {
if (rois_shape.size() != kRoisDims) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of rois must be equal to 2, but got "
<< rois_shape.size() << ".";
return KRET_RESIZE_FAILED;

View File

@ -22,15 +22,15 @@ namespace kernel {
bool ROIAlignGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
// Check input and output numbers
constexpr size_t input_num_no_xshape = 2;
constexpr size_t input_num_with_xshape = 3;
constexpr size_t output_num = 1;
constexpr size_t kInputNumNoShape = 2;
constexpr size_t kInputNumWithShape = 3;
constexpr size_t kOutputNum = 1;
kernel_name_ = base_operator->name();
if (inputs.size() != input_num_no_xshape && inputs.size() != input_num_with_xshape) {
if (inputs.size() != kInputNumNoShape && inputs.size() != kInputNumWithShape) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << inputs.size()
<< ".";
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
@ -40,7 +40,7 @@ bool ROIAlignGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
pooled_width_ = op->get_pooled_width();
spatial_scale_ = op->get_spatial_scale();
sample_num_ = op->get_sample_num();
if (inputs.size() == input_num_with_xshape) {
if (inputs.size() == kInputNumWithShape) {
is_xdiff_shape_dyn_ = true;
return true;
}
@ -63,19 +63,19 @@ int ROIAlignGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
// Get the input shapes
auto dy_shape = inputs[kIndex0]->GetShapeVector();
auto rois_shape = inputs[kIndex1]->GetShapeVector();
constexpr size_t dx_dy_shape_size = 4;
constexpr size_t rois_shape_size = 2;
if (dy_shape.size() != dx_dy_shape_size) {
constexpr size_t kDiffDims = 4;
constexpr size_t kRoisDims = 2;
if (dy_shape.size() != kDiffDims) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of dy must be equal to 4, but got " << dy_shape.size()
<< ".";
return KRET_RESIZE_FAILED;
}
if (rois_shape.size() != rois_shape_size) {
if (rois_shape.size() != kRoisDims) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of rois must be equal to 2, but got "
<< rois_shape.size() << ".";
return KRET_RESIZE_FAILED;
}
if (xdiff_shape_.size() > dx_dy_shape_size) {
if (xdiff_shape_.size() > kDiffDims) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the length of xdiff_shape cannot be greater than 4, but got "
<< xdiff_shape_.size() << ".";
return KRET_RESIZE_FAILED;

View File

@ -46,8 +46,7 @@ class ROIAlignInfer : public abstract::OpInferBase {
constexpr size_t kFeatureShapeSize = 4;
(void)CheckAndConvertUtils::CheckInteger("rank of feature shape", SizeToLong(feature_shape.size()), kLessEqual,
kFeatureShapeSize, op_name);
const int64_t channel_index = kInputIndex1;
out_c = feature_shape[channel_index];
out_c = feature_shape[kInputIndex1];
}
if (IsDynamicRank(rois_shape)) {
out_n = abstract::Shape::kShapeDimAny;
@ -61,8 +60,7 @@ class ROIAlignInfer : public abstract::OpInferBase {
(void)CheckAndConvertUtils::CheckInteger("second dim of rois shape", rois_second_dim, kEqual,
kRoisShapeSecondDim);
}
const int64_t roi_num_index = kInputIndex0;
out_n = rois_shape[roi_num_index];
out_n = rois_shape[kInputIndex0];
}
ShapeVector output_shape;
auto pooled_height_ptr = primitive->GetAttr(kPooledHeight);

View File

@ -15,28 +15,36 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, ops
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import _inner_ops as inner
class NetROIAlignGrad(nn.Cell):
def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num):
def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num, is_dyn_rank=False):
super(NetROIAlignGrad, self).__init__()
self.shape = ops.Shape()
self.dyn_shape = ops.TensorShape()
self.roi_align_grad = G.ROIAlignGrad(pooled_height, pooled_width, spatial_scale, sample_num)
self.is_dyn_rank = is_dyn_rank
self.convert_to_dynamic_rank = inner.ConvertToDynamic(is_dynamic_rank=is_dyn_rank).add_prim_attr(
"primitive_target", "CPU"
)
def construct(self, dy, rois, xdiff):
if self.is_dyn_rank:
dy = self.convert_to_dynamic_rank(dy)
rois = self.convert_to_dynamic_rank(rois)
xdiff = self.convert_to_dynamic_rank(xdiff)
xdiff_shape = self.shape(xdiff)
if -1 in xdiff_shape or -2 in xdiff_shape:
xdiff_shape = self.dyn_shape(xdiff)
return self.roi_align_grad(dy, rois, xdiff_shape)
def roi_align_grad_case(data_type=np.float16, is_dyn_shape=False):
def roi_align_grad_case(data_type=np.float16, is_dyn_shape=False, is_dyn_rank=False):
rois = Tensor(np.array([[0, -2.0, -2.0, 21.0, 21.0]], data_type))
dy = Tensor(np.array([[[[0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]]], data_type))
@ -45,12 +53,11 @@ def roi_align_grad_case(data_type=np.float16, is_dyn_shape=False):
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
roi_align_grad = NetROIAlignGrad(pooled_height, pooled_width, spatial_scale, sample_num)
roi_align_grad = NetROIAlignGrad(pooled_height, pooled_width, spatial_scale, sample_num, is_dyn_rank)
if is_dyn_shape:
dtype_map = {np.float16: ms.float16, np.float32: ms.float32}
dyn_dx_dy = Tensor(shape=(None, None, None, None), dtype=dtype_map.get(data_type))
dyn_rois = Tensor(shape=(None, None), dtype=dtype_map.get(data_type))
dyn_dx_dy = Tensor(shape=(None, None, None, None), dtype=dy.dtype)
dyn_rois = Tensor(shape=(None, None), dtype=dy.dtype)
roi_align_grad.set_inputs(dyn_dx_dy, dyn_rois, dyn_dx_dy)
output = roi_align_grad(dy, rois, xdiff)
@ -75,33 +82,63 @@ def roi_align_grad_case(data_type=np.float16, is_dyn_shape=False):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_roi_align_grad():
def test_roi_align_grad_float16():
"""
Feature: Test the operator ROIAlignGrad
Description: Test in GRAPH and PYNATIVE mode using float32 and float16 inputs
Description: Test in GRAPH and PYNATIVE mode using float16 inputs
Expectation: Assert the result is equal to the expectation
"""
context.set_context(mode=context.GRAPH_MODE)
roi_align_grad_case(np.float32)
roi_align_grad_case(np.float16)
context.set_context(mode=context.PYNATIVE_MODE)
roi_align_grad_case(np.float32)
roi_align_grad_case(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_roi_align_grad_dynamic_shape():
def test_roi_align_grad_float32():
"""
Feature: Test the operator ROIAlignGrad
Description: Test in GRAPH and PYNATIVE mode using float32 inputs
Expectation: Assert the result is equal to the expectation
"""
context.set_context(mode=context.GRAPH_MODE)
roi_align_grad_case(np.float32)
context.set_context(mode=context.PYNATIVE_MODE)
roi_align_grad_case(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align_grad_float16_dynamic_shape():
"""
Feature: Test the operator ROIAlignGrad with dynamic shape inputs
Description: Test in GRAPH and PYNATIVE mode using float32 and float16 dynamic shape inputs
Description: Test in GRAPH and PYNATIVE mode using float16 dynamic shape inputs
Expectation: Assert the result is equal to the expectation
"""
context.set_context(mode=context.GRAPH_MODE)
roi_align_grad_case(np.float16, True)
context.set_context(mode=context.PYNATIVE_MODE)
roi_align_grad_case(np.float16, True)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align_grad_float32_dynamic_rank():
"""
Feature: Test the operator ROIAlignGrad with dynamic rank inputs
Description: Test in GRAPH and PYNATIVE mode using float32 dynamic rank inputs
Expectation: Assert the result is equal to the expectation
"""
context.set_context(mode=context.GRAPH_MODE)
roi_align_grad_case(np.float32, True)
roi_align_grad_case(np.float16, True)
context.set_context(mode=context.PYNATIVE_MODE)
roi_align_grad_case(np.float32, True)
roi_align_grad_case(np.float16, True)

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,24 +15,31 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
class NetROIAlign(nn.Cell):
def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode):
def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode, is_dyn_rank=False):
super(NetROIAlign, self).__init__()
self.roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode)
self.is_dyn_rank = is_dyn_rank
self.convert_to_dynamic_rank = inner.ConvertToDynamic(is_dynamic_rank=is_dyn_rank).add_prim_attr(
"primitive_target", "CPU"
)
def construct(self, features, rois):
if self.is_dyn_rank:
features = self.convert_to_dynamic_rank(features)
rois = self.convert_to_dynamic_rank(rois)
return self.roi_align(features, rois)
def roi_align_case(data_type=np.float16, is_dyn_shape=False):
x = Tensor(
def roi_align_case(data_type=np.float16, is_dyn_shape=False, is_dyn_rank=False):
features = Tensor(
np.array(
[
[
@ -50,37 +57,36 @@ def roi_align_case(data_type=np.float16, is_dyn_shape=False):
)
)
dtype_map = {np.float16: ms.float16, np.float32: ms.float32}
dyn_features = Tensor(shape=(None, None, None, None), dtype=dtype_map.get(data_type))
dyn_rois = Tensor(shape=(None, None), dtype=dtype_map.get(data_type))
dyn_features = Tensor(shape=(None, None, None, None), dtype=features.dtype)
dyn_rois = Tensor(shape=(None, None), dtype=features.dtype)
# test case 1
rois = Tensor(np.array([[0, -2.0, -2.0, 21.0, 21.0]], data_type))
pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode = 3, 3, 0.25, 2, 1
roi_align = NetROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode)
roi_align = NetROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode, is_dyn_rank)
if is_dyn_shape:
roi_align.set_inputs(dyn_features, dyn_rois)
output = roi_align(x, rois)
output = roi_align(features, rois)
expect = [[[[4.5, 6.5, 8.5], [16.5, 18.5, 20.5], [28.5, 30.5, 32.5]]]]
assert (output.asnumpy() == expect).all()
# test case 2
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], data_type))
pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode = 3, 3, 0.25, 2, 0
roi_align = NetROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode)
roi_align = NetROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode, is_dyn_rank)
if is_dyn_shape:
roi_align.set_inputs(dyn_features, dyn_rois)
output = roi_align(x, rois)
output = roi_align(features, rois)
expect = [[[[4.5, 6.5, 8.5], [16.5, 18.5, 20.5], [28.5, 30.5, 32.5]]]]
assert (output.asnumpy() == expect).all()
# test case 3
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], data_type))
pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode = 2, 2, 1.0, -1, 0
roi_align = NetROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode)
roi_align = NetROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, roi_end_mode, is_dyn_rank)
if is_dyn_shape:
roi_align.set_inputs(dyn_features, dyn_rois)
output = roi_align(x, rois)
output = roi_align(features, rois)
expect = [[[[6.295, 0.0], [0.0, 0.0]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=2)
@ -91,33 +97,63 @@ def roi_align_case(data_type=np.float16, is_dyn_shape=False):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_roi_align():
def test_roi_align_float16():
"""
Feature: Test the operator ROIAlign
Description: Test in GRAPH and PYNATIVE mode using float32 and float16 inputs
Description: Test in GRAPH and PYNATIVE mode using float16 inputs
Expectation: Assert the result is equal to the expectation
"""
context.set_context(mode=context.PYNATIVE_MODE)
roi_align_case(np.float32)
roi_align_case(np.float16)
context.set_context(mode=context.GRAPH_MODE)
roi_align_case(np.float32)
roi_align_case(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_roi_align_dynamic_shape():
def test_roi_align_float32():
"""
Feature: Test the operator ROIAlign with dynamic shape inputs
Description: Test in GRAPH and PYNATIVE mode using float32 and float16 dynamic shape inputs
Feature: Test the operator ROIAlign
Description: Test in GRAPH and PYNATIVE mode using float32 inputs
Expectation: Assert the result is equal to the expectation
"""
context.set_context(mode=context.PYNATIVE_MODE)
roi_align_case(np.float32)
context.set_context(mode=context.GRAPH_MODE)
roi_align_case(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align_float16_dynamic_shape():
"""
Feature: Test the operator ROIAlign with dynamic shape inputs
Description: Test in GRAPH and PYNATIVE mode using float16 dynamic shape inputs
Expectation: Assert the result is equal to the expectation
"""
context.set_context(mode=context.PYNATIVE_MODE)
roi_align_case(np.float32, True)
roi_align_case(np.float16, True)
context.set_context(mode=context.GRAPH_MODE)
roi_align_case(np.float32, True)
roi_align_case(np.float16, True)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align_float32_dynamic_rank():
"""
Feature: Test the operator ROIAlign with dynamic rank inputs
Description: Test in GRAPH and PYNATIVE mode using float32 dynamic rank inputs
Expectation: Assert the result is equal to the expectation
"""
context.set_context(mode=context.PYNATIVE_MODE)
roi_align_case(np.float32, False, True)
context.set_context(mode=context.GRAPH_MODE)
roi_align_case(np.float32, False, True)