forked from mindspore-Ecosystem/mindspore
!35383 Add cpu kernel for deformable_offsets
Merge pull request !35383 from YuJianfeng/deformable_conv
This commit is contained in:
commit
8c823836a6
|
@ -46,6 +46,7 @@ functional算子是经过初始化后的Primitive,可以直接作为函数使
|
|||
|
||||
mindspore.ops.adaptive_avgpool2d
|
||||
mindspore.ops.pdist
|
||||
mindspore.ops.deformable_conv2d
|
||||
|
||||
激活函数
|
||||
^^^^^^^^^^
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
mindspore.ops.deformable_conv2d
|
||||
===============================
|
||||
|
||||
.. py:function:: mindspore.ops.deformable_conv2d(x, weight, offsets, kernel_size, strides, padding, bias=None, dilations=(1, 1, 1, 1), groups=1, deformable_groups=1, modulated=True)
|
||||
|
||||
给定4D的Tensor输入`x` , `weight` 和 `offsets` ,计算一个2D的可变形卷积。可变形卷积运算可以表达如下:
|
||||
可变形卷积v1:
|
||||
.. math::
|
||||
y(p)=\sum_{k=1}^{K}w_{k}\cdot x(p+p_{k}+\Delta{p_{k}})
|
||||
|
||||
可变形卷积v2:
|
||||
.. math::
|
||||
y(p)=\sum_{k=1}^{K}w_{k}\cdot x(p+p_{k}+\Delta{p_{k}})\cdot \Delta{m_{k}}
|
||||
|
||||
其中 :math:`\Delta{p_{k}}` 和 :math:`\Delta{m_{k}}` 分别为第k个位置的可学习偏移和调制标量。细节请参考论文 `Deformable ConvNets v2: More Deformable, Better Results <https://arxiv.org/abs/1811.11168>`_ 和 `Deformable Convolutional Networks <https://arxiv.org/abs/1703.06211>`_ 。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **x** (Tensor) - 一个四维Tensor,表示输入图像。数据格式为"NCHW",shape为 :math:`(N, C_{in}, H_{in}, W_{in})` 。Dytpe为float16或float32。
|
||||
- **weight** (Tensor) - 一个四维Tensor,表示可学习的滤波器。数据类型必须与 `x` 相同,shape为 :math:`(C_{out}, C_{in} / groups, H_{f}, W_{f})` 。
|
||||
- **offsets** (Tensor) - 一个四维Tensor,存储x和y坐标的偏移,以及可变形卷积的输入掩码mask。数据格式为"NCHW",shape为 :math:`(batch, 3 * deformable_groups * H_{f} * W_{f}, H_{out}, W_{out})` ,注意其中C维度的存储顺序为(offset_x, offset_y, mask)。数据类型必须与 `x` 相同。
|
||||
- **kernel_size** (tuple[int]) - 一个包含两个整数的元组,表示卷积核的大小。
|
||||
- **strides** (tuple[int]) - 一个包含四个整数的元组,表示对于输入的每个维度的滑动窗口步长。其维度顺序依据 `x` 的数据格式,对应N和C维度的值必须设置成1。
|
||||
- **padding** (tuple[int]) - 一个包含四个整数的元组,表示沿(上,下,左,右)四个方向往输入填充的像素点个数。
|
||||
- **bias** (Tensor, 可选) - 一个一维Tensor,表示加到卷积输出的偏置参数。shape为 :math:`(out_channels)` 。默认值为None。
|
||||
- **dilations** (tuple[int], 可选) - 一个包含四个整数的元组,表示对于输入的每个维度的膨胀系数。其维度顺序依据 `x` 的数据格式,对应N和C维度的值必须设置成1。默认值为(1, 1, 1, 1)。
|
||||
- **groups** (int, 可选) - 一个int32类型的整数,表示从输入通道到输出通道的阻塞连接数。输入通道数和输出通道数必须都能被 `groups` 整除。默认值为1。
|
||||
- **deformable_groups** (int, 可选) - 一个int32类型的整数,表示可变形卷积组数。输入通道数必须能被 `deformable_groups` 整除。默认值为1。
|
||||
- **modulated** (int, 可选) - 指定可变形二维卷积的版本。True表示v2,False表示v1。当前只支持设置为v2版本。默认值为True。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,一个四维Tensor,表示输出特征图。数据类型与 `x` 相同,数据格式为"NCHW",shape为 :math:`(N, C_{out}, H_{out}, W_{out})` 。
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
H_{out} = \left \lfloor{\frac{H_{in} + padding[0] + padding[1] - (H_{f} - 1) \times
|
||||
\text{dilations[3]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
|
||||
W_{out} = \left \lfloor{\frac{W_{in} + padding[2] + padding[3] - (W_{f} - 1) \times
|
||||
\text{dilations[4]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
|
||||
\end{array}
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果 `strides`, `padding`, `kernel_size` 或者 `dilations` 不是一个整数元组。
|
||||
- **TypeError** - 如果 `modulated` 不是一个布尔值。
|
||||
- **ValueError** - 如果 `strides`, `padding`, `kernel_size` 或者 `dilations` 的元组不是期望的大小。
|
||||
- **ValueError** - 如果 `strides` 或者 `dilations` 对应N和C维度的值不为1。
|
||||
- **ValueError** - 如果 `modulated` 的值不是True。
|
||||
|
||||
.. note::
|
||||
- 这是一个实验性质的接口,将来有可能被修改或删除。
|
||||
- 在Ascend平台上,只支持同时满足 :math:`C_{in}` 能被8整除, `deformable_groups` 为1且 `offsets` 的数据是浮点数类型(即需要包含小数部分)的场景。例如, `x` 的shape为 :math:`(N, 2, H_{in}, W_{in})` 、 `deformable_groups` 为2、使用"numpy.ones()"函数去赋值 `offsets` 等场景均不支持。
|
|
@ -46,6 +46,7 @@ Neural Network
|
|||
|
||||
mindspore.ops.adaptive_avgpool2d
|
||||
mindspore.ops.pdist
|
||||
mindspore.ops.deformable_conv2d
|
||||
|
||||
Activation Functions
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -0,0 +1,284 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/deformable_offsets_cpu_kernel.h"
|
||||
#include <memory>
|
||||
#include "ops/deformable_offsets.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kInputsSize = 2;
|
||||
constexpr size_t kOutputsSize = 1;
|
||||
constexpr size_t kStridesSize = 4;
|
||||
constexpr size_t kPadsSize = 4;
|
||||
constexpr size_t kKernelSizeSize = 2;
|
||||
constexpr size_t kKernelSizeHIndex = 0;
|
||||
constexpr size_t kKernelSizeWIndex = 1;
|
||||
constexpr size_t kDilationsSize = 4;
|
||||
constexpr size_t kXShapeSize = 4;
|
||||
constexpr size_t kOutputShapeSize = 4;
|
||||
constexpr size_t kPadTopIndex = 0;
|
||||
constexpr size_t kPadLeftIndex = 2;
|
||||
constexpr size_t kOffsetsSize = 3;
|
||||
|
||||
template <typename T>
|
||||
T DeformableBilinear(const T *input, T x, T y, int64_t width, int64_t height) {
|
||||
if (y <= static_cast<T>(-1) || y >= static_cast<T>(height) || x <= static_cast<T>(-1) || x >= static_cast<T>(width)) {
|
||||
return static_cast<T>(0);
|
||||
}
|
||||
int64_t left;
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
left = static_cast<int64_t>(floorf(x));
|
||||
} else {
|
||||
left = static_cast<int64_t>(floor(x));
|
||||
}
|
||||
auto right = left + 1;
|
||||
int64_t top;
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
top = static_cast<int64_t>(floorf(y));
|
||||
} else {
|
||||
top = static_cast<int64_t>(floor(y));
|
||||
}
|
||||
auto bottom = top + 1;
|
||||
|
||||
T l = x - static_cast<T>(left);
|
||||
T r = static_cast<T>(1) - l;
|
||||
T t = y - static_cast<T>(top);
|
||||
T b = static_cast<T>(1) - t;
|
||||
|
||||
T lt = static_cast<T>(0);
|
||||
T lb = static_cast<T>(0);
|
||||
if (left >= 0) {
|
||||
if (top >= 0) {
|
||||
lt = input[top * width + left];
|
||||
}
|
||||
if (bottom <= height - 1) {
|
||||
lb = input[bottom * width + left];
|
||||
}
|
||||
}
|
||||
T rt = static_cast<T>(0);
|
||||
T rb = static_cast<T>(0);
|
||||
if (right <= width - 1) {
|
||||
if (top >= 0) {
|
||||
rt = input[top * width + right];
|
||||
}
|
||||
if (bottom <= height - 1) {
|
||||
rb = input[bottom * width + right];
|
||||
}
|
||||
}
|
||||
|
||||
T w_lt = r * b;
|
||||
T w_rt = l * b;
|
||||
T w_lb = r * t;
|
||||
T w_rb = l * t;
|
||||
T val = (w_lt * lt + w_rt * rt + w_lb * lb + w_rb * rb);
|
||||
return val;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool DeformableOffsetsCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.size() != kInputsSize || outputs.size() != kOutputsSize) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it should get " << kInputsSize << " inputs and " << kOutputsSize
|
||||
<< " outputs, but got " << inputs.size() << " inputs and " << outputs.size() << " outputs.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_ptr = std::make_shared<ops::DeformableOffsets>(base_operator->GetPrim());
|
||||
// Check args.
|
||||
n_axis_ = kIndex0;
|
||||
c_axis_ = kIndex1;
|
||||
h_axis_ = kIndex2;
|
||||
w_axis_ = kIndex3;
|
||||
strides_ = kernel_ptr->get_strides();
|
||||
if (strides_.size() != kStridesSize || strides_[n_axis_] != 1 || strides_[c_axis_] != 1) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'strides' should be a vector with size " << kStridesSize
|
||||
<< " and the values according to N and C dimensions must be set to 1. But got 'strides': "
|
||||
<< strides_;
|
||||
return false;
|
||||
}
|
||||
pads_ = kernel_ptr->get_pads();
|
||||
if (pads_.size() != kPadsSize) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'pads' should be a vector with size " << kPadsSize
|
||||
<< ". But got 'pads': " << pads_;
|
||||
return false;
|
||||
}
|
||||
kernel_size_ = kernel_ptr->get_kernel_size();
|
||||
if (kernel_size_.size() != kKernelSizeSize) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'kernel_size' should be a vector with size " << kKernelSizeSize
|
||||
<< ". But got 'kernel_size': " << kernel_size_;
|
||||
return false;
|
||||
}
|
||||
dilations_ = kernel_ptr->get_dilations();
|
||||
if (dilations_.size() != kDilationsSize || dilations_[n_axis_] != 1 || dilations_[c_axis_] != 1) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'dilations' should be a vector with size " << kDilationsSize
|
||||
<< " and the values according to N and C dimensions must be set to 1. But got 'dilations': "
|
||||
<< dilations_;
|
||||
return false;
|
||||
}
|
||||
deformable_groups_ = kernel_ptr->get_deformable_groups();
|
||||
if (deformable_groups_ <= 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'deformable_groups' should be greater than 0, but got "
|
||||
<< deformable_groups_;
|
||||
return false;
|
||||
}
|
||||
modulated_ = kernel_ptr->get_modulated();
|
||||
if (!modulated_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value of 'modulated' only support to be set to True.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return MatchKernelFunc(base_operator, inputs, outputs);
|
||||
}
|
||||
|
||||
void DeformableOffsetsCpuKernelMod::ResetResource() noexcept {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
int DeformableOffsetsCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
ResetResource();
|
||||
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto x_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
if (x_shape.size() != kXShapeSize) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape size of input 'x' should be " << kXShapeSize
|
||||
<< ", but got " << x_shape.size();
|
||||
}
|
||||
n_ = x_shape[n_axis_];
|
||||
c_ = x_shape[c_axis_];
|
||||
input_h_ = x_shape[h_axis_];
|
||||
input_w_ = x_shape[w_axis_];
|
||||
|
||||
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
if (output_shape.size() != kOutputShapeSize) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape size of output 'y' should be " << kOutputShapeSize
|
||||
<< ", but got " << output_shape.size();
|
||||
}
|
||||
output_h_ = output_shape[h_axis_];
|
||||
output_w_ = output_shape[w_axis_];
|
||||
position_grid_size_ = output_h_ * output_w_;
|
||||
(void)workspace_size_list_.emplace_back(sizeof(int64_t) * position_grid_size_ * kKernelSizeSize);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
void DeformableOffsetsCpuKernelMod::GenPositionGrid(int64_t *position_grid) {
|
||||
auto task = [this, &position_grid](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
auto long_i = SizeToLong(i);
|
||||
int64_t y = long_i / output_w_;
|
||||
int64_t x = long_i % output_w_;
|
||||
int64_t pixel_y = y / kernel_size_[kKernelSizeHIndex];
|
||||
int64_t pixel_x = x / kernel_size_[kKernelSizeWIndex];
|
||||
int64_t kernel_y = y % kernel_size_[kKernelSizeHIndex];
|
||||
int64_t kernel_x = x % kernel_size_[kKernelSizeWIndex];
|
||||
size_t index = i * 2;
|
||||
position_grid[index] = pixel_x * strides_[w_axis_] + kernel_x * dilations_[w_axis_] - pads_[kPadLeftIndex];
|
||||
position_grid[index + 1] = pixel_y * strides_[h_axis_] + kernel_y * dilations_[h_axis_] - pads_[kPadTopIndex];
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, LongToSize(output_h_ * output_w_), this, ¶llel_search_info_, pool_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DeformableOffsetsCpuKernelMod::DeformableOffsets(const T *input_addr, const T *offsets_addr,
|
||||
const int64_t *position_grid_addr, T *output_addr) {
|
||||
int64_t pixel_h = output_h_ / kernel_size_[kKernelSizeHIndex];
|
||||
int64_t pixel_w = output_w_ / kernel_size_[kKernelSizeWIndex];
|
||||
int64_t output_c_dim = output_h_ * output_w_;
|
||||
int64_t output_n_dim = c_ * output_c_dim;
|
||||
int64_t c_size_per_dfm_group = c_ / deformable_groups_;
|
||||
int64_t offset_kw_dim = pixel_h * pixel_w;
|
||||
int64_t offset_kh_dim = offset_kw_dim * kernel_size_[kKernelSizeWIndex];
|
||||
int64_t offset_group_dim = offset_kh_dim * kernel_size_[kKernelSizeHIndex];
|
||||
int64_t offset_mask_dim = offset_group_dim * deformable_groups_;
|
||||
int64_t offset_n_dim = offset_mask_dim * kOffsetsSize;
|
||||
int64_t input_c_dim = input_h_ * input_w_;
|
||||
int64_t input_n_dim = input_c_dim * c_;
|
||||
|
||||
auto task = [this, &input_addr, &offsets_addr, &output_addr, &position_grid_addr, &pixel_w, &output_c_dim,
|
||||
&output_n_dim, &c_size_per_dfm_group, &offset_kw_dim, &offset_kh_dim, &offset_group_dim,
|
||||
&offset_mask_dim, &offset_n_dim, &input_c_dim, &input_n_dim](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
auto long_i = SizeToLong(i);
|
||||
// Get input position
|
||||
int64_t hw_idx = long_i % output_c_dim;
|
||||
int64_t position_grid_idx = hw_idx * 2;
|
||||
int64_t input_x = position_grid_addr[position_grid_idx];
|
||||
int64_t input_y = position_grid_addr[position_grid_idx + 1];
|
||||
// Get offsets
|
||||
int64_t n_index = long_i / output_n_dim;
|
||||
int64_t c_index = long_i / output_c_dim % c_;
|
||||
int64_t x = hw_idx % output_w_;
|
||||
int64_t y = hw_idx / output_w_;
|
||||
int64_t dfm_group_index = c_index / c_size_per_dfm_group;
|
||||
int64_t pixel_x = x / kernel_size_[kKernelSizeWIndex];
|
||||
int64_t pixel_y = y / kernel_size_[kKernelSizeHIndex];
|
||||
int64_t kernel_x = x % kernel_size_[kKernelSizeWIndex];
|
||||
int64_t kernel_y = y % kernel_size_[kKernelSizeHIndex];
|
||||
int64_t x_offsets_offset = n_index * offset_n_dim + dfm_group_index * offset_group_dim +
|
||||
kernel_y * offset_kh_dim + kernel_x * offset_kw_dim + pixel_y * pixel_w + pixel_x;
|
||||
T x_offsets = offsets_addr[x_offsets_offset];
|
||||
int64_t y_offsets_offset = x_offsets_offset + offset_mask_dim;
|
||||
T y_offsets = offsets_addr[y_offsets_offset];
|
||||
int64_t mask_offset = y_offsets_offset + offset_mask_dim;
|
||||
T mask = offsets_addr[mask_offset];
|
||||
|
||||
T new_x = static_cast<T>(input_x) + x_offsets;
|
||||
T new_y = static_cast<T>(input_y) + y_offsets;
|
||||
const T *input_addr_offset = input_addr + n_index * input_n_dim + c_index * input_c_dim;
|
||||
T bilinear_val = DeformableBilinear(input_addr_offset, new_x, new_y, input_w_, input_h_);
|
||||
output_addr[i] = bilinear_val * mask;
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, LongToSize(n_ * output_n_dim), this, ¶llel_search_info_, pool_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DeformableOffsetsCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspaces,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto *position_grid_addr = GetDeviceAddress<int64_t>(workspaces, kIndex0);
|
||||
GenPositionGrid(position_grid_addr);
|
||||
T *x_addr = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *offsets_addr = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
DeformableOffsets(x_addr, offsets_addr, position_grid_addr, output_addr);
|
||||
return true;
|
||||
}
|
||||
|
||||
using KernelAttrAndDeformableOffsetsFuncList =
|
||||
std::vector<std::pair<KernelAttr, DeformableOffsetsCpuKernelMod::KernelRunFunc>>;
|
||||
const KernelAttrAndDeformableOffsetsFuncList &DeformableOffsetsCpuKernelMod::GetFuncList() const {
|
||||
static const KernelAttrAndDeformableOffsetsFuncList func_list = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&DeformableOffsetsCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&DeformableOffsetsCpuKernelMod::LaunchKernel<float>}};
|
||||
return func_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DeformableOffsets, DeformableOffsetsCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_DEFORMABLE_OFFSETS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_DEFORMABLE_OFFSETS_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class DeformableOffsetsCpuKernelMod : public NativeCpuKernelMod,
|
||||
public MatchKernelHelper<DeformableOffsetsCpuKernelMod> {
|
||||
public:
|
||||
DeformableOffsetsCpuKernelMod() { ResetResource(); }
|
||||
~DeformableOffsetsCpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
void ResetResource() noexcept;
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
void GenPositionGrid(int64_t *position_grid);
|
||||
|
||||
template <typename T>
|
||||
void DeformableOffsets(const T *x_addr, const T *offsets_addr, const int64_t *position_grid_addr, T *output_addr);
|
||||
|
||||
std::vector<int64_t> strides_;
|
||||
std::vector<int64_t> pads_;
|
||||
std::vector<int64_t> kernel_size_;
|
||||
std::vector<int64_t> dilations_;
|
||||
int64_t deformable_groups_;
|
||||
bool modulated_;
|
||||
|
||||
int64_t n_axis_;
|
||||
int64_t c_axis_;
|
||||
int64_t h_axis_;
|
||||
int64_t w_axis_;
|
||||
int64_t n_;
|
||||
int64_t c_;
|
||||
int64_t input_h_;
|
||||
int64_t input_w_;
|
||||
int64_t output_h_;
|
||||
int64_t output_w_;
|
||||
int64_t position_grid_size_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_DEFORMABLE_OFFSETS_CPU_KERNEL_H_
|
|
@ -96,6 +96,18 @@ void DeformableOffsetsPadFunction(std::vector<int64_t> *output_hw, const std::ve
|
|||
output_hw->push_back(out_w);
|
||||
}
|
||||
|
||||
void CheckOutputHeightAndWight(const std::string &prim_name, const std::vector<int64_t> &output_hw,
|
||||
const std::vector<int64_t> &offset_shape) {
|
||||
if (output_hw[kIndex0] != offset_shape[kIndex2] || output_hw[kIndex1] != offset_shape[kIndex3]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim_name
|
||||
<< ", the H and W dims of offsets input should be equal to the computed H and W dims of the "
|
||||
"output of deformable_conv2d. But got H and W dims of offsets input: ("
|
||||
<< offset_shape[kIndex2] << ", " << offset_shape[kIndex3]
|
||||
<< "), computed H and W dims of the output of deformable_conv2d: (" << output_hw[kIndex0] << ", "
|
||||
<< output_hw[kIndex1] << ").";
|
||||
}
|
||||
}
|
||||
|
||||
abstract::ShapePtr DeformableOffsetsInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -124,16 +136,10 @@ abstract::ShapePtr DeformableOffsetsInferShape(const PrimitivePtr &primitive,
|
|||
abstract::CheckShapeAllPositive(prim_name + " offsets_min_shape", offsets_min_shape);
|
||||
abstract::CheckShapeAllPositive(prim_name + " offsets_max_shape", offsets_max_shape);
|
||||
|
||||
const uint64_t n_axis = kIndex0;
|
||||
uint64_t c_axis = kIndex1;
|
||||
uint64_t h_axis = kIndex2;
|
||||
uint64_t w_axis = kIndex3;
|
||||
int64_t data_format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
|
||||
if (data_format == Format::NHWC) {
|
||||
c_axis = kIndex3;
|
||||
h_axis = kIndex1;
|
||||
w_axis = kIndex2;
|
||||
}
|
||||
constexpr uint64_t n_axis = 0;
|
||||
constexpr uint64_t c_axis = 1;
|
||||
constexpr uint64_t h_axis = 2;
|
||||
constexpr uint64_t w_axis = 3;
|
||||
|
||||
constexpr size_t strides_num = 4;
|
||||
auto strides = CheckAttrTupleAndNCDimensions(primitive, kAttrStrides, strides_num, n_axis, c_axis);
|
||||
|
@ -176,22 +182,19 @@ abstract::ShapePtr DeformableOffsetsInferShape(const PrimitivePtr &primitive,
|
|||
DeformableOffsetsPadFunction(&output_hw_max, kernel_size, strides, dilations, pads, x_max_shape[h_axis],
|
||||
x_max_shape[w_axis], h_axis, w_axis);
|
||||
|
||||
CheckOutputHeightAndWight(prim_name, output_hw, offsets_shape);
|
||||
CheckOutputHeightAndWight(prim_name, output_hw_min, offsets_min_shape);
|
||||
CheckOutputHeightAndWight(prim_name, output_hw_max, offsets_max_shape);
|
||||
|
||||
ShapeVector output_shape;
|
||||
ShapeVector output_shape_min;
|
||||
ShapeVector output_shape_max;
|
||||
if (data_format == Format::NHWC) {
|
||||
output_shape = {x_shape[n_axis], output_hw[0] * kernel_size[0], output_hw[1] * kernel_size[1], x_shape[c_axis]};
|
||||
output_shape_min = {x_min_shape[n_axis], output_hw_min[0] * kernel_size[0], output_hw_min[1] * kernel_size[1],
|
||||
x_min_shape[c_axis]};
|
||||
output_shape_max = {x_max_shape[n_axis], output_hw_max[0] * kernel_size[0], output_hw_max[1] * kernel_size[1],
|
||||
x_max_shape[c_axis]};
|
||||
} else {
|
||||
output_shape = {x_shape[n_axis], x_shape[c_axis], output_hw[0] * kernel_size[0], output_hw[1] * kernel_size[1]};
|
||||
output_shape_min = {x_min_shape[n_axis], x_min_shape[c_axis], output_hw_min[0] * kernel_size[0],
|
||||
output_hw_min[1] * kernel_size[1]};
|
||||
output_shape_max = {x_max_shape[n_axis], x_max_shape[c_axis], output_hw_max[0] * kernel_size[0],
|
||||
output_hw_max[1] * kernel_size[1]};
|
||||
}
|
||||
output_shape = {x_shape[n_axis], x_shape[c_axis], output_hw[0] * kernel_size[0], output_hw[1] * kernel_size[1]};
|
||||
output_shape_min = {x_min_shape[n_axis], x_min_shape[c_axis], output_hw_min[0] * kernel_size[0],
|
||||
output_hw_min[1] * kernel_size[1]};
|
||||
output_shape_max = {x_max_shape[n_axis], x_max_shape[c_axis], output_hw_max[0] * kernel_size[0],
|
||||
output_hw_max[1] * kernel_size[1]};
|
||||
|
||||
abstract::CheckShapeAnyAndPositive(prim_name + " output_shape", output_shape);
|
||||
abstract::CheckShapeAllPositive(prim_name + " output_shape_min", output_shape_min);
|
||||
abstract::CheckShapeAllPositive(prim_name + " output_shape_max", output_shape_max);
|
||||
|
|
|
@ -219,87 +219,94 @@ def softsign(x):
|
|||
|
||||
|
||||
def deformable_conv2d(x, weight, offsets, kernel_size, strides, padding, bias=None, dilations=(1, 1, 1, 1), groups=1,
|
||||
data_format="NCHW", deformable_groups=1, modulated=True):
|
||||
deformable_groups=1, modulated=True):
|
||||
r"""
|
||||
Computes a 2D deformable convolution given 4D `x`, `weight` and `offsets` tensors.
|
||||
Given 4D tensor inputs `x`, `weight` and `offsets`, compute a 2D deformable convolution. The deformable convolution
|
||||
operation can be expressed as follow:
|
||||
Deformable Convolution v1:
|
||||
.. math::
|
||||
y(p)=\sum_{k=1}^{K}w_{k}\cdot x(p+p_{k}+\Delta{p_{k}})
|
||||
|
||||
Note:
|
||||
For Ascend platform, only support cases when `in_channels` can be divisible by 8, `deformable_groups` is 1
|
||||
and `offsets` value is float which needs to contain a decimal part. For example: `x` is
|
||||
`(batch, 2, in_height, in_width)`, or `deformable_groups` is 2, or `offsets` assign with 'numpy.ones()'
|
||||
function, none of these scenarios are supported.
|
||||
This is an experimental interface that is subject to change or deletion.
|
||||
Deformable Convolution v2:
|
||||
.. math::
|
||||
y(p)=\sum_{k=1}^{K}w_{k}\cdot x(p+p_{k}+\Delta{p_{k}})\cdot \Delta{m_{k}}
|
||||
|
||||
Where :math:`\Delta{p_{k}}` and :math:`\Delta{m_{k}}` are the learnable offset and modulation scalar for the k-th
|
||||
location. For details, please refer to `Deformable ConvNets v2: More Deformable, Better Results
|
||||
<https://arxiv.org/abs/1811.11168>`_ and `Deformable Convolutional Networks <https://arxiv.org/abs/1703.06211>`_.
|
||||
|
||||
Args:
|
||||
x (Tensor): A 4D tensor of input image. With the format "NCHW" or "NHWC", the data is stored in the order of:
|
||||
:math:`(batch, in_channels, in_height, in_width)` when the format is "NCHW".
|
||||
weight (Tensor): A 4D tensor of learnable filters. Must have the same type as `x`. The data is stored in the
|
||||
order of: :math:`(out_channels, in_channels / groups, filter_height, filter_width)` when the format of
|
||||
`x` is "NCHW".
|
||||
offsets (Tensor): A 4D tensor of x-y coordinates offset and mask. With the format "NCHW" or "NHWC", when the
|
||||
format is "NCHW", the data is stored in the order of:
|
||||
:math:`(batch, deformable_groups * filter_height * filter_width * 3, out_height, out_width)`.
|
||||
kernel_size (tuple[int]): Required. A tuple of 2 integers. The size of kernel.
|
||||
strides (tuple[int]): Required. A tuple of 4 integers. The stride of the sliding window for each dimension of
|
||||
x (Tensor): A 4D tensor of input image. With the format "NCHW",
|
||||
the shape is :math:`(N, C_{in}, H_{in}, W_{in})`. Dtype: float16 or float32.
|
||||
weight (Tensor): A 4D tensor of learnable filters. Must have the same type as `x`.
|
||||
The shape is :math:`(C_{out}, C_{in} / groups, H_{f}, W_{f})`.
|
||||
offsets (Tensor): A 4D tensor of x-y coordinates offset and mask. With the format "NCHW",
|
||||
the shape is :math:`(batch, 3 * deformable_groups * H_{f} * W_{f}, H_{out}, W_{out})`. Note the C dimension
|
||||
is stored in the order of (offset_x, offset_y, mask). Must have the same type as `x`.
|
||||
kernel_size (tuple[int]): A tuple of 2 integers. The size of kernel.
|
||||
strides (tuple[int]): A tuple of 4 integers. The stride of the sliding window for each dimension of
|
||||
input. The dimension order is interpreted according to the data format of `x`. The N and C dimensions must
|
||||
be set to 1.
|
||||
padding (tuple[int]): Required. A list of 4 integers. The number of pixels to add to each (top, bottom, left,
|
||||
padding (tuple[int]): A tuple of 4 integers. The number of pixels to add to each (top, bottom, left,
|
||||
right) side of the input.
|
||||
bias (Tensor): Optional. An 1D tensor of additive biases to the filter outputs. The data is stored in the
|
||||
order of: :math:`(out_channels)`.
|
||||
dilations (tuple[int]): Optional. A list of 4 integers. The dilation factor for each dimension of input. The
|
||||
bias (Tensor, Optional): An 1D tensor of additive biases to the filter outputs.
|
||||
The shape is :math:`(out_channels)`. Defaults to None.
|
||||
dilations (tuple[int], Optional): A tuple of 4 integers. The dilation factor for each dimension of input. The
|
||||
dimension order is interpreted according to the data format of `x`. The N and C dimensions must be set
|
||||
to 1. Defaults to (1, 1, 1, 1).
|
||||
groups (int): Optional. An integer of type int32. The number of blocked connections from input channels
|
||||
groups (int, Optional): An integer of type int32. The number of blocked connections from input channels
|
||||
to output channels. In_channels and out_channels must both be divisible by `groups`. Defaults to 1.
|
||||
data_format (str): Optional. The value for data format, is 'NCHW' or 'NHWC'. Defaults to 'NCHW'.
|
||||
deformable_groups (int) - Optional. An integer of type int32. The number of deformable group partitions.
|
||||
deformable_groups (int, Optional): An integer of type int32. The number of deformable group partitions.
|
||||
In_channels must be divisible by `deformable_groups`. Defaults to 1.
|
||||
modulated (bool) - Optional. Specify version of DeformableConv2D, True means v2, False means v1, currently
|
||||
only support v2. Defaults to True.
|
||||
modulated (bool, Optional): Specifies version of DeformableConv2D, True means v2, False means v1, currently
|
||||
only supports v2. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tensor, A 4D Tensor of output feature map. With the same type as `x`. With the format "NCHW" or "NHWC", the
|
||||
data is stored in the order of: :math:`(batch, out_channels, out_height, out_width)` when the format is
|
||||
"NCHW".
|
||||
Tensor, A 4D Tensor of output feature map. With the same type as `x`. With the format "NCHW",
|
||||
the shape is :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
\text{out\_height} = {\frac{\text{in\_height} + \text{pad\_top} + \text{pad\_bottom}
|
||||
- (\text{dilation\_h} * (\text{filter\_height} - 1) + 1)}{\text{stride\_h}}} + 1 \\
|
||||
\text{out\_width} = {\frac{\text{in\_width} + \text{pad\_left} + \text{pad\_right}
|
||||
- (\text{dilation\_w} * (\text{filter\_width} - 1) + 1)}{\text{stride\_w}}} + 1 \\
|
||||
H_{out} = \left \lfloor{\frac{H_{in} + padding[0] + padding[1] - (H_{f} - 1) \times
|
||||
\text{dilations[3]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
|
||||
W_{out} = \left \lfloor{\frac{W_{in} + padding[2] + padding[3] - (W_{f} - 1) \times
|
||||
\text{dilations[4]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
|
||||
\end{array}
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Raises:
|
||||
TypeError: If `strides`, `padding`, `kernel_size` or `dilations` is not a tuple with integer elements.
|
||||
TypeError: If `modulated` is not a bool.
|
||||
ValueError: If the tuple size of `strides`, `padding`, `kernel_size` or `dilations` is not expected.
|
||||
ValueError: The N or C dimensions of 'strides' or `dilations` is not set to 1.
|
||||
ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
|
||||
ValueError: If `modulated` is not set to True.
|
||||
|
||||
.. note::
|
||||
- This is an experimental interface that is subject to change or deletion.
|
||||
- For Ascend platform, only supports cases when :math:`C_{in}` can be divisible by 8, `deformable_groups` is 1
|
||||
and `offsets` value is float which needs to contain a decimal part. For example, these scenarios where the
|
||||
shape of `x` is :math:`(N, 2, H_{in}, W_{in})`, `deformable_groups` is 2 or `offsets` is assigned with
|
||||
"numpy.ones()" function are not supported.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.ones((4, 3, 10, 10)), mstype.float32)
|
||||
>>> kh, kw = 3, 3
|
||||
>>> weight = Tensor(np.ones((5, 3, kh, kw)), mstype.float32)
|
||||
>>> offsets = Tensor(np.ones((4, 3 * kh * kw, 8, 8)), mstype.float32)
|
||||
>>> output = ops.deformable_conv2d(x, weight, offsets, (kh, kw), (1, 1, 1, 1), (0, 0, 0, 0), data_format="NCHW")
|
||||
>>> output = ops.deformable_conv2d(x, weight, offsets, (kh, kw), (1, 1, 1, 1), (0, 0, 0, 0))
|
||||
>>> print(output.shape)
|
||||
(4, 5, 8, 8)
|
||||
"""
|
||||
deformable_offsets = NN.DeformableOffsets(strides, padding, kernel_size, dilations, data_format, deformable_groups,
|
||||
deformable_offsets = NN.DeformableOffsets(strides, padding, kernel_size, dilations, "NCHW", deformable_groups,
|
||||
modulated)
|
||||
fm_offset = deformable_offsets(x, offsets)
|
||||
|
||||
weight_shape = weight.shape
|
||||
out_channel = weight_shape[0]
|
||||
if data_format == "NHWC":
|
||||
out_channel = weight_shape[3]
|
||||
strides_conv = (kernel_size[0], kernel_size[1])
|
||||
conv = P.Conv2D(out_channel, kernel_size, 1, "valid", 0, strides_conv, 1, groups, data_format)
|
||||
bias_add = P.BiasAdd(data_format)
|
||||
conv = P.Conv2D(out_channel, kernel_size, 1, "valid", 0, strides_conv, 1, groups)
|
||||
bias_add = P.BiasAdd()
|
||||
|
||||
output = conv(fm_offset, weight)
|
||||
if bias is not None:
|
||||
|
|
|
@ -0,0 +1,329 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops as ops
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
context.set_context(device_target='CPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, weight, offsets, kh, kw, strides=(1, 1, 1, 1), padding=(0, 0, 0, 0), bias=None,
|
||||
dilations=(1, 1, 1, 1)):
|
||||
return ops.deformable_conv2d(x, weight, offsets, (kh, kw), strides, padding, bias, dilations)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_deformable_conv2d():
|
||||
""""
|
||||
Feature: deformable_conv2d function.
|
||||
Description: Test case for simplest deformable_conv2d.
|
||||
Expectation: The results are as expected.
|
||||
"""
|
||||
kh, kw = 1, 1
|
||||
# x shape [1, 1, 1, 2]
|
||||
x = np.array([[[[-0.41675785, -0.05626683]]]]).astype(np.float32)
|
||||
x = Tensor(x, mstype.float32)
|
||||
# weight shape [1, 1, 1, 1]
|
||||
weight = np.array([[[[-2.1361961]]]]).astype(np.float32)
|
||||
weight = Tensor(weight, mstype.float32)
|
||||
# offsets shape [1, 3, 1, 2]
|
||||
offsets = np.array([[[[1.6402708, -1.7934356]],
|
||||
[[-0.84174734, 0.5028814]],
|
||||
[[-1.2452881, -1.0579522]]]]).astype(np.float32)
|
||||
offsets = Tensor(offsets, mstype.float32)
|
||||
out = Net()(x, weight, offsets, kh, kw)
|
||||
# expected output: [1, 1, 1, 2]
|
||||
expected = np.array([[[[-0.00852099, -0.09671781]]]]).astype(np.float32)
|
||||
assert np.allclose(out.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_required_inputs():
|
||||
""""
|
||||
Feature: deformable_conv2d function.
|
||||
Description: Test case for simplest deformable_conv2d.
|
||||
Expectation: The results are as expected.
|
||||
"""
|
||||
x = Tensor(np.arange(2 * 3 * 5 * 5).reshape(2, 3, 5, 5), mstype.float32)
|
||||
kh, kw = 3, 3
|
||||
weight = Tensor(np.arange(5 * 3 * kh * kw).reshape(5, 3, kh, kw), mstype.float32)
|
||||
offsets = Tensor(np.ones((2, 3 * kh * kw, 3, 3)), mstype.float32)
|
||||
output = Net()(x, weight, offsets, kh, kw)
|
||||
expect = np.array([[[[17325., 17676., 11547.],
|
||||
[19080., 19431., 12672.],
|
||||
[11991., 12198., 7920.]],
|
||||
|
||||
[[44298., 45378., 30258.],
|
||||
[49698., 50778., 33813.],
|
||||
[33618., 34311., 22824.]],
|
||||
|
||||
[[71271., 73080., 48969.],
|
||||
[80316., 82125., 54954.],
|
||||
[55245., 56424., 37728.]],
|
||||
|
||||
[[98244., 100782., 67680.],
|
||||
[110934., 113472., 76095.],
|
||||
[76872., 78537., 52632.]],
|
||||
|
||||
[[125217., 128484., 86391.],
|
||||
[141552., 144819., 97236.],
|
||||
[98499., 100650., 67536.]]],
|
||||
|
||||
[[[43650., 44001., 28422.],
|
||||
[45405., 45756., 29547.],
|
||||
[27516., 27723., 17820.]],
|
||||
|
||||
[[125298., 126378., 83583.],
|
||||
[130698., 131778., 87138.],
|
||||
[85593., 86286., 57024.]],
|
||||
|
||||
[[206946., 208755., 138744.],
|
||||
[215991., 217800., 144729.],
|
||||
[143670., 144849., 96228.]],
|
||||
|
||||
[[288594., 291132., 193905.],
|
||||
[301284., 303822., 202320.],
|
||||
[201747., 203412., 135432.]],
|
||||
|
||||
[[370242., 373509., 249066.],
|
||||
[386577., 389844., 259911.],
|
||||
[259824., 261975., 174636.]]]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_with_bias():
|
||||
""""
|
||||
Feature: deformable_conv2d function.
|
||||
Description: Test case with bias input.
|
||||
Expectation: The results are as expected.
|
||||
"""
|
||||
x = Tensor(np.arange(2 * 3 * 5 * 5).reshape(2, 3, 5, 5), mstype.float32)
|
||||
kh, kw = 3, 3
|
||||
weight = Tensor(np.arange(5 * 3 * kh * kw).reshape(5, 3, kh, kw), mstype.float32)
|
||||
bias = Tensor(np.ones((5,)), mstype.float32)
|
||||
offsets = Tensor(np.ones((2, 3 * kh * kw, 3, 3)), mstype.float32)
|
||||
output = Net()(x, weight, offsets, kh, kw, bias=bias)
|
||||
expect = np.array([[[[17326., 17677., 11548.],
|
||||
[19081., 19432., 12673.],
|
||||
[11992., 12199., 7921.]],
|
||||
|
||||
[[44299., 45379., 30259.],
|
||||
[49699., 50779., 33814.],
|
||||
[33619., 34312., 22825.]],
|
||||
|
||||
[[71272., 73081., 48970.],
|
||||
[80317., 82126., 54955.],
|
||||
[55246., 56425., 37729.]],
|
||||
|
||||
[[98245., 100783., 67681.],
|
||||
[110935., 113473., 76096.],
|
||||
[76873., 78538., 52633.]],
|
||||
|
||||
[[125218., 128485., 86392.],
|
||||
[141553., 144820., 97237.],
|
||||
[98500., 100651., 67537.]]],
|
||||
|
||||
[[[43651., 44002., 28423.],
|
||||
[45406., 45757., 29548.],
|
||||
[27517., 27724., 17821.]],
|
||||
|
||||
[[125299., 126379., 83584.],
|
||||
[130699., 131779., 87139.],
|
||||
[85594., 86287., 57025.]],
|
||||
|
||||
[[206947., 208756., 138745.],
|
||||
[215992., 217801., 144730.],
|
||||
[143671., 144850., 96229.]],
|
||||
|
||||
[[288595., 291133., 193906.],
|
||||
[301285., 303823., 202321.],
|
||||
[201748., 203413., 135433.]],
|
||||
|
||||
[[370243., 373510., 249067.],
|
||||
[386578., 389845., 259912.],
|
||||
[259825., 261976., 174637.]]]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_with_strides():
|
||||
""""
|
||||
Feature: deformable_conv2d function.
|
||||
Description: Test case with strides input.
|
||||
Expectation: The results are as expected.
|
||||
"""
|
||||
x = Tensor(np.arange(2 * 3 * 5 * 5).reshape(2, 3, 5, 5), mstype.float32)
|
||||
kh, kw = 3, 3
|
||||
weight = Tensor(np.arange(5 * 3 * kh * kw).reshape(5, 3, kh, kw), mstype.float32)
|
||||
offsets = Tensor(np.ones((2, 3 * kh * kw, 2, 2)), mstype.float32)
|
||||
output = Net()(x, weight, offsets, kh, kw, (1, 1, 2, 2))
|
||||
expect = np.array([[[[17325., 11547.],
|
||||
[11991., 7920.]],
|
||||
|
||||
[[44298., 30258.],
|
||||
[33618., 22824.]],
|
||||
|
||||
[[71271., 48969.],
|
||||
[55245., 37728.]],
|
||||
|
||||
[[98244., 67680.],
|
||||
[76872., 52632.]],
|
||||
|
||||
[[125217., 86391.],
|
||||
[98499., 67536.]]],
|
||||
|
||||
[[[43650., 28422.],
|
||||
[27516., 17820.]],
|
||||
|
||||
[[125298., 83583.],
|
||||
[85593., 57024.]],
|
||||
|
||||
[[206946., 138744.],
|
||||
[143670., 96228.]],
|
||||
|
||||
[[288594., 193905.],
|
||||
[201747., 135432.]],
|
||||
|
||||
[[370242., 249066.],
|
||||
[259824., 174636.]]]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_with_padding():
|
||||
""""
|
||||
Feature: deformable_conv2d function.
|
||||
Description: Test case with padding input.
|
||||
Expectation: The results are as expected.
|
||||
"""
|
||||
x = Tensor(np.arange(2 * 3 * 5 * 5).reshape(2, 3, 5, 5), mstype.float32)
|
||||
kh, kw = 3, 3
|
||||
weight = Tensor(np.arange(5 * 3 * kh * kw).reshape(5, 3, kh, kw), mstype.float32)
|
||||
offsets = Tensor(np.ones((2, 3 * kh * kw, 5, 7)), mstype.float32)
|
||||
output = Net()(x, weight, offsets, kh, kw, padding=(1, 1, 2, 2))
|
||||
expect = np.array([[[[10296., 15219., 15570., 15921., 10422., 5112., 0.],
|
||||
[11511., 16974., 17325., 17676., 11547., 5652., 0.],
|
||||
[12726., 18729., 19080., 19431., 12672., 6192., 0.],
|
||||
[8040., 11784., 11991., 12198., 7920., 3852., 0.],
|
||||
[3768., 5496., 5586., 5676., 3666., 1773., 0.]],
|
||||
|
||||
[[25119., 37818., 38898., 39978., 26703., 13374., 0.],
|
||||
[28764., 43218., 44298., 45378., 30258., 15129., 0.],
|
||||
[32409., 48618., 49698., 50778., 33813., 16884., 0.],
|
||||
[21972., 32925., 33618., 34311., 22824., 11385., 0.],
|
||||
[11139., 16674., 17007., 17340., 11523., 5742., 0.]],
|
||||
|
||||
[[39942., 60417., 62226., 64035., 42984., 21636., 0.],
|
||||
[46017., 69462., 71271., 73080., 48969., 24606., 0.],
|
||||
[52092., 78507., 80316., 82125., 54954., 27576., 0.],
|
||||
[35904., 54066., 55245., 56424., 37728., 18918., 0.],
|
||||
[18510., 27852., 28428., 29004., 19380., 9711., 0.]],
|
||||
|
||||
[[54765., 83016., 85554., 88092., 59265., 29898., 0.],
|
||||
[63270., 95706., 98244., 100782., 67680., 34083., 0.],
|
||||
[71775., 108396., 110934., 113472., 76095., 38268., 0.],
|
||||
[49836., 75207., 76872., 78537., 52632., 26451., 0.],
|
||||
[25881., 39030., 39849., 40668., 27237., 13680., 0.]],
|
||||
|
||||
[[69588., 105615., 108882., 112149., 75546., 38160., 0.],
|
||||
[80523., 121950., 125217., 128484., 86391., 43560., 0.],
|
||||
[91458., 138285., 141552., 144819., 97236., 48960., 0.],
|
||||
[63768., 96348., 98499., 100650., 67536., 33984., 0.],
|
||||
[33252., 50208., 51270., 52332., 35094., 17649., 0.]]],
|
||||
|
||||
[[[28521., 41544., 41895., 42246., 27297., 13212., 0.],
|
||||
[29736., 43299., 43650., 44001., 28422., 13752., 0.],
|
||||
[30951., 45054., 45405., 45756., 29547., 14292., 0.],
|
||||
[18840., 27309., 27516., 27723., 17820., 8577., 0.],
|
||||
[8493., 12246., 12336., 12426., 7941., 3798., 0.]],
|
||||
|
||||
[[79794., 118818., 119898., 120978., 80028., 39699., 0.],
|
||||
[83439., 124218., 125298., 126378., 83583., 41454., 0.],
|
||||
[87084., 129618., 130698., 131778., 87138., 43209., 0.],
|
||||
[57072., 84900., 85593., 86286., 57024., 28260., 0.],
|
||||
[28014., 41649., 41982., 42315., 27948., 13842., 0.]],
|
||||
|
||||
[[131067., 196092., 197901., 199710., 132759., 66186., 0.],
|
||||
[137142., 205137., 206946., 208755., 138744., 69156., 0.],
|
||||
[143217., 214182., 215991., 217800., 144729., 72126., 0.],
|
||||
[95304., 142491., 143670., 144849., 96228., 47943., 0.],
|
||||
[47535., 71052., 71628., 72204., 47955., 23886., 0.]],
|
||||
|
||||
[[182340., 273366., 275904., 278442., 185490., 92673., 0.],
|
||||
[190845., 286056., 288594., 291132., 193905., 96858., 0.],
|
||||
[199350., 298746., 301284., 303822., 202320., 101043., 0.],
|
||||
[133536., 200082., 201747., 203412., 135432., 67626., 0.],
|
||||
[67056., 100455., 101274., 102093., 67962., 33930., 0.]],
|
||||
|
||||
[[233613., 350640., 353907., 357174., 238221., 119160., 0.],
|
||||
[244548., 366975., 370242., 373509., 249066., 124560., 0.],
|
||||
[255483., 383310., 386577., 389844., 259911., 129960., 0.],
|
||||
[171768., 257673., 259824., 261975., 174636., 87309., 0.],
|
||||
[86577., 129858., 130920., 131982., 87969., 43974., 0.]]]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_with_dilations():
|
||||
""""
|
||||
Feature: deformable_conv2d function.
|
||||
Description: Test case with dilations input.
|
||||
Expectation: The results are as expected.
|
||||
"""
|
||||
x = Tensor(np.arange(2 * 3 * 5 * 5).reshape(2, 3, 5, 5), mstype.float32)
|
||||
kh, kw = 3, 3
|
||||
weight = Tensor(np.arange(5 * 3 * kh * kw).reshape(5, 3, kh, kw), mstype.float32)
|
||||
offsets = Tensor(np.ones((2, 3 * kh * kw, 1, 1)), mstype.float32)
|
||||
output = Net()(x, weight, offsets, kh, kw, dilations=(1, 1, 2, 2))
|
||||
expect = np.array([[[[6780.]],
|
||||
|
||||
[[18768.]],
|
||||
|
||||
[[30756.]],
|
||||
|
||||
[[42744.]],
|
||||
|
||||
[[54732.]]],
|
||||
|
||||
[[[16680.]],
|
||||
|
||||
[[52968.]],
|
||||
|
||||
[[89256.]],
|
||||
|
||||
[[125544.]],
|
||||
|
||||
[[161832.]]]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
|
@ -49,8 +49,7 @@ def test_deformable_conv2d():
|
|||
[[-1.2452881, -1.0579522]]]]).astype(np.float32)
|
||||
offsets = Tensor(offsets, mstype.float32)
|
||||
out = deformable_conv2d(x, weight, offsets, (kh, kw), (1, 1, stride_h, stride_w), (pad_h, pad_h, pad_w, pad_w),
|
||||
data_format="NCHW", dilations=(1, 1, dilation_h, dilation_w),
|
||||
deformable_groups=deformable_group)
|
||||
dilations=(1, 1, dilation_h, dilation_w), deformable_groups=deformable_group)
|
||||
# expected output: [1, 1, 1, 2]
|
||||
expected = np.array([[[[-0.00852099, -0.09671781]]]]).astype(np.float32)
|
||||
assert np.allclose(out.asnumpy(), expected)
|
||||
|
|
Loading…
Reference in New Issue