forked from mindspore-Ecosystem/mindspore
!34770 Add gpu DeformableOffsets
Merge pull request !34770 from tanghuikang/gpu_deformable_offset
This commit is contained in:
commit
6fe64a18c9
|
@ -0,0 +1,169 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/deformable_offsets_impl.cuh"
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
constexpr int OFFSET_NUM = 3;
|
||||
|
||||
template <typename T>
|
||||
__device__ T DefromableBilinear(const T *input, const uint width, const uint height, const T x, const T y) {
|
||||
if (y <= static_cast<T>(-1) || y >= static_cast<T>(height) || x <= static_cast<T>(-1) || x >= static_cast<T>(width)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int left = floorf(x);
|
||||
int top = floorf(y);
|
||||
int right = left + 1;
|
||||
int bottom = top + 1;
|
||||
|
||||
T l = x - static_cast<T>(left);
|
||||
T t = y - static_cast<T>(top);
|
||||
T r = static_cast<T>(1) - l;
|
||||
T b = static_cast<T>(1) - t;
|
||||
|
||||
T lt = 0;
|
||||
T lb = 0;
|
||||
if (left >= 0) {
|
||||
if (top >= 0) {
|
||||
lt = input[top * width + left];
|
||||
}
|
||||
if (bottom <= height - 1) {
|
||||
lb = input[bottom * width + left];
|
||||
}
|
||||
}
|
||||
T rt = 0;
|
||||
T rb = 0;
|
||||
if (right <= width - 1) {
|
||||
if (top >= 0) {
|
||||
rt = input[top * width + right];
|
||||
}
|
||||
if (bottom <= height - 1) {
|
||||
rb = input[bottom * width + right];
|
||||
}
|
||||
}
|
||||
|
||||
T w_lt = r * b;
|
||||
T w_rt = l * b;
|
||||
T w_lb = r * t;
|
||||
T w_rb = l * t;
|
||||
T val = (w_lt * lt + w_rt * rt + w_lb * lb + w_rb * rb);
|
||||
return val;
|
||||
}
|
||||
__global__ void GenPositionGridKernel(const uint kernel_h, const uint kernel_w, const uint stride_h,
|
||||
const uint stride_w, const uint dilations_h, const uint dilations_w,
|
||||
const uint pad_l, const uint pad_t, const uint output_w, const uint num,
|
||||
uint *position_grid) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) {
|
||||
uint y = i / output_w;
|
||||
uint x = i % output_w;
|
||||
uint pixel_y = y / kernel_h;
|
||||
uint pixel_x = x / kernel_w;
|
||||
uint kernel_y = y % kernel_h;
|
||||
uint kernel_x = x % kernel_w;
|
||||
uint index = i * 2;
|
||||
position_grid[index] = pixel_x * stride_w + kernel_x * dilations_w - pad_l;
|
||||
position_grid[index + 1] = pixel_y * stride_h + kernel_y * dilations_h - pad_t;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void DeformableOffsetsKernel(const T *input, const T *offsets, const uint *position_grid, const uint c,
|
||||
const uint output_n_dim, const uint output_c_dim, const uint output_w,
|
||||
const uint c_size_per_dfm_group, const uint offset_n_dim,
|
||||
const uint offset_mask_dim, const uint offset_group_dim,
|
||||
const uint offset_kh_dim, const uint offset_kw_dim, const uint pixel_w,
|
||||
const uint input_n_dim, const uint input_c_dim, const uint input_h,
|
||||
const uint input_w, const uint kernel_h, const uint kernel_w, const uint num,
|
||||
T *output) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) {
|
||||
// Get original input position
|
||||
const uint hw_idx = i % output_c_dim;
|
||||
const uint position_grid_idx = hw_idx * 2;
|
||||
const uint input_x = position_grid[position_grid_idx];
|
||||
const uint input_y = position_grid[position_grid_idx + 1];
|
||||
// Get offsets
|
||||
const uint n_index = i / output_n_dim;
|
||||
const uint c_index = i / output_c_dim % c;
|
||||
const uint x = hw_idx % output_w;
|
||||
const uint y = hw_idx / output_w;
|
||||
const uint dfm_group_index = c_index / c_size_per_dfm_group;
|
||||
const uint pixel_x = x / kernel_w;
|
||||
const uint pixel_y = y / kernel_h;
|
||||
const uint kernel_x = x % kernel_w;
|
||||
const uint kernel_y = y % kernel_h;
|
||||
const uint x_offsets_offset = n_index * offset_n_dim // + 0 * offset_mask_dim
|
||||
+ dfm_group_index * offset_group_dim + kernel_y * offset_kh_dim +
|
||||
kernel_x * offset_kw_dim + pixel_y * pixel_w + pixel_x;
|
||||
T x_offsets = offsets[x_offsets_offset];
|
||||
const int y_offsets_offset = x_offsets_offset + offset_mask_dim;
|
||||
T y_offsets = offsets[y_offsets_offset];
|
||||
const int mask_offset = y_offsets_offset + offset_mask_dim;
|
||||
T mask = offsets[mask_offset];
|
||||
// Deform
|
||||
T deformed_x = static_cast<T>(input_x) + x_offsets;
|
||||
T deformed_y = static_cast<T>(input_y) + y_offsets;
|
||||
const T *input_base = input + n_index * input_n_dim + c_index * input_c_dim;
|
||||
T bilinear_val = DefromableBilinear(input_base, input_w, input_h, deformed_x, deformed_y);
|
||||
output[i] = bilinear_val * mask;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void DeformableOffsets(const T *input, const T *offsets, const uint *position_grid, uint n, uint c, uint input_h,
|
||||
uint input_w, uint dfm_group, uint kernel_h, uint kernel_w, uint output_h, uint output_w,
|
||||
T *output, uint32_t device_id, cudaStream_t cuda_stream) {
|
||||
const uint pixel_w = output_w / kernel_w;
|
||||
const uint pixel_h = output_h / kernel_h;
|
||||
const uint output_c_dim = output_h * output_w;
|
||||
const uint output_n_dim = c * output_c_dim;
|
||||
const uint num = n * output_n_dim;
|
||||
const uint c_size_per_dfm_group = c / dfm_group;
|
||||
const uint offset_kw_dim = pixel_h * pixel_w;
|
||||
const uint offset_kh_dim = offset_kw_dim * kernel_w;
|
||||
const uint offset_group_dim = offset_kh_dim * kernel_h;
|
||||
const uint offset_mask_dim = offset_group_dim * dfm_group;
|
||||
const uint offset_n_dim = offset_mask_dim * OFFSET_NUM;
|
||||
const uint input_c_dim = input_h * input_w;
|
||||
const uint input_n_dim = input_c_dim * c;
|
||||
DeformableOffsetsKernel<<<CUDA_BLOCKS(device_id, num), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, offsets, position_grid, c, output_n_dim, output_c_dim, output_w, c_size_per_dfm_group, offset_n_dim,
|
||||
offset_mask_dim, offset_group_dim, offset_kh_dim, offset_kw_dim, pixel_w, input_n_dim, input_c_dim, input_h,
|
||||
input_w, kernel_h, kernel_w, num, output);
|
||||
}
|
||||
|
||||
|
||||
void GenPositionGrid(const uint kernel_h, const uint kernel_w, const uint stride_h, const uint stride_w,
|
||||
const uint dilations_h, const uint dilations_w, const uint pad_l, const uint pad_t,
|
||||
const uint output_w, const uint num, uint *position_grid, const uint32_t device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
GenPositionGridKernel<<<CUDA_BLOCKS(device_id, num), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
kernel_h, kernel_w, stride_h, stride_w, dilations_h, dilations_w, pad_l, pad_t, output_w, num, position_grid);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void DeformableOffsets<float>(const float *input, const float *offsets,
|
||||
const uint *position_grid, uint n, uint c, uint input_h,
|
||||
uint input_w, uint dfm_group, uint kernel_h, uint kernel_w,
|
||||
uint output_h, uint output_w, float *output,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void DeformableOffsets<half>(const half *input, const half *offsets,
|
||||
const uint *position_grid, uint n, uint c, uint input_h,
|
||||
uint input_w, uint dfm_group, uint kernel_h, uint kernel_w,
|
||||
uint output_h, uint output_w, half *output,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DEFORMABLE_OFFSETS_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DEFORMABLE_OFFSETS_IMPL_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
CUDA_LIB_EXPORT void GenPositionGrid(const uint kernel_h, const uint kernel_w, const uint stride_h, const uint stride_w,
|
||||
const uint dilations_h, const uint dilations_w, const uint pad_l, const uint pad_t,
|
||||
const uint output_w, const uint num, uint *position_grid, const uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void DeformableOffsets(const T *input, const T *offsets, const uint *position_grid, uint n, uint c,
|
||||
uint input_h, uint input_w, uint dfm_group, uint kernel_h, uint kernel_w,
|
||||
uint output_h, uint output_w, T *output, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DEFORMABLE_OFFSETS_IMPL_CUH_
|
|
@ -0,0 +1,184 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <mindspore/core/abstract/utils.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/device/gpu/kernel/nn/deformable_offsets_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/deformable_offsets_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kTopPadIndex = 0;
|
||||
constexpr size_t kLeftPadIndex = 2;
|
||||
constexpr size_t kKernelSizeHIndex = 0;
|
||||
constexpr size_t kKernelSizeWIndex = 1;
|
||||
constexpr size_t kInputNum = 2;
|
||||
constexpr size_t kOutputNum = 1;
|
||||
constexpr size_t kStrideAttrNum = 4;
|
||||
constexpr size_t kPadAttrNum = 4;
|
||||
constexpr size_t kKernelSizeAttrNum = 2;
|
||||
constexpr size_t kDilationAttrNum = 4;
|
||||
} // namespace
|
||||
|
||||
bool DeformableOffsetsGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
bool DeformableOffsetsGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::DeformableOffsets>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
if (inputs.size() != kInputNum || outputs.size() != kOutputNum) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it should get two inputs and one output, but got " << inputs.size()
|
||||
<< "inputs and " << outputs.size() << " outputs";
|
||||
return false;
|
||||
}
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
|
||||
if (!CheckParam(kernel_ptr)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DeformableOffsetsGpuKernelMod::CheckParam(const std::shared_ptr<ops::DeformableOffsets> &kernel) {
|
||||
data_format_ = kernel->get_data_format();
|
||||
if (data_format_ == kOpFormat_NCHW) {
|
||||
n_axis_ = 0;
|
||||
c_axis_ = 1;
|
||||
h_axis_ = 2;
|
||||
w_axis_ = 3;
|
||||
} else {
|
||||
MS_LOG(ERROR) << kernel_name_ << " only supports input with format NCHW, but got format " << data_format_;
|
||||
return false;
|
||||
}
|
||||
const auto to_unsigned = [](const int64_t &value) { return LongToUint(value); };
|
||||
const auto &strides = kernel->get_strides();
|
||||
std::transform(strides.begin(), strides.end(), std::back_inserter(strides_), to_unsigned);
|
||||
if (strides_.size() != kStrideAttrNum || strides_[n_axis_] != 1 || strides_[c_axis_] != 1) {
|
||||
MS_LOG(ERROR) << "Get invalid strides attr form " << kernel_name_
|
||||
<< ", strides should be a vector constructed by 4 integer and n&c dim should be 1, but got"
|
||||
<< strides_;
|
||||
return false;
|
||||
}
|
||||
const auto &pads = kernel->get_pads();
|
||||
std::transform(pads.begin(), pads.end(), std::back_inserter(pads_), to_unsigned);
|
||||
if (pads_.size() != kPadAttrNum) {
|
||||
MS_LOG(ERROR) << "Get invalid pads attr form " << kernel_name_
|
||||
<< ", padding should be a vector constructed by 4 integer, but got" << pads_;
|
||||
return false;
|
||||
}
|
||||
const auto &kernel_size = kernel->get_kernel_size();
|
||||
std::transform(kernel_size.begin(), kernel_size.end(), std::back_inserter(kernel_size_), to_unsigned);
|
||||
if (kernel_size_.size() != kKernelSizeAttrNum) {
|
||||
MS_LOG(ERROR) << "Get invalid ksize attr form " << kernel_name_
|
||||
<< ", ksize should be a vector constructed by 2 integer, but got" << kernel_size_;
|
||||
return false;
|
||||
}
|
||||
const auto &dilations = kernel->get_dilations();
|
||||
std::transform(dilations.begin(), dilations.end(), std::back_inserter(dilations_), to_unsigned);
|
||||
if (dilations_.size() != kDilationAttrNum || dilations_[n_axis_] != 1 || dilations_[c_axis_] != 1) {
|
||||
MS_LOG(ERROR) << "Get invalid dilations attr form " << kernel_name_
|
||||
<< ", dilations should be a vector constructed by 4 integer and n&c dim should be 1, but got"
|
||||
<< dilations_;
|
||||
return false;
|
||||
}
|
||||
deformable_groups_ = static_cast<size_t>(kernel->get_deformable_groups());
|
||||
if (deformable_groups_ <= 0) {
|
||||
MS_LOG(ERROR) << kernel_name_ << "'s deformable_groups should greater than 0, but got " << deformable_groups_;
|
||||
return false;
|
||||
}
|
||||
modulated_ = kernel->get_modulated();
|
||||
if (!modulated_) {
|
||||
MS_LOG(ERROR) << kernel_name_ << "only support v2, and the modulated should be true, but got false";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int DeformableOffsetsGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (KernelMod::Resize(base_operator, inputs, outputs) == KRET_UNKNOWN_SHAPE) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
const auto &x_shape = inputs[0]->GetShapeVector();
|
||||
n_ = x_shape[n_axis_];
|
||||
c_ = x_shape[c_axis_];
|
||||
x_h_ = x_shape[h_axis_];
|
||||
x_w_ = x_shape[w_axis_];
|
||||
const auto &y_shape = outputs[0]->GetShapeVector();
|
||||
output_h_ = y_shape[h_axis_];
|
||||
output_w_ = y_shape[w_axis_];
|
||||
position_grid_num_ = output_w_ * output_h_;
|
||||
auto position_grid_size = position_grid_num_ * 2 * sizeof(uint);
|
||||
workspace_size_list_.emplace_back(position_grid_size);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool DeformableOffsetsGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
unsigned int *position_addr = GetDeviceAddress<unsigned int>(workspace, 0);
|
||||
const size_t num = output_h_ * output_w_;
|
||||
GenPositionGrid(kernel_size_[kKernelSizeHIndex], kernel_size_[kKernelSizeWIndex], strides_[h_axis_],
|
||||
strides_[w_axis_], dilations_[h_axis_], dilations_[w_axis_], pads_[kLeftPadIndex],
|
||||
pads_[kTopPadIndex], output_w_, num, position_addr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
T *x_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *offsets_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
DeformableOffsets(x_addr, offsets_addr, position_addr, n_, c_, x_h_, x_w_, deformable_groups_,
|
||||
kernel_size_[kKernelSizeHIndex], kernel_size_[kKernelSizeWIndex], output_h_, output_w_, output_addr,
|
||||
device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, DeformableOffsetsGpuKernelMod::LaunchKernelFunc>>
|
||||
DeformableOffsetsGpuKernelMod::func_list_ = {{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32, kOpFormat_NCHW)
|
||||
.AddInputAttr(kNumberTypeFloat32, kOpFormat_NCHW)
|
||||
.AddOutputAttr(kNumberTypeFloat32, kOpFormat_NCHW),
|
||||
&DeformableOffsetsGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16, kOpFormat_NCHW)
|
||||
.AddInputAttr(kNumberTypeFloat16, kOpFormat_NCHW)
|
||||
.AddOutputAttr(kNumberTypeFloat16, kOpFormat_NCHW),
|
||||
&DeformableOffsetsGpuKernelMod::LaunchKernel<half>}};
|
||||
|
||||
std::vector<KernelAttr> DeformableOffsetsGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, LaunchKernelFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DeformableOffsets, DeformableOffsetsGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DEFORMABLE_OFFSET_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DEFORMABLE_OFFSET_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "ops/deformable_offsets.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class DeformableOffsetsGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
DeformableOffsetsGpuKernelMod() {}
|
||||
~DeformableOffsetsGpuKernelMod() override {}
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
using LaunchKernelFunc =
|
||||
std::function<bool(DeformableOffsetsGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
|
||||
static std::vector<std::pair<KernelAttr, LaunchKernelFunc>> func_list_;
|
||||
template <class T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
bool CheckParam(const std::shared_ptr<ops::DeformableOffsets> &kernel);
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
// attrs
|
||||
std::vector<uint32_t> strides_;
|
||||
std::vector<uint32_t> pads_;
|
||||
std::vector<uint32_t> kernel_size_;
|
||||
std::vector<uint32_t> dilations_;
|
||||
std::string data_format_;
|
||||
uint32_t deformable_groups_;
|
||||
bool modulated_;
|
||||
|
||||
// Constant value
|
||||
LaunchKernelFunc kernel_func_{};
|
||||
// axis
|
||||
size_t n_axis_;
|
||||
size_t c_axis_;
|
||||
size_t h_axis_;
|
||||
size_t w_axis_;
|
||||
|
||||
// Dynamic value
|
||||
uint32_t position_grid_num_;
|
||||
// x shape
|
||||
uint32_t n_;
|
||||
uint32_t c_;
|
||||
uint32_t x_h_;
|
||||
uint32_t x_w_;
|
||||
// output shape
|
||||
uint32_t output_h_;
|
||||
uint32_t output_w_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DEFORMABLE_OFFSET_H_
|
|
@ -73,6 +73,16 @@ inline std::vector<size_t> LongVecToSizeVec(const std::vector<int64_t> &vec) {
|
|||
return result;
|
||||
}
|
||||
|
||||
inline uint32_t LongToUint(int64_t u) {
|
||||
if (u < 0) {
|
||||
MS_LOG(EXCEPTION) << "The int64_t value(" << u << ") is less than 0.";
|
||||
}
|
||||
if (u > static_cast<int64_t>((std::numeric_limits<uint32_t>::max)())) {
|
||||
MS_LOG(EXCEPTION) << "The int64_t value(" << u << ") exceeds the maximum value of uint32_t.";
|
||||
}
|
||||
return static_cast<uint32_t>(u);
|
||||
}
|
||||
|
||||
inline size_t FloatToSize(float u) {
|
||||
if (u < 0) {
|
||||
MS_LOG(EXCEPTION) << "The float value(" << u << ") is less than 0.";
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import deformable_conv2d
|
||||
from mindspore import Tensor
|
||||
|
||||
context.set_context(device_target="GPU")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_deformable_conv2d():
|
||||
""""
|
||||
Feature: deformable_conv2d function
|
||||
Description: Test case for simplest deformable_conv2d
|
||||
Expectation: The results are as expected
|
||||
"""
|
||||
kh, kw = 1, 1
|
||||
deformable_group = 1
|
||||
stride_h, stride_w = 1, 1
|
||||
pad_h, pad_w = 0, 0
|
||||
dilation_h, dilation_w = 1, 1
|
||||
# x shape [1, 1, 1, 2]
|
||||
x = np.array([[[[-0.41675785, -0.05626683]]]]).astype(np.float32)
|
||||
x = Tensor(x, mstype.float32)
|
||||
# weight shape [1, 1, 1, 1]
|
||||
weight = np.array([[[[-2.1361961]]]]).astype(np.float32)
|
||||
weight = Tensor(weight, mstype.float32)
|
||||
# offsets shape [1, 3, 1, 2]
|
||||
offsets = np.array([[[[1.6402708, -1.7934356]],
|
||||
[[-0.84174734, 0.5028814]],
|
||||
[[-1.2452881, -1.0579522]]]]).astype(np.float32)
|
||||
offsets = Tensor(offsets, mstype.float32)
|
||||
out = deformable_conv2d(x, weight, offsets, (kh, kw), (1, 1, stride_h, stride_w), (pad_h, pad_h, pad_w, pad_w),
|
||||
data_format="NCHW", dilations=(1, 1, dilation_h, dilation_w),
|
||||
deformable_groups=deformable_group)
|
||||
# expected output: [1, 1, 1, 2]
|
||||
expected = np.array([[[[-0.00852099, -0.09671781]]]]).astype(np.float32)
|
||||
assert np.allclose(out.asnumpy(), expected)
|
Loading…
Reference in New Issue