forked from mindspore-Ecosystem/mindspore
!34766 add gpu deformable offset grad
Merge pull request !34766 from kisnwang/add-gpu-deformable-offset-grad
This commit is contained in:
commit
aacf006016
|
@ -0,0 +1,300 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/deformable_offsets_grad_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
template <typename T>
|
||||
__device__ void DeformableOffsetGradKernel(const uint offset_position_stride,
|
||||
const uint input_x_deformable_group_channel_stride,
|
||||
const uint input_x_w_stride, const uint input_x_h_stride,
|
||||
const uint grad_deformable_group_channel_stride, const uint dim_x_h,
|
||||
const uint dim_x_w, const uint dim_deformable_group_channel, float input_x_i,
|
||||
float input_x_j, const uint offset_index_base_pos,
|
||||
const uint input_grad_base_pos, const uint input_x_base_pos, T *input_grad,
|
||||
T *input_x, T *input_offset, T *output_grad_x, T *output_grad_offset) {
|
||||
const uint offset_index_i = offset_index_base_pos + offset_position_stride;
|
||||
const uint offset_index_weight = offset_index_base_pos + 2 * offset_position_stride;
|
||||
float offset_i = static_cast<float>(input_offset[offset_index_i]);
|
||||
float offset_j = static_cast<float>(input_offset[offset_index_base_pos]);
|
||||
float scale_weight = static_cast<float>(input_offset[offset_index_weight]);
|
||||
|
||||
float floor_offset_i = floorf(offset_i);
|
||||
float floor_offset_j = floorf(offset_j);
|
||||
float ceil_offset_i = floor_offset_i + 1;
|
||||
float ceil_offset_j = floor_offset_j + 1;
|
||||
|
||||
float floor_i = input_x_i + floor_offset_i;
|
||||
float floor_j = input_x_j + floor_offset_j;
|
||||
float ceil_i = input_x_i + ceil_offset_i;
|
||||
float ceil_j = input_x_j + ceil_offset_j;
|
||||
|
||||
float ceil_weight_i = offset_i + 1 - ceil_offset_i;
|
||||
float ceil_weight_j = offset_j + 1 - ceil_offset_j;
|
||||
float floor_weight_i = 1 - ceil_weight_i;
|
||||
float floor_weight_j = 1 - ceil_weight_j;
|
||||
|
||||
float floor_floor_weight = floor_weight_i * floor_weight_j;
|
||||
float ceil_floor_weight = ceil_weight_i * floor_weight_j;
|
||||
float floor_ceil_weight = floor_weight_i * ceil_weight_j;
|
||||
float ceil_ceil_weight = ceil_weight_i * ceil_weight_j;
|
||||
|
||||
bool floor_floor_valid = false;
|
||||
bool ceil_floor_valid = false;
|
||||
bool floor_ceil_valid = false;
|
||||
bool ceil_ceil_valid = false;
|
||||
if (floor_i >= 0 && floor_i < dim_x_h) {
|
||||
if (floor_j >= 0 && floor_j < dim_x_w) {
|
||||
floor_floor_valid = true;
|
||||
}
|
||||
if (ceil_j >= 0 && ceil_j < dim_x_w) {
|
||||
floor_ceil_valid = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (ceil_i >= 0 && ceil_i < dim_x_h) {
|
||||
if (floor_j >= 0 && floor_j < dim_x_w) {
|
||||
ceil_floor_valid = true;
|
||||
}
|
||||
if (ceil_j >= 0 && ceil_j < dim_x_w) {
|
||||
ceil_ceil_valid = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint channel = 0; channel < dim_deformable_group_channel; ++channel) {
|
||||
float grad = static_cast<float>(input_grad[input_grad_base_pos + channel * grad_deformable_group_channel_stride]);
|
||||
float grad_scale = grad * scale_weight;
|
||||
uint tmp_input_x_base_pos = input_x_base_pos + channel * input_x_deformable_group_channel_stride;
|
||||
float current_x_pos;
|
||||
float floor_floor_value = 0;
|
||||
float ceil_floor_value = 0;
|
||||
float floor_ceil_value = 0;
|
||||
float ceil_ceil_value = 0;
|
||||
uint input_x_pos = 0;
|
||||
if (floor_floor_valid) {
|
||||
current_x_pos = tmp_input_x_base_pos + floor_i * input_x_h_stride + floor_j * input_x_w_stride;
|
||||
input_x_pos = static_cast<uint>(current_x_pos);
|
||||
floor_floor_value = static_cast<float>(input_x[input_x_pos]);
|
||||
MsAtomicAdd(output_grad_x + input_x_pos, static_cast<T>(grad_scale * floor_floor_weight));
|
||||
}
|
||||
|
||||
if (ceil_floor_valid) {
|
||||
current_x_pos = tmp_input_x_base_pos + ceil_i * input_x_h_stride + floor_j * input_x_w_stride;
|
||||
input_x_pos = static_cast<uint>(current_x_pos);
|
||||
ceil_floor_value = static_cast<float>(input_x[input_x_pos]);
|
||||
MsAtomicAdd(output_grad_x + input_x_pos, static_cast<T>(grad_scale * ceil_floor_weight));
|
||||
}
|
||||
|
||||
if (floor_ceil_valid) {
|
||||
current_x_pos = tmp_input_x_base_pos + floor_i * input_x_h_stride + ceil_j * input_x_w_stride;
|
||||
input_x_pos = static_cast<uint>(current_x_pos);
|
||||
floor_ceil_value = static_cast<float>(input_x[input_x_pos]);
|
||||
MsAtomicAdd(output_grad_x + input_x_pos, static_cast<T>(grad_scale * floor_ceil_weight));
|
||||
}
|
||||
|
||||
if (ceil_ceil_valid) {
|
||||
input_x_pos = tmp_input_x_base_pos + ceil_i * input_x_h_stride + ceil_j * input_x_w_stride;
|
||||
ceil_ceil_value = static_cast<float>(input_x[input_x_pos]);
|
||||
MsAtomicAdd(output_grad_x + input_x_pos, static_cast<T>(grad_scale * ceil_ceil_weight));
|
||||
}
|
||||
|
||||
float delta = -floor_floor_value * floor_weight_j + ceil_floor_value * ceil_weight_j -
|
||||
floor_ceil_value * floor_weight_j + ceil_ceil_value * ceil_weight_j;
|
||||
delta *= grad_scale;
|
||||
output_grad_offset[offset_index_i] += static_cast<T>(delta);
|
||||
|
||||
delta = -floor_floor_value * floor_weight_i - ceil_floor_value * floor_weight_i + floor_ceil_value * ceil_weight_i +
|
||||
ceil_ceil_value * ceil_weight_i;
|
||||
delta *= grad_scale;
|
||||
output_grad_offset[offset_index_base_pos] += static_cast<T>(delta);
|
||||
|
||||
delta = floor_floor_value * floor_floor_weight + ceil_floor_value * ceil_floor_weight +
|
||||
floor_ceil_value * floor_ceil_weight + ceil_ceil_value * ceil_ceil_weight;
|
||||
delta *= grad;
|
||||
output_grad_offset[offset_index_weight] += static_cast<T>(delta);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void DeformableOffsetGradNHWCKernel(const uint num_kernels, const uint dim_x_n, const uint dim_x_h,
|
||||
const uint dim_x_w, const uint dim_offset_h, const uint dim_offset_w,
|
||||
const uint dim_kernel_h, const uint dim_kernel_w, const uint dim_pad_top,
|
||||
const uint dim_pad_left, const uint dim_stride_h,
|
||||
const uint dim_stride_w, const uint dim_dilation_h,
|
||||
const uint dim_dilation_w, const uint dim_deformable_group,
|
||||
const uint dim_deformable_group_channel, T *input_grad, T *input_x,
|
||||
T *input_offset, T *output_grad_x, T *output_grad_offset) {
|
||||
const uint offset_kernel_w_stride = 1;
|
||||
const uint offset_kernel_h_stride = dim_kernel_w * offset_kernel_w_stride;
|
||||
const uint offset_deformable_group_stride = dim_kernel_h * offset_kernel_h_stride;
|
||||
const uint offset_position_stride = dim_deformable_group * offset_deformable_group_stride;
|
||||
const uint offset_offset_w_stride = 3 * offset_position_stride;
|
||||
const uint offset_offset_h_stride = dim_offset_w * offset_offset_w_stride;
|
||||
const uint offset_n_stride = dim_offset_h * offset_offset_h_stride;
|
||||
|
||||
const uint grad_deformable_group_channel_stride = 1;
|
||||
const uint grad_deformable_group_stride = dim_deformable_group_channel * grad_deformable_group_channel_stride;
|
||||
const uint grad_kernel_w_stride = dim_deformable_group * grad_deformable_group_stride;
|
||||
const uint grad_offset_w_stride = dim_kernel_w * grad_kernel_w_stride;
|
||||
const uint grad_kernel_h_stride = dim_offset_w * grad_offset_w_stride;
|
||||
const uint grad_offset_h_stride = dim_kernel_h * grad_kernel_h_stride;
|
||||
const uint grad_n_stride = dim_offset_h * grad_offset_h_stride;
|
||||
|
||||
const uint input_x_deformable_group_channel_stride = 1;
|
||||
const uint input_x_deformable_group_stride = dim_deformable_group_channel * input_x_deformable_group_channel_stride;
|
||||
const uint input_x_w_stride = dim_deformable_group * input_x_deformable_group_stride;
|
||||
const uint input_x_h_stride = dim_x_w * input_x_w_stride;
|
||||
const uint input_x_n_stride = dim_x_h * input_x_h_stride;
|
||||
|
||||
for (uint index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += gridDim.x * blockDim.x) {
|
||||
const uint offset_index_kernel_j = index % dim_kernel_w;
|
||||
uint tmp = index / dim_kernel_w;
|
||||
const uint offset_index_kernel_i = tmp % dim_kernel_h;
|
||||
tmp = tmp / dim_kernel_h;
|
||||
const uint offset_index_deformable_group_i = tmp % dim_deformable_group;
|
||||
tmp = tmp / dim_deformable_group;
|
||||
const uint offset_index_offset_j = tmp % dim_offset_w;
|
||||
tmp = tmp / dim_offset_w;
|
||||
const uint offset_index_offset_i = tmp % dim_offset_h;
|
||||
const uint offset_index_n_i = tmp / dim_offset_h;
|
||||
|
||||
const uint offset_index_base_pos =
|
||||
offset_index_n_i * offset_n_stride + offset_index_deformable_group_i * offset_deformable_group_stride +
|
||||
offset_index_kernel_i * offset_kernel_h_stride + offset_index_kernel_j * offset_kernel_w_stride +
|
||||
offset_index_offset_i * offset_offset_h_stride + offset_index_offset_j * offset_offset_w_stride;
|
||||
const uint input_grad_base_pos =
|
||||
offset_index_n_i * grad_n_stride + offset_index_offset_i * grad_offset_h_stride +
|
||||
offset_index_offset_j * grad_offset_w_stride + offset_index_kernel_i * grad_kernel_h_stride +
|
||||
offset_index_kernel_j * grad_kernel_w_stride + offset_index_deformable_group_i * grad_deformable_group_stride;
|
||||
const uint input_x_base_pos =
|
||||
offset_index_n_i * input_x_n_stride + offset_index_deformable_group_i * input_x_deformable_group_stride;
|
||||
float input_x_i = -1.0 * dim_pad_top;
|
||||
float input_x_j = -1.0 * dim_pad_left;
|
||||
input_x_i += offset_index_offset_i * dim_stride_h + offset_index_kernel_i * dim_dilation_h;
|
||||
input_x_j += offset_index_offset_j * dim_stride_w + offset_index_kernel_j * dim_dilation_w;
|
||||
|
||||
DeformableOffsetGradKernel(offset_position_stride, input_x_deformable_group_channel_stride, input_x_w_stride,
|
||||
input_x_h_stride, grad_deformable_group_channel_stride, dim_x_h, dim_x_w,
|
||||
dim_deformable_group_channel, input_x_i, input_x_j, offset_index_base_pos,
|
||||
input_grad_base_pos, input_x_base_pos, input_grad, input_x, input_offset, output_grad_x,
|
||||
output_grad_offset);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void DeformableOffsetGradNCHWKernel(const uint num_kernels, const uint dim_x_n, const uint dim_x_h,
|
||||
const uint dim_x_w, const uint dim_offset_h, const uint dim_offset_w,
|
||||
const uint dim_kernel_h, const uint dim_kernel_w, const uint dim_pad_top,
|
||||
const uint dim_pad_left, const uint dim_stride_h,
|
||||
const uint dim_stride_w, const uint dim_dilation_h,
|
||||
const uint dim_dilation_w, const uint dim_deformable_group,
|
||||
const uint dim_deformable_group_channel, T *input_grad, T *input_x,
|
||||
T *input_offset, T *output_grad_x, T *output_grad_offset) {
|
||||
const uint offset_offset_w_stride = 1;
|
||||
const uint offset_offset_h_stride = dim_offset_w * offset_offset_w_stride;
|
||||
const uint offset_kernel_w_stride = dim_offset_h * offset_offset_h_stride;
|
||||
const uint offset_kernel_h_stride = dim_kernel_w * offset_kernel_w_stride;
|
||||
const uint offset_deformable_group_stride = dim_kernel_h * offset_kernel_h_stride;
|
||||
const uint offset_position_stride = dim_deformable_group * offset_deformable_group_stride;
|
||||
const uint offset_n_stride = 3 * offset_position_stride;
|
||||
|
||||
const uint grad_kernel_w_stride = 1;
|
||||
const uint grad_offset_w_stride = dim_kernel_w * grad_kernel_w_stride;
|
||||
const uint grad_kernel_h_stride = dim_offset_w * grad_offset_w_stride;
|
||||
const uint grad_offset_h_stride = dim_kernel_h * grad_kernel_h_stride;
|
||||
const uint grad_deformable_group_channel_stride = dim_offset_h * grad_offset_h_stride;
|
||||
const uint grad_deformable_group_stride = dim_deformable_group_channel * grad_deformable_group_channel_stride;
|
||||
const uint grad_n_stride = dim_deformable_group * grad_deformable_group_stride;
|
||||
|
||||
const uint input_x_w_stride = 1;
|
||||
const uint input_x_h_stride = dim_x_w * input_x_w_stride;
|
||||
const uint input_x_deformable_group_channel_stride = dim_x_h * input_x_h_stride;
|
||||
const uint input_x_deformable_group_stride = dim_deformable_group_channel * input_x_deformable_group_channel_stride;
|
||||
const uint input_x_n_stride = dim_deformable_group * input_x_deformable_group_stride;
|
||||
|
||||
for (uint index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += gridDim.x * blockDim.x) {
|
||||
const uint offset_index_offset_j = index % dim_offset_w;
|
||||
uint tmp = index / dim_offset_w;
|
||||
const uint offset_index_offset_i = tmp % dim_offset_h;
|
||||
tmp = tmp / dim_offset_h;
|
||||
const uint offset_index_kernel_j = tmp % dim_kernel_w;
|
||||
tmp = tmp / dim_kernel_w;
|
||||
const uint offset_index_kernel_i = tmp % dim_kernel_h;
|
||||
tmp = tmp / dim_kernel_h;
|
||||
const uint offset_index_deformable_group_i = tmp % dim_deformable_group;
|
||||
const uint offset_index_n_i = tmp / dim_deformable_group;
|
||||
|
||||
float input_x_i = -1.0 * dim_pad_top;
|
||||
float input_x_j = -1.0 * dim_pad_left;
|
||||
input_x_i += offset_index_offset_i * dim_stride_h + offset_index_kernel_i * dim_dilation_h;
|
||||
input_x_j += offset_index_offset_j * dim_stride_w + offset_index_kernel_j * dim_dilation_w;
|
||||
|
||||
const uint offset_index_base_pos =
|
||||
offset_index_n_i * offset_n_stride + offset_index_deformable_group_i * offset_deformable_group_stride +
|
||||
offset_index_kernel_i * offset_kernel_h_stride + offset_index_kernel_j * offset_kernel_w_stride +
|
||||
offset_index_offset_i * offset_offset_h_stride + offset_index_offset_j * offset_offset_w_stride;
|
||||
const uint input_grad_base_pos =
|
||||
offset_index_n_i * grad_n_stride + offset_index_offset_i * grad_offset_h_stride +
|
||||
offset_index_offset_j * grad_offset_w_stride + offset_index_kernel_i * grad_kernel_h_stride +
|
||||
offset_index_kernel_j * grad_kernel_w_stride + offset_index_deformable_group_i * grad_deformable_group_stride;
|
||||
const uint input_x_base_pos =
|
||||
offset_index_n_i * input_x_n_stride + offset_index_deformable_group_i * input_x_deformable_group_stride;
|
||||
|
||||
DeformableOffsetGradKernel(offset_position_stride, input_x_deformable_group_channel_stride, input_x_w_stride,
|
||||
input_x_h_stride, grad_deformable_group_channel_stride, dim_x_h, dim_x_w,
|
||||
dim_deformable_group_channel, input_x_i, input_x_j, offset_index_base_pos,
|
||||
input_grad_base_pos, input_x_base_pos, input_grad, input_x, input_offset, output_grad_x,
|
||||
output_grad_offset);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ApplyDeformableOffsetGrad(const uint dim_x_n, const uint dim_x_h, const uint dim_x_w, const uint dim_offset_h,
|
||||
const uint dim_offset_w, const uint dim_kernel_h, const uint dim_kernel_w,
|
||||
const uint dim_pad_top, const uint dim_pad_left, const uint dim_stride_h,
|
||||
const uint dim_stride_w, const uint dim_dilation_h, const uint dim_dilation_w,
|
||||
const uint dim_deformable_group, const uint dim_deformable_group_channel, bool nchw,
|
||||
T *input_grad, T *input_x, T *input_offset, T *output_grad_x, T *output_grad_offset,
|
||||
const uint device_id, cudaStream_t cuda_stream) {
|
||||
const uint num_kernels = dim_x_n * dim_offset_h * dim_offset_w * dim_kernel_h * dim_kernel_w * dim_deformable_group;
|
||||
if (nchw) {
|
||||
DeformableOffsetGradNCHWKernel<<<CUDA_BLOCKS(device_id, num_kernels), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
num_kernels, dim_x_n, dim_x_h, dim_x_w, dim_offset_h, dim_offset_w, dim_kernel_h, dim_kernel_w, dim_pad_top,
|
||||
dim_pad_left, dim_stride_h, dim_stride_w, dim_dilation_h, dim_dilation_w, dim_deformable_group,
|
||||
dim_deformable_group_channel, input_grad, input_x, input_offset, output_grad_x, output_grad_offset);
|
||||
} else {
|
||||
DeformableOffsetGradNHWCKernel<<<CUDA_BLOCKS(device_id, num_kernels), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
num_kernels, dim_x_n, dim_x_h, dim_x_w, dim_offset_h, dim_offset_w, dim_kernel_h, dim_kernel_w, dim_pad_top,
|
||||
dim_pad_left, dim_stride_h, dim_stride_w, dim_dilation_h, dim_dilation_w, dim_deformable_group,
|
||||
dim_deformable_group_channel, input_grad, input_x, input_offset, output_grad_x, output_grad_offset);
|
||||
}
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyDeformableOffsetGrad<float>(
|
||||
const uint dim_x_n, const uint dim_x_h, const uint dim_x_w, const uint dim_offset_h, const uint dim_offset_w,
|
||||
const uint dim_kernel_h, const uint dim_kernel_w, const uint dim_pad_top, const uint dim_pad_left,
|
||||
const uint dim_stride_h, const uint dim_stride_w, const uint dim_dilation_h, const uint dim_dilation_w,
|
||||
const uint dim_deformable_group, const uint dim_deformable_group_channel, bool nchw, float *input_grad,
|
||||
float *input_x, float *input_offset, float *output_grad_x, float *output_grad_offset, const uint device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyDeformableOffsetGrad<half>(
|
||||
const uint dim_x_n, const uint dim_x_h, const uint dim_x_w, const uint dim_offset_h, const uint dim_offset_w,
|
||||
const uint dim_kernel_h, const uint dim_kernel_w, const uint dim_pad_top, const uint dim_pad_left,
|
||||
const uint dim_stride_h, const uint dim_stride_w, const uint dim_dilation_h, const uint dim_dilation_w,
|
||||
const uint dim_deformable_group, const uint dim_deformable_group_channel, bool nchw, half *input_grad, half *input_x,
|
||||
half *input_offset, half *output_grad_x, half *output_grad_offset, const uint device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DEFORMABLE_OFFSETS_GRAD_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DEFORMABLE_OFFSETS_GRAD_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void ApplyDeformableOffsetGrad(
|
||||
const uint dim_x_n, const uint dim_x_h, const uint dim_x_w, const uint dim_offset_h, const uint dim_offset_w,
|
||||
const uint dim_kernel_h, const uint dim_kernel_w, const uint dim_pad_top, const uint dim_pad_left,
|
||||
const uint dim_stride_h, const uint dim_stride_w, const uint dim_dilation_h, const uint dim_dilation_w,
|
||||
const uint dim_deformable_group, const uint dim_deformable_group_channel, bool nchw, T *input_grad, T *input_x,
|
||||
T *input_offset, T *output_grad_x, T *output_grad_offset, const uint device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DEFORMABLE_OFFSETS_GRAD_IMPL_CUH_
|
|
@ -0,0 +1,244 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/gpu/kernel/nn/deformable_offsets_grad_gpu_kernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include "abstract/utils.h"
|
||||
#include "mindspore/core/ops/grad/deformable_offsets_grad.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kInputNum = 3;
|
||||
constexpr size_t kOutputNum = 2;
|
||||
constexpr size_t kInputShapeSize = 4;
|
||||
|
||||
constexpr size_t kGradIndex = 0;
|
||||
constexpr size_t kXIndex = 1;
|
||||
constexpr size_t kOffsetIndex = 2;
|
||||
constexpr size_t kGradXIndex = 0;
|
||||
constexpr size_t kGradOffsetIndex = 1;
|
||||
|
||||
constexpr size_t kPadNum = 4;
|
||||
constexpr size_t kStrideNum = 4;
|
||||
constexpr size_t kDilationNum = 4;
|
||||
constexpr size_t kKernelSizeNum = 2;
|
||||
|
||||
constexpr size_t kPadTopIndex = 0;
|
||||
constexpr size_t kPadLeftIndex = 2;
|
||||
constexpr size_t kStrideHIndex = 2;
|
||||
constexpr size_t kStrideWIndex = 3;
|
||||
constexpr size_t kDilationHIndex = 2;
|
||||
constexpr size_t kDilationWIndex = 3;
|
||||
constexpr size_t kKernelHIndex = 0;
|
||||
constexpr size_t kKernelWIndex = 1;
|
||||
|
||||
constexpr size_t kCIndexForNCHW = 1;
|
||||
constexpr size_t kHIndexForNCHW = 2;
|
||||
constexpr size_t kWIndexForNCHW = 3;
|
||||
|
||||
constexpr size_t kHIndexForNHWC = 1;
|
||||
constexpr size_t kWIndexForNHWC = 2;
|
||||
constexpr size_t kCIndexForNHWC = 3;
|
||||
constexpr size_t kOffsetChannel = 3;
|
||||
|
||||
void CheckSize(const std::string &kernel_name, const std::string &dim_name, size_t expect, size_t actual) {
|
||||
if (actual != expect) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the length of '" << dim_name << "' must be " << expect
|
||||
<< ", but got " << actual;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool DeformableOffsetsGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.size() != kInputNum || outputs.size() != kOutputNum) {
|
||||
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kInputNum << " and " << kOutputNum
|
||||
<< ", but get " << inputs.size() << " and " << outputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::DeformableOffsetsGrad>(base_operator);
|
||||
if (kernel_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Cast DeformableOffsetsGrad failed!";
|
||||
return false;
|
||||
}
|
||||
data_format_ = kernel_ptr->get_format();
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(0).first);
|
||||
return true;
|
||||
}
|
||||
|
||||
void DeformableOffsetsGradGpuKernelMod::SetDims(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::DeformableOffsetsGrad>(base_operator);
|
||||
if (kernel_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast DeformableOffsetsGrad failed!";
|
||||
}
|
||||
dims_.deformable_group = LongToUint(kernel_ptr->get_deformable_groups());
|
||||
if (dims_.deformable_group == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', deformable group must be greater than 0.";
|
||||
}
|
||||
std::vector<int64_t> pad = kernel_ptr->get_pads();
|
||||
CheckSize(kernel_name_, "pads", kPadNum, pad.size());
|
||||
dims_.pad_top = LongToUint(pad[kPadTopIndex]);
|
||||
dims_.pad_left = LongToUint(pad[kPadLeftIndex]);
|
||||
|
||||
std::vector<int64_t> stride = kernel_ptr->get_strides();
|
||||
CheckSize(kernel_name_, "strides", kStrideNum, stride.size());
|
||||
dims_.stride_h = LongToUint(stride[kStrideHIndex]);
|
||||
dims_.stride_w = LongToUint(stride[kStrideWIndex]);
|
||||
|
||||
std::vector<int64_t> dilation = kernel_ptr->get_dilations();
|
||||
CheckSize(kernel_name_, "dilations", kDilationNum, dilation.size());
|
||||
dims_.dilation_h = LongToUint(dilation[kDilationHIndex]);
|
||||
dims_.dilation_w = LongToUint(dilation[kDilationWIndex]);
|
||||
|
||||
std::vector<int64_t> ksize = kernel_ptr->get_kernel_size();
|
||||
CheckSize(kernel_name_, "ksize", kKernelSizeNum, ksize.size());
|
||||
dims_.kernel_h = LongToUint(ksize[kKernelHIndex]);
|
||||
dims_.kernel_w = LongToUint(ksize[kKernelWIndex]);
|
||||
if (dims_.kernel_h == 0 || dims_.kernel_w == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of 'ksize' must be larger than 0.";
|
||||
}
|
||||
auto x_shape = inputs[kXIndex]->GetShapeVector();
|
||||
CheckSize(kernel_name_, "input_x", kInputShapeSize, x_shape.size());
|
||||
dims_.x_n = LongToUint(x_shape[0]);
|
||||
auto grad_shape = inputs[kGradIndex]->GetShapeVector();
|
||||
CheckSize(kernel_name_, "input_grad", kInputShapeSize, grad_shape.size());
|
||||
if (data_format_ == kOpFormat_NCHW) {
|
||||
dims_.grad_h = LongToUint(grad_shape[kHIndexForNCHW]);
|
||||
dims_.grad_w = LongToUint(grad_shape[kWIndexForNCHW]);
|
||||
dims_.x_h = LongToUint(x_shape[kHIndexForNCHW]);
|
||||
dims_.x_w = LongToUint(x_shape[kWIndexForNCHW]);
|
||||
dims_.deformable_group_channel = LongToUint(x_shape[kCIndexForNCHW]) / dims_.deformable_group;
|
||||
} else {
|
||||
dims_.grad_h = LongToUint(grad_shape[kHIndexForNHWC]);
|
||||
dims_.grad_w = LongToUint(grad_shape[kWIndexForNHWC]);
|
||||
dims_.x_h = LongToUint(x_shape[kHIndexForNHWC]);
|
||||
dims_.x_w = LongToUint(x_shape[kWIndexForNHWC]);
|
||||
dims_.deformable_group_channel = LongToUint(x_shape[kCIndexForNHWC]) / dims_.deformable_group;
|
||||
}
|
||||
dims_.offset_h = dims_.grad_h / dims_.kernel_h;
|
||||
dims_.offset_w = dims_.grad_w / dims_.kernel_w;
|
||||
|
||||
auto grad_x_shape = outputs[kGradXIndex]->GetShapeVector();
|
||||
grad_x_size_ = std::accumulate(grad_x_shape.begin(), grad_x_shape.end(), type_size_, std::multiplies<size_t>());
|
||||
|
||||
auto grad_offset_shape = outputs[kGradOffsetIndex]->GetShapeVector();
|
||||
grad_offset_size_ =
|
||||
std::accumulate(grad_offset_shape.begin(), grad_offset_shape.end(), type_size_, std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
int DeformableOffsetsGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
if (input_size_list_.size() != kInputNum || output_size_list_.size() != kOutputNum) {
|
||||
MS_LOG(ERROR) << kernel_name_ << " resize : input and output size should be " << kInputNum << " and " << kOutputNum
|
||||
<< ", but get " << input_size_list_.size() << " and " << output_size_list_.size();
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
SetDims(base_operator, inputs, outputs);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DeformableOffsetsGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *grad_addr = GetDeviceAddress<T>(inputs, kGradIndex);
|
||||
T *x_addr = GetDeviceAddress<T>(inputs, kXIndex);
|
||||
T *offset_addr = GetDeviceAddress<T>(inputs, kOffsetIndex);
|
||||
T *grad_x_addr = GetDeviceAddress<T>(outputs, kGradXIndex);
|
||||
T *grad_offset_addr = GetDeviceAddress<T>(outputs, kGradOffsetIndex);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(grad_x_addr, 0, grad_x_size_, cuda_stream_),
|
||||
"Call cudaMemsetAsync grad_x failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(grad_offset_addr, 0, grad_offset_size_, cuda_stream_),
|
||||
"Call cudaMemsetAsync grad_x failed");
|
||||
uint dim_x_n = dims_.x_n;
|
||||
uint dim_x_h = dims_.x_h;
|
||||
uint dim_x_w = dims_.x_w;
|
||||
uint dim_offset_h = dims_.offset_h;
|
||||
uint dim_offset_w = dims_.offset_w;
|
||||
uint dim_kernel_h = dims_.kernel_h;
|
||||
uint dim_kernel_w = dims_.kernel_w;
|
||||
uint dim_pad_top = dims_.pad_top;
|
||||
uint dim_pad_left = dims_.pad_left;
|
||||
uint dim_stride_h = dims_.stride_h;
|
||||
uint dim_stride_w = dims_.stride_w;
|
||||
uint dim_dilation_h = dims_.dilation_h;
|
||||
uint dim_dilation_w = dims_.dilation_w;
|
||||
uint dim_deformable_group = dims_.deformable_group;
|
||||
uint dim_deformable_group_channel = dims_.deformable_group_channel;
|
||||
if (data_format_ == kOpFormat_NCHW) {
|
||||
ApplyDeformableOffsetGrad(dim_x_n, dim_x_h, dim_x_w, dim_offset_h, dim_offset_w, dim_kernel_h, dim_kernel_w,
|
||||
dim_pad_top, dim_pad_left, dim_stride_h, dim_stride_w, dim_dilation_h, dim_dilation_w,
|
||||
dim_deformable_group, dim_deformable_group_channel, true, grad_addr, x_addr, offset_addr,
|
||||
grad_x_addr, grad_offset_addr, device_id_, cuda_stream_);
|
||||
} else {
|
||||
ApplyDeformableOffsetGrad(dim_x_n, dim_x_h, dim_x_w, dim_offset_h, dim_offset_w, dim_kernel_h, dim_kernel_w,
|
||||
dim_pad_top, dim_pad_left, dim_stride_h, dim_stride_w, dim_dilation_h, dim_dilation_w,
|
||||
dim_deformable_group, dim_deformable_group_channel, false, grad_addr, x_addr, offset_addr,
|
||||
grad_x_addr, grad_offset_addr, device_id_, cuda_stream_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, DeformableOffsetsGradGpuKernelMod::KernelFunc>>
|
||||
DeformableOffsetsGradGpuKernelMod::func_list_ = {{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&DeformableOffsetsGradGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&DeformableOffsetsGradGpuKernelMod::LaunchKernel<float>}};
|
||||
|
||||
std::vector<KernelAttr> DeformableOffsetsGradGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, KernelFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DeformableOffsetsGrad, DeformableOffsetsGradGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,93 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_DEFORMABLE_OFFSETS_GRAD_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_DEFORMABLE_OFFSETS_GRAD_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/kernel_constants.h"
|
||||
#include "plugin/device/gpu/kernel/nn/deformable_offsets_grad_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/deformable_offsets_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class DeformableOffsetsGradGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
DeformableOffsetsGradGpuKernelMod() = default;
|
||||
~DeformableOffsetsGradGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
struct DeformableOffsetsGradDims {
|
||||
uint x_n;
|
||||
uint x_h;
|
||||
uint x_w;
|
||||
uint offset_h;
|
||||
uint offset_w;
|
||||
uint grad_h;
|
||||
uint grad_w;
|
||||
uint kernel_h;
|
||||
uint kernel_w;
|
||||
uint pad_top;
|
||||
uint pad_left;
|
||||
uint stride_h;
|
||||
uint stride_w;
|
||||
uint dilation_h;
|
||||
uint dilation_w;
|
||||
uint deformable_group;
|
||||
uint deformable_group_channel;
|
||||
};
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
using KernelFunc = std::function<bool(DeformableOffsetsGradGpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &)>;
|
||||
|
||||
void SetDims(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs);
|
||||
|
||||
std::string kernel_name_;
|
||||
cudaStream_t cuda_stream_{nullptr};
|
||||
KernelFunc kernel_func_{};
|
||||
static std::vector<std::pair<KernelAttr, KernelFunc>> func_list_;
|
||||
std::string data_format_ = kOpFormat_NCHW;
|
||||
DeformableOffsetsGradDims dims_;
|
||||
size_t grad_x_size_{0};
|
||||
size_t grad_offset_size_{0};
|
||||
size_t type_size_{1};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_DEFORMABLE_OFFSETS_GRAD_KERNEL_H_
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
class NetDeformableOffsetsGrad(nn.Cell):
|
||||
def __init__(self, data_format):
|
||||
super(NetDeformableOffsetsGrad, self).__init__()
|
||||
strides = (1, 1, 1, 1)
|
||||
pads = (0, 0, 0, 0)
|
||||
ksize = (3, 3)
|
||||
self.grad_op = G.DeformableOffsetsGrad(strides, pads, ksize, data_format=data_format)
|
||||
|
||||
def construct(self, grad, input_x, offsets):
|
||||
return self.grad_op(grad, input_x, offsets)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float16, np.float32])
|
||||
def test_deformable_offsets_grad_nchw(dtype):
|
||||
"""
|
||||
Feature: DeformableOffsetsGrad gpu kernel
|
||||
Description: test the rightness of DeformableOffsetsGrad gpu kernel
|
||||
Expectation: the output is same as expected result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
net = NetDeformableOffsetsGrad(data_format="NCHW")
|
||||
dout = Tensor(np.ones([1, 2, 3, 3]).astype(dtype))
|
||||
x = Tensor(np.ones([1, 2, 4, 4]).astype(dtype))
|
||||
offsets = Tensor(np.ones([1, 27, 1, 1]).astype(dtype) * 0.1)
|
||||
output = net(dout, x, offsets)
|
||||
|
||||
expect_grad_x = np.array([[[0.081, 0.09, 0.09, 0.009],
|
||||
[0.09, 0.1, 0.1, 0.01],
|
||||
[0.09, 0.1, 0.1, 0.01],
|
||||
[0.009, 0.01, 0.01, 0.001]],
|
||||
[[0.081, 0.09, 0.09, 0.009],
|
||||
[0.09, 0.1, 0.1, 0.01],
|
||||
[0.09, 0.1, 0.1, 0.01],
|
||||
[0.009, 0.01, 0.01, 0.001]]]
|
||||
).astype(dtype)
|
||||
expect_grad_offset = np.array([-0.32] * 18 + [2.0] * 9).astype(dtype).reshape([1, 27, 1, 1])
|
||||
rtol = 1e-5
|
||||
if dtype == np.float16:
|
||||
rtol = 1e-3
|
||||
assert np.allclose(output[0].asnumpy(), expect_grad_x, rtol)
|
||||
assert np.allclose(output[1].asnumpy(), expect_grad_offset, rtol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float16, np.float32])
|
||||
def test_deformable_offsets_grad_nhwc(dtype):
|
||||
"""
|
||||
Feature: DeformableOffsetsGrad gpu kernel
|
||||
Description: test the rightness of DeformableOffsetsGrad gpu kernel
|
||||
Expectation: the output is same as expected result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
net = NetDeformableOffsetsGrad(data_format="NHWC")
|
||||
dout = Tensor(np.ones([1, 3, 3, 2]).astype(dtype))
|
||||
x = Tensor(np.ones([1, 4, 4, 2]).astype(dtype))
|
||||
offsets = Tensor(np.ones([1, 1, 1, 27]).astype(dtype) * 0.1)
|
||||
output = net(dout, x, offsets)
|
||||
expect_grad_x = np.array([[[0.081, 0.081],
|
||||
[0.09, 0.09],
|
||||
[0.09, 0.09],
|
||||
[0.009, 0.009]],
|
||||
[[0.09, 0.09],
|
||||
[0.1, 0.1],
|
||||
[0.1, 0.1],
|
||||
[0.01, 0.01]],
|
||||
[[0.09, 0.09],
|
||||
[0.1, 0.1],
|
||||
[0.1, 0.1],
|
||||
[0.01, 0.01]],
|
||||
[[0.009, 0.009],
|
||||
[0.01, 0.01],
|
||||
[0.01, 0.01],
|
||||
[0.001, 0.001]]
|
||||
]
|
||||
).astype(dtype)
|
||||
expect_grad_offset = np.array([-0.32] * 18 + [2.0] * 9).astype(dtype).reshape([1, 1, 1, 27])
|
||||
rtol = 1e-5
|
||||
if dtype == np.float16:
|
||||
rtol = 1e-3
|
||||
assert np.allclose(output[0].asnumpy(), expect_grad_x, rtol)
|
||||
assert np.allclose(output[1].asnumpy(), expect_grad_offset, rtol)
|
Loading…
Reference in New Issue