!34770 Add gpu DeformableOffsets

Merge pull request !34770 from tanghuikang/gpu_deformable_offset
This commit is contained in:
i-robot 2022-06-07 10:55:50 +00:00 committed by Gitee
commit 6fe64a18c9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 541 additions and 0 deletions

View File

@ -0,0 +1,169 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/deformable_offsets_impl.cuh"
#include <stdio.h>
#include <stdint.h>
#include "include/cuda_fp16.h"
constexpr int OFFSET_NUM = 3;
template <typename T>
__device__ T DefromableBilinear(const T *input, const uint width, const uint height, const T x, const T y) {
if (y <= static_cast<T>(-1) || y >= static_cast<T>(height) || x <= static_cast<T>(-1) || x >= static_cast<T>(width)) {
return 0;
}
int left = floorf(x);
int top = floorf(y);
int right = left + 1;
int bottom = top + 1;
T l = x - static_cast<T>(left);
T t = y - static_cast<T>(top);
T r = static_cast<T>(1) - l;
T b = static_cast<T>(1) - t;
T lt = 0;
T lb = 0;
if (left >= 0) {
if (top >= 0) {
lt = input[top * width + left];
}
if (bottom <= height - 1) {
lb = input[bottom * width + left];
}
}
T rt = 0;
T rb = 0;
if (right <= width - 1) {
if (top >= 0) {
rt = input[top * width + right];
}
if (bottom <= height - 1) {
rb = input[bottom * width + right];
}
}
T w_lt = r * b;
T w_rt = l * b;
T w_lb = r * t;
T w_rb = l * t;
T val = (w_lt * lt + w_rt * rt + w_lb * lb + w_rb * rb);
return val;
}
__global__ void GenPositionGridKernel(const uint kernel_h, const uint kernel_w, const uint stride_h,
const uint stride_w, const uint dilations_h, const uint dilations_w,
const uint pad_l, const uint pad_t, const uint output_w, const uint num,
uint *position_grid) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) {
uint y = i / output_w;
uint x = i % output_w;
uint pixel_y = y / kernel_h;
uint pixel_x = x / kernel_w;
uint kernel_y = y % kernel_h;
uint kernel_x = x % kernel_w;
uint index = i * 2;
position_grid[index] = pixel_x * stride_w + kernel_x * dilations_w - pad_l;
position_grid[index + 1] = pixel_y * stride_h + kernel_y * dilations_h - pad_t;
}
}
template <class T>
__global__ void DeformableOffsetsKernel(const T *input, const T *offsets, const uint *position_grid, const uint c,
const uint output_n_dim, const uint output_c_dim, const uint output_w,
const uint c_size_per_dfm_group, const uint offset_n_dim,
const uint offset_mask_dim, const uint offset_group_dim,
const uint offset_kh_dim, const uint offset_kw_dim, const uint pixel_w,
const uint input_n_dim, const uint input_c_dim, const uint input_h,
const uint input_w, const uint kernel_h, const uint kernel_w, const uint num,
T *output) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) {
// Get original input position
const uint hw_idx = i % output_c_dim;
const uint position_grid_idx = hw_idx * 2;
const uint input_x = position_grid[position_grid_idx];
const uint input_y = position_grid[position_grid_idx + 1];
// Get offsets
const uint n_index = i / output_n_dim;
const uint c_index = i / output_c_dim % c;
const uint x = hw_idx % output_w;
const uint y = hw_idx / output_w;
const uint dfm_group_index = c_index / c_size_per_dfm_group;
const uint pixel_x = x / kernel_w;
const uint pixel_y = y / kernel_h;
const uint kernel_x = x % kernel_w;
const uint kernel_y = y % kernel_h;
const uint x_offsets_offset = n_index * offset_n_dim // + 0 * offset_mask_dim
+ dfm_group_index * offset_group_dim + kernel_y * offset_kh_dim +
kernel_x * offset_kw_dim + pixel_y * pixel_w + pixel_x;
T x_offsets = offsets[x_offsets_offset];
const int y_offsets_offset = x_offsets_offset + offset_mask_dim;
T y_offsets = offsets[y_offsets_offset];
const int mask_offset = y_offsets_offset + offset_mask_dim;
T mask = offsets[mask_offset];
// Deform
T deformed_x = static_cast<T>(input_x) + x_offsets;
T deformed_y = static_cast<T>(input_y) + y_offsets;
const T *input_base = input + n_index * input_n_dim + c_index * input_c_dim;
T bilinear_val = DefromableBilinear(input_base, input_w, input_h, deformed_x, deformed_y);
output[i] = bilinear_val * mask;
}
}
template <class T>
void DeformableOffsets(const T *input, const T *offsets, const uint *position_grid, uint n, uint c, uint input_h,
uint input_w, uint dfm_group, uint kernel_h, uint kernel_w, uint output_h, uint output_w,
T *output, uint32_t device_id, cudaStream_t cuda_stream) {
const uint pixel_w = output_w / kernel_w;
const uint pixel_h = output_h / kernel_h;
const uint output_c_dim = output_h * output_w;
const uint output_n_dim = c * output_c_dim;
const uint num = n * output_n_dim;
const uint c_size_per_dfm_group = c / dfm_group;
const uint offset_kw_dim = pixel_h * pixel_w;
const uint offset_kh_dim = offset_kw_dim * kernel_w;
const uint offset_group_dim = offset_kh_dim * kernel_h;
const uint offset_mask_dim = offset_group_dim * dfm_group;
const uint offset_n_dim = offset_mask_dim * OFFSET_NUM;
const uint input_c_dim = input_h * input_w;
const uint input_n_dim = input_c_dim * c;
DeformableOffsetsKernel<<<CUDA_BLOCKS(device_id, num), CUDA_THREADS(device_id), 0, cuda_stream>>>(
input, offsets, position_grid, c, output_n_dim, output_c_dim, output_w, c_size_per_dfm_group, offset_n_dim,
offset_mask_dim, offset_group_dim, offset_kh_dim, offset_kw_dim, pixel_w, input_n_dim, input_c_dim, input_h,
input_w, kernel_h, kernel_w, num, output);
}
void GenPositionGrid(const uint kernel_h, const uint kernel_w, const uint stride_h, const uint stride_w,
const uint dilations_h, const uint dilations_w, const uint pad_l, const uint pad_t,
const uint output_w, const uint num, uint *position_grid, const uint32_t device_id,
cudaStream_t cuda_stream) {
GenPositionGridKernel<<<CUDA_BLOCKS(device_id, num), CUDA_THREADS(device_id), 0, cuda_stream>>>(
kernel_h, kernel_w, stride_h, stride_w, dilations_h, dilations_w, pad_l, pad_t, output_w, num, position_grid);
}
template CUDA_LIB_EXPORT void DeformableOffsets<float>(const float *input, const float *offsets,
const uint *position_grid, uint n, uint c, uint input_h,
uint input_w, uint dfm_group, uint kernel_h, uint kernel_w,
uint output_h, uint output_w, float *output,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void DeformableOffsets<half>(const half *input, const half *offsets,
const uint *position_grid, uint n, uint c, uint input_h,
uint input_w, uint dfm_group, uint kernel_h, uint kernel_w,
uint output_h, uint output_w, half *output,
uint32_t device_id, cudaStream_t cuda_stream);

View File

@ -0,0 +1,34 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DEFORMABLE_OFFSETS_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DEFORMABLE_OFFSETS_IMPL_CUH_
#include <cuda_runtime.h>
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
CUDA_LIB_EXPORT void GenPositionGrid(const uint kernel_h, const uint kernel_w, const uint stride_h, const uint stride_w,
const uint dilations_h, const uint dilations_w, const uint pad_l, const uint pad_t,
const uint output_w, const uint num, uint *position_grid, const uint32_t device_id,
cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void DeformableOffsets(const T *input, const T *offsets, const uint *position_grid, uint n, uint c,
uint input_h, uint input_w, uint dfm_group, uint kernel_h, uint kernel_w,
uint output_h, uint output_w, T *output, uint32_t device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DEFORMABLE_OFFSETS_IMPL_CUH_

View File

@ -0,0 +1,184 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <mindspore/core/abstract/utils.h>
#include <memory>
#include <utility>
#include <algorithm>
#include "abstract/utils.h"
#include "plugin/device/gpu/kernel/nn/deformable_offsets_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/deformable_offsets_impl.cuh"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kTopPadIndex = 0;
constexpr size_t kLeftPadIndex = 2;
constexpr size_t kKernelSizeHIndex = 0;
constexpr size_t kKernelSizeWIndex = 1;
constexpr size_t kInputNum = 2;
constexpr size_t kOutputNum = 1;
constexpr size_t kStrideAttrNum = 4;
constexpr size_t kPadAttrNum = 4;
constexpr size_t kKernelSizeAttrNum = 2;
constexpr size_t kDilationAttrNum = 4;
} // namespace
bool DeformableOffsetsGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool DeformableOffsetsGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::DeformableOffsets>(base_operator);
kernel_name_ = kernel_ptr->name();
if (inputs.size() != kInputNum || outputs.size() != kOutputNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it should get two inputs and one output, but got " << inputs.size()
<< "inputs and " << outputs.size() << " outputs";
return false;
}
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
if (!is_match) {
return false;
}
kernel_func_ = func_list_[index].second;
if (!CheckParam(kernel_ptr)) {
return false;
}
return true;
}
bool DeformableOffsetsGpuKernelMod::CheckParam(const std::shared_ptr<ops::DeformableOffsets> &kernel) {
data_format_ = kernel->get_data_format();
if (data_format_ == kOpFormat_NCHW) {
n_axis_ = 0;
c_axis_ = 1;
h_axis_ = 2;
w_axis_ = 3;
} else {
MS_LOG(ERROR) << kernel_name_ << " only supports input with format NCHW, but got format " << data_format_;
return false;
}
const auto to_unsigned = [](const int64_t &value) { return LongToUint(value); };
const auto &strides = kernel->get_strides();
std::transform(strides.begin(), strides.end(), std::back_inserter(strides_), to_unsigned);
if (strides_.size() != kStrideAttrNum || strides_[n_axis_] != 1 || strides_[c_axis_] != 1) {
MS_LOG(ERROR) << "Get invalid strides attr form " << kernel_name_
<< ", strides should be a vector constructed by 4 integer and n&c dim should be 1, but got"
<< strides_;
return false;
}
const auto &pads = kernel->get_pads();
std::transform(pads.begin(), pads.end(), std::back_inserter(pads_), to_unsigned);
if (pads_.size() != kPadAttrNum) {
MS_LOG(ERROR) << "Get invalid pads attr form " << kernel_name_
<< ", padding should be a vector constructed by 4 integer, but got" << pads_;
return false;
}
const auto &kernel_size = kernel->get_kernel_size();
std::transform(kernel_size.begin(), kernel_size.end(), std::back_inserter(kernel_size_), to_unsigned);
if (kernel_size_.size() != kKernelSizeAttrNum) {
MS_LOG(ERROR) << "Get invalid ksize attr form " << kernel_name_
<< ", ksize should be a vector constructed by 2 integer, but got" << kernel_size_;
return false;
}
const auto &dilations = kernel->get_dilations();
std::transform(dilations.begin(), dilations.end(), std::back_inserter(dilations_), to_unsigned);
if (dilations_.size() != kDilationAttrNum || dilations_[n_axis_] != 1 || dilations_[c_axis_] != 1) {
MS_LOG(ERROR) << "Get invalid dilations attr form " << kernel_name_
<< ", dilations should be a vector constructed by 4 integer and n&c dim should be 1, but got"
<< dilations_;
return false;
}
deformable_groups_ = static_cast<size_t>(kernel->get_deformable_groups());
if (deformable_groups_ <= 0) {
MS_LOG(ERROR) << kernel_name_ << "'s deformable_groups should greater than 0, but got " << deformable_groups_;
return false;
}
modulated_ = kernel->get_modulated();
if (!modulated_) {
MS_LOG(ERROR) << kernel_name_ << "only support v2, and the modulated should be true, but got false";
return false;
}
return true;
}
int DeformableOffsetsGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (KernelMod::Resize(base_operator, inputs, outputs) == KRET_UNKNOWN_SHAPE) {
return KRET_UNKNOWN_SHAPE;
}
const auto &x_shape = inputs[0]->GetShapeVector();
n_ = x_shape[n_axis_];
c_ = x_shape[c_axis_];
x_h_ = x_shape[h_axis_];
x_w_ = x_shape[w_axis_];
const auto &y_shape = outputs[0]->GetShapeVector();
output_h_ = y_shape[h_axis_];
output_w_ = y_shape[w_axis_];
position_grid_num_ = output_w_ * output_h_;
auto position_grid_size = position_grid_num_ * 2 * sizeof(uint);
workspace_size_list_.emplace_back(position_grid_size);
return KRET_OK;
}
template <class T>
bool DeformableOffsetsGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
unsigned int *position_addr = GetDeviceAddress<unsigned int>(workspace, 0);
const size_t num = output_h_ * output_w_;
GenPositionGrid(kernel_size_[kKernelSizeHIndex], kernel_size_[kKernelSizeWIndex], strides_[h_axis_],
strides_[w_axis_], dilations_[h_axis_], dilations_[w_axis_], pads_[kLeftPadIndex],
pads_[kTopPadIndex], output_w_, num, position_addr, device_id_,
reinterpret_cast<cudaStream_t>(stream_ptr));
T *x_addr = GetDeviceAddress<T>(inputs, 0);
T *offsets_addr = GetDeviceAddress<T>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
DeformableOffsets(x_addr, offsets_addr, position_addr, n_, c_, x_h_, x_w_, deformable_groups_,
kernel_size_[kKernelSizeHIndex], kernel_size_[kKernelSizeWIndex], output_h_, output_w_, output_addr,
device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
std::vector<std::pair<KernelAttr, DeformableOffsetsGpuKernelMod::LaunchKernelFunc>>
DeformableOffsetsGpuKernelMod::func_list_ = {{KernelAttr()
.AddInputAttr(kNumberTypeFloat32, kOpFormat_NCHW)
.AddInputAttr(kNumberTypeFloat32, kOpFormat_NCHW)
.AddOutputAttr(kNumberTypeFloat32, kOpFormat_NCHW),
&DeformableOffsetsGpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16, kOpFormat_NCHW)
.AddInputAttr(kNumberTypeFloat16, kOpFormat_NCHW)
.AddOutputAttr(kNumberTypeFloat16, kOpFormat_NCHW),
&DeformableOffsetsGpuKernelMod::LaunchKernel<half>}};
std::vector<KernelAttr> DeformableOffsetsGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, LaunchKernelFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DeformableOffsets, DeformableOffsetsGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,88 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DEFORMABLE_OFFSET_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DEFORMABLE_OFFSET_H_
#include <vector>
#include <string>
#include <functional>
#include <map>
#include <utility>
#include <memory>
#include "ops/deformable_offsets.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class DeformableOffsetsGpuKernelMod : public NativeGpuKernelMod {
public:
DeformableOffsetsGpuKernelMod() {}
~DeformableOffsetsGpuKernelMod() override {}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
using LaunchKernelFunc =
std::function<bool(DeformableOffsetsGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, LaunchKernelFunc>> func_list_;
template <class T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
bool CheckParam(const std::shared_ptr<ops::DeformableOffsets> &kernel);
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
// attrs
std::vector<uint32_t> strides_;
std::vector<uint32_t> pads_;
std::vector<uint32_t> kernel_size_;
std::vector<uint32_t> dilations_;
std::string data_format_;
uint32_t deformable_groups_;
bool modulated_;
// Constant value
LaunchKernelFunc kernel_func_{};
// axis
size_t n_axis_;
size_t c_axis_;
size_t h_axis_;
size_t w_axis_;
// Dynamic value
uint32_t position_grid_num_;
// x shape
uint32_t n_;
uint32_t c_;
uint32_t x_h_;
uint32_t x_w_;
// output shape
uint32_t output_h_;
uint32_t output_w_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DEFORMABLE_OFFSET_H_

View File

@ -73,6 +73,16 @@ inline std::vector<size_t> LongVecToSizeVec(const std::vector<int64_t> &vec) {
return result;
}
inline uint32_t LongToUint(int64_t u) {
if (u < 0) {
MS_LOG(EXCEPTION) << "The int64_t value(" << u << ") is less than 0.";
}
if (u > static_cast<int64_t>((std::numeric_limits<uint32_t>::max)())) {
MS_LOG(EXCEPTION) << "The int64_t value(" << u << ") exceeds the maximum value of uint32_t.";
}
return static_cast<uint32_t>(u);
}
inline size_t FloatToSize(float u) {
if (u < 0) {
MS_LOG(EXCEPTION) << "The float value(" << u << ") is less than 0.";

View File

@ -0,0 +1,56 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore.ops import deformable_conv2d
from mindspore import Tensor
context.set_context(device_target="GPU")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_deformable_conv2d():
""""
Feature: deformable_conv2d function
Description: Test case for simplest deformable_conv2d
Expectation: The results are as expected
"""
kh, kw = 1, 1
deformable_group = 1
stride_h, stride_w = 1, 1
pad_h, pad_w = 0, 0
dilation_h, dilation_w = 1, 1
# x shape [1, 1, 1, 2]
x = np.array([[[[-0.41675785, -0.05626683]]]]).astype(np.float32)
x = Tensor(x, mstype.float32)
# weight shape [1, 1, 1, 1]
weight = np.array([[[[-2.1361961]]]]).astype(np.float32)
weight = Tensor(weight, mstype.float32)
# offsets shape [1, 3, 1, 2]
offsets = np.array([[[[1.6402708, -1.7934356]],
[[-0.84174734, 0.5028814]],
[[-1.2452881, -1.0579522]]]]).astype(np.float32)
offsets = Tensor(offsets, mstype.float32)
out = deformable_conv2d(x, weight, offsets, (kh, kw), (1, 1, stride_h, stride_w), (pad_h, pad_h, pad_w, pad_w),
data_format="NCHW", dilations=(1, 1, dilation_h, dilation_w),
deformable_groups=deformable_group)
# expected output: [1, 1, 1, 2]
expected = np.array([[[[-0.00852099, -0.09671781]]]]).astype(np.float32)
assert np.allclose(out.asnumpy(), expected)