Reworked mirrorPad

lintIssues

convert long -> int64

correcting int64 -> int64_t

lint
This commit is contained in:
danishnxt 2020-10-25 20:27:54 -04:00
parent 70bb0a842a
commit 0f69be06b1
9 changed files with 343 additions and 127 deletions

View File

@ -18,126 +18,214 @@
#include <stdint.h>
#include "backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh"
// check for existence in current padded array on X and Y dims
__inline__ __device__ bool range_check(int x, int y, int padded_width, int padded_height) {
// check for existence in current padded array
if (((x >= 0) && (x <= padded_width - 1)) && ((y >= 0) && (y <= padded_height - 1))) {
return true;
}
return false;
}
template <typename T>
__global__ void MirrorPad(const size_t size, const T *input, const int num, const int channels, const int old_height,
const int old_width, const int padded_height, const int padded_width, const int padd_dim,
const int *paddings, int mode, T *output) {
int padd_offset = 4 * (padd_dim - 2);
int pad_left_ = paddings[padd_offset + 4];
int pad_top_ = paddings[padd_offset + 0];
// extract paddings from correct positions given variable paddings_arg size
__inline__ __device__ void extract_paddings(const int64_t *paddings_arg, int padd_dim, int64_t *extracted_paddings) {
const int paddings_offset = MAX_PADDINGS - padd_dim;
for (int i = 0; i < padd_dim; i++) {
extracted_paddings[(paddings_offset + i) * PADDING_SIZE] = paddings_arg[i * PADDING_SIZE];
extracted_paddings[(paddings_offset + i) * PADDING_SIZE + 1] = paddings_arg[i * PADDING_SIZE + 1];
}
}
// Create anchor points for old tensor positions inside new tensor
int ap1_x = pad_left_;
int ap1_y = pad_top_;
int ap2_x = pad_left_ + old_width - 1;
int ap2_y = pad_top_ + old_height - 1;
// for every position, first calculate position it mirrors from in the new padded array
// adjust calculated position to origin dx array dimensions and copy value
template <typename T>
__global__ void MirrorPad(const size_t size, const T *input, const int old_batch, const int old_channel,
const int old_height, const int old_width, const int padded_height, const int padded_width,
const int padd_dim, const int64_t *paddings_arg, int mode, T *output) {
int64_t paddings[MAX_PADDINGS * PADDING_SIZE]; // local and fixed size to keep in registers
for (int i = 0; i < MAX_PADDINGS * PADDING_SIZE; i++) {
paddings[i] = 0;
}
extract_paddings(paddings_arg, padd_dim, paddings);
// Create anchor points for non mirrored data inside new tensor
int ap1_x = paddings[WIDTH + LEFT];
int ap2_x = paddings[WIDTH + LEFT] + old_width - 1;
int ap1_y = paddings[HEIGHT + TOP];
int ap2_y = paddings[HEIGHT + TOP] + old_height - 1;
int ap1_channel = paddings[CHANNEL + LEFT];
int ap2_channel = paddings[CHANNEL + LEFT] + old_channel - 1;
int ap1_batch = paddings[BATCH + LEFT];
int ap2_batch = paddings[BATCH + LEFT] + old_batch - 1;
int channels_new = old_channel + paddings[CHANNEL + LEFT] + paddings[CHANNEL + RIGHT];
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
int block_num = (pos / padded_width) / padded_height;
// cur position
const int padded_x = pos % padded_width;
const int padded_y = (pos / padded_width) % padded_height;
// distance to move from anchor point
const int padded_channel = block_num % channels_new;
const int padded_batch = block_num / channels_new;
// distance from anchor points
// can be +/- depending on position
int x_dist = 0;
int y_dist = 0;
int channel_dist = 0;
int batch_dist = 0;
// x,y value to mirror in new tenspr
// data to mirror from in new tensor dims
int matchval_x_index = padded_x;
int matchval_y_index = padded_y;
int matchval_channel_index = padded_channel;
int matchval_batch_index = padded_batch;
int equiv_block_num = 0;
if (padded_y - pad_top_ < 0 || padded_x - pad_left_ < 0 || padded_y - pad_top_ >= old_height ||
padded_x - pad_left_ >= old_width) {
if ((padded_x < ap1_x) || (padded_x > ap2_x)) {
x_dist = (padded_x < ap1_x) ? (ap1_x - padded_x) : (padded_x - ap2_x); // GEN DIST
matchval_x_index = (padded_x < ap1_x) ? (ap1_x + x_dist - mode) : (ap2_x - x_dist + mode);
// update matching index in original tensor across all 4 dims
if ((padded_x < ap1_x) || (padded_x > ap2_x)) {
x_dist = (padded_x < ap1_x) ? (ap1_x - padded_x) : (padded_x - ap2_x);
matchval_x_index = (padded_x < ap1_x) ? (ap1_x + x_dist - mode) : (ap2_x - x_dist + mode);
}
if ((padded_y < ap1_y) || (padded_y > ap2_y)) {
y_dist = (padded_y < ap1_y) ? (ap1_y - padded_y) : (padded_y - ap2_y);
matchval_y_index = (padded_y < ap1_y) ? (ap1_y + y_dist - mode) : (ap2_y - y_dist + mode);
}
if ((padded_channel < ap1_channel) || (padded_channel > ap2_channel)) {
channel_dist = (padded_channel < ap1_channel) ? (ap1_channel - padded_channel) : (padded_channel - ap2_channel);
matchval_channel_index =
(padded_channel < ap1_channel) ? (ap1_channel + channel_dist - mode) : (ap2_channel - channel_dist + mode);
}
if ((padded_batch < ap1_batch) || (padded_batch > ap2_batch)) {
batch_dist = (padded_batch < ap1_batch) ? (ap1_batch - padded_batch) : (padded_batch - ap2_batch);
matchval_batch_index =
(padded_batch < ap1_batch) ? (ap1_batch + batch_dist - mode) : (ap2_batch - batch_dist + mode);
}
// calculate equivalent block in input
equiv_block_num = ((matchval_batch_index - paddings[BATCH + LEFT]) * old_channel) +
(matchval_channel_index - paddings[CHANNEL + LEFT]);
// copy data from equiv block and adjusted x and y values in unpadded tensor
output[pos] = input[(equiv_block_num * old_height + matchval_y_index - paddings[HEIGHT + TOP]) * old_width +
matchval_x_index - paddings[WIDTH + LEFT]];
}
return;
}
// Accumlates mirrored values across batch and channels into an interim workspace array
// One thread for every output value and a sweeping add logic allows kernel to avoid using
// slower locked based atomic adds
template <typename T>
__global__ void MirrorPadGradBatchChannel(const size_t size, T *dy, T *interim_dy, const int dx_batches,
const int dx_channels, const int dx_height, const int dx_width,
const int dy_height, const int dy_width, const int padd_dim,
const int64_t *paddings_arg, int mode, T *dx) {
int64_t paddings[MAX_PADDINGS * PADDING_SIZE]; // local and fixed size to keep in registers
for (int i = 0; i < MAX_PADDINGS * PADDING_SIZE; i++) {
paddings[i] = 0; // init all to 0
}
extract_paddings(paddings_arg, padd_dim, paddings);
// Create anchor points for non mirrored data inside new tensor
int ap1_channel = paddings[CHANNEL + LEFT];
int ap2_channel = paddings[CHANNEL + LEFT] + dx_channels - 1;
int ap1_batch = paddings[BATCH + LEFT];
int ap2_batch = paddings[BATCH + LEFT] + dx_batches - 1;
int dy_channels = dx_channels + paddings[CHANNEL + LEFT] + paddings[CHANNEL + RIGHT];
int dy_batches = dx_batches + paddings[BATCH + LEFT] + paddings[BATCH + RIGHT];
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
int block_num = (pos / dy_width) / dy_height;
// Select exact position inside the dy_interim array
const int interim_x = pos % dy_width;
const int interim_y = (pos / dy_width) % dy_height;
const int interim_channel = block_num % dx_channels;
const int interim_batch = block_num / dx_channels;
interim_dy[pos] = 0; // init
// map cur interim channel and batch to equivalent in padded dy array
const int equiv_dy_channel = interim_channel + paddings[CHANNEL + LEFT];
const int equiv_dy_batch = interim_batch + paddings[BATCH + LEFT];
int target_batch = 0;
int target_channel = 0;
int equiv_block_num = 0;
equiv_block_num = ((equiv_dy_batch * dy_channels) + equiv_dy_channel);
// generate values to sweep over all possible mirrored points
auto batch_offsets = {2 * (ap1_batch - equiv_dy_batch) - mode, 0, 2 * (ap2_batch - equiv_dy_batch) + mode};
auto channel_offsets = {2 * (ap1_channel - equiv_dy_channel) - mode, 0,
2 * (ap2_channel - equiv_dy_channel) + mode};
for (auto b_adjust : batch_offsets) {
for (auto c_adjust : channel_offsets) {
target_batch = equiv_dy_batch + b_adjust;
target_channel = equiv_dy_channel + c_adjust;
// bounds check - if within bounds, mirrored value exists - copy dy
if ((target_batch < 0) || (target_batch > (dy_batches - 1)) || (target_channel < 0) ||
(target_channel > (dy_channels - 1))) {
continue; // no mirrored value with these target values
}
equiv_block_num = ((target_batch * dy_channels) + target_channel);
// Copy data and set value at input to 0 to avoid duplicates in reflect mode
interim_dy[pos] = interim_dy[pos] + dy[(equiv_block_num * dy_height + interim_y) * dy_width + interim_x];
dy[(equiv_block_num * dy_height + interim_y) * dy_width + interim_x] = 0;
}
if ((padded_y < ap1_y) || (padded_y > ap2_y)) {
y_dist = (padded_y < ap1_y) ? (ap1_y - padded_y) : (padded_y - ap2_y);
matchval_y_index = (padded_y < ap1_y) ? (ap1_y + y_dist - mode) : (ap2_y - y_dist + mode);
}
output[pos] =
input[(block_num * old_height + matchval_y_index - pad_top_) * old_width + matchval_x_index - pad_left_];
} else {
// existing values remain the same
output[pos] = input[(block_num * old_height + padded_y - pad_top_) * old_width + padded_x - pad_left_];
}
}
return;
}
// Accumulate mirrored values across width and height from the interim dy array into output array
// Similar sweep logic again allows for a no lock based logic
template <typename T>
__global__ void MirrorPadGrad(const size_t size, const T *dy, const int num, const int channels,
const int padded_height, const int padded_width, const int old_height,
const int old_width, const int padd_dim, const int *paddings, int mode, T *dx) {
int padd_offset = 4 * (padd_dim - 2);
int pad_left_ = paddings[padd_offset + 4];
int pad_top_ = paddings[padd_offset + 0];
// Create anchor points for positions in the dy array
int ap1_x = pad_left_;
int ap1_y = pad_top_;
int ap2_x = pad_left_ + old_width - 1;
int ap2_y = pad_top_ + old_height - 1;
int adjust = 0; // adjust dist from reflection axis for symmetric padding
if (mode == 1) {
adjust = 1;
__global__ void MirrorPadGrad_Width_Height(const size_t size, const T *dy, T *interim_dy, const int dx_batches,
const int dx_channels, const int dx_height, const int dx_width,
const int dy_height, const int dy_width, const int padd_dim,
const int64_t *paddings_arg, int mode, T *dx) {
int64_t paddings[MAX_PADDINGS * PADDING_SIZE]; // local and fixed size to keep in registers
for (int i = 0; i < MAX_PADDINGS * PADDING_SIZE; i++) {
paddings[i] = 0; // init all to 0
}
extract_paddings(paddings_arg, padd_dim, paddings);
// Create required anchor points for non-mirrored data inside new tensor
int ap1_x = paddings[WIDTH + LEFT];
int ap2_x = paddings[WIDTH + LEFT] + dx_width - 1;
int ap1_y = paddings[HEIGHT + TOP];
int ap2_y = paddings[HEIGHT + TOP] + dx_height - 1;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
int block_num = (pos / old_width) / old_height;
// refer to indices of original values inside padded array
const int padded_x = (pos % old_width) + pad_left_;
const int padded_y = ((pos / old_width) % old_height) + pad_top_;
// copy positions own value into output
dx[pos] = dy[(block_num * padded_height + padded_y) * padded_width + padded_x];
int x_dist_1 = (ap1_x - padded_x - adjust);
int y_dist_1 = (ap1_y - padded_y - adjust);
int x_dist_2 = (ap2_x - padded_x + adjust);
int y_dist_2 = (ap2_y - padded_y + adjust);
int dx_block_num = (pos / dx_width) / dx_height;
const int grad_x = (pos % dx_width) + paddings[WIDTH + LEFT];
const int grad_y = ((pos / dx_width) % dx_height) + paddings[HEIGHT + TOP];
// copy position's own value into output
dx[pos] = interim_dy[(dx_block_num * dy_height + grad_y) * dy_width + grad_x];
int x_dist_1 = (ap1_x - grad_x - mode);
int y_dist_1 = (ap1_y - grad_y - mode);
int x_dist_2 = (ap2_x - grad_x + mode);
int y_dist_2 = (ap2_y - grad_y + mode);
int axis_dist[] = {x_dist_1, x_dist_2, y_dist_1, y_dist_2};
int anch_point[] = {ap1_x, ap2_x, ap1_y, ap2_y};
bool x_axis_check[] = {true, true, false, false}; // true - update X , false - update Y
int temp_x = 0;
int temp_y = 0;
// mirroring in axis lines
for (int x = 0; x < 4; x++) {
if (axis_dist[x] != 0) {
if (x_axis_check[x]) {
temp_y = padded_y;
temp_y = grad_y;
temp_x = anch_point[x] + axis_dist[x];
} else {
temp_x = padded_x;
temp_x = grad_x;
temp_y = anch_point[x] + axis_dist[x];
}
if (range_check(temp_x, temp_y, padded_width, padded_height)) {
dx[pos] = dx[pos] + dy[(block_num * padded_height + temp_y) * padded_width + temp_x];
if (range_check(temp_x, temp_y, dy_width, dy_height)) {
dx[pos] = dx[pos] + interim_dy[(dx_block_num * dy_height + temp_y) * dy_width + temp_x];
}
}
}
// mirroring at corners
for (int x = 0; x < 2; x++) {
for (int y = 2; y < 4; y++) {
if ((axis_dist[x] != 0) && (axis_dist[y] != 0)) {
temp_x = anch_point[x] + axis_dist[x];
temp_y = anch_point[y] + axis_dist[y];
if (range_check(temp_x, temp_y, padded_width, padded_height)) {
dx[pos] = dx[pos] + dy[(block_num * padded_height + temp_y) * padded_width + temp_x];
if (range_check(temp_x, temp_y, dy_width, dy_height)) {
dx[pos] = dx[pos] + interim_dy[(dx_block_num * dy_height + temp_y) * dy_width + temp_x];
}
}
}
@ -147,36 +235,49 @@ __global__ void MirrorPadGrad(const size_t size, const T *dy, const int num, con
}
template <typename T>
void CalMirrorPad(const size_t size, const T *input, const int num, const int channels, const int old_height,
void CalMirrorPad(const size_t size, const T *input, const int old_batch, const int old_channel, const int old_height,
const int old_width, const int padded_height, const int padded_width, int padd_num,
const int *paddings, const int mode, T *output, cudaStream_t cuda_stream) {
MirrorPad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, input, num, channels, old_height, old_width, padded_height, padded_width, padd_num, paddings, mode, output);
const int64_t *paddings, const int mode, T *output, cudaStream_t cuda_stream) {
MirrorPad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, old_batch, old_channel, old_height,
old_width, padded_height, padded_width, padd_num,
paddings, mode, output);
return;
}
template <typename T>
void CalMirrorPadGrad(const size_t size, const T *dy, const int num, const int channels, const int padded_height,
const int padded_width, const int old_height, const int old_width, const int padd_dim,
const int *paddings, int mode, T *dx, cudaStream_t cuda_stream) {
MirrorPadGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, num, channels, padded_height, padded_width,
old_height, old_width, padd_dim, paddings, mode, dx);
void CalMirrorPadGrad(const size_t dx_size, const size_t interim_dy_size, T *dy, T *interim_dy, const int dx_batches,
const int dx_channels, const int dx_height, const int dx_width, const int dy_height,
const int dy_width, const int padd_dim, const int64_t *paddings, int mode, T *dx,
cudaStream_t cuda_stream) {
MirrorPadGradBatchChannel<<<GET_BLOCKS(interim_dy_size), GET_THREADS, 0, cuda_stream>>>(
interim_dy_size, dy, interim_dy, dx_batches, dx_channels, dx_height, dx_width, dy_height, dy_width, padd_dim,
paddings, mode, dx);
MirrorPadGrad_Width_Height<<<GET_BLOCKS(dx_size), GET_THREADS, 0, cuda_stream>>>(
dx_size, dy, interim_dy, dx_batches, dx_channels, dx_height, dx_width, dy_height, dy_width, padd_dim, paddings,
mode, dx);
return;
}
template void CalMirrorPad<float>(const size_t size, const float *input, const int num, const int channels,
template void CalMirrorPad<float>(const size_t size, const float *input, const int old_batch, const int old_channel,
const int old_height, const int old_width, const int padded_height,
const int padded_width, int padd_num, const int *paddings, int mode, float *output,
cudaStream_t cuda_stream);
template void CalMirrorPadGrad<float>(const size_t size, const float *dy, const int num, const int channels,
const int old_height, const int old_width, const int padded_height,
const int padded_width, const int padd_dim, const int *paddings, int mode,
float *dx, cudaStream_t cuda_stream);
template void CalMirrorPad<half>(const size_t size, const half *input, const int num, const int channels,
const int padded_width, int padd_num, const int64_t *paddings, int mode,
float *output, cudaStream_t cuda_stream);
template void CalMirrorPad<half>(const size_t size, const half *input, const int old_batch, const int old_channel,
const int old_height, const int old_width, const int padded_height,
const int padded_width, int padd_num, const int *paddings, int mode, half *output,
const int padded_width, int padd_num, const int64_t *paddings, int mode, half *output,
cudaStream_t cuda_stream);
template void CalMirrorPadGrad<half>(const size_t size, const half *dy, const int num, const int channels,
const int old_height, const int old_width, const int padded_height,
const int padded_width, const int padd_dim, const int *paddings, int mode,
half *dx, cudaStream_t cuda_stream);
template void CalMirrorPad<int>(const size_t size, const int *input, const int old_batch, const int old_channel,
const int old_height, const int old_width, const int padded_height,
const int padded_width, int padd_num, const int64_t *paddings, int mode, int *output,
cudaStream_t cuda_stream);
template void CalMirrorPadGrad<float>(const size_t dx_size, const size_t dy_size, float *dy, float *interim_dy,
const int dx_batches, const int dx_channels, const int dx_height,
const int dx_width, const int dy_height, const int dy_width, const int padd_dim,
const int64_t *paddings, int mode, float *dx, cudaStream_t cuda_stream);
template void CalMirrorPadGrad<half>(const size_t dx_size, const size_t dy_size, half *dy, half *interim_dy,
const int dx_batches, const int dx_channels, const int dx_height,
const int dx_width, const int dy_height, const int dy_width, const int padd_dim,
const int64_t *paddings, int mode, half *dx, cudaStream_t cuda_stream);
template void CalMirrorPadGrad<int>(const size_t dx_size, const size_t dy_size, int *dy, int *interim_dy,
const int dx_batches, const int dx_channels, const int dx_height,
const int dx_width, const int dy_height, const int dy_width, const int padd_dim,
const int64_t *paddings, int mode, int *dx, cudaStream_t cuda_stream);

View File

@ -19,13 +19,28 @@
#include <cuda_runtime.h>
#include "runtime/device/gpu/cuda_common.h"
// preset size of paddings
#define MAX_PADDINGS 4
#define PADDING_SIZE 2
// define constants for kernel indexing use
#define BATCH 0 * PADDING_SIZE
#define CHANNEL 1 * PADDING_SIZE
#define HEIGHT 2 * PADDING_SIZE
#define WIDTH 3 * PADDING_SIZE
#define TOP 0
#define BOTTOM 1
#define LEFT 0
#define RIGHT 1
template <typename T>
void CalMirrorPad(const size_t size, const T *input, const int num, const int channels, const int old_height,
void CalMirrorPad(const size_t size, const T *input, const int old_batch, const int old_channel, const int old_height,
const int old_width, const int padded_height, const int padded_width, int padd_num,
const int *paddings, int mode, T *output, cudaStream_t cuda_stream);
const int64_t *paddings, int mode, T *output, cudaStream_t cuda_stream);
template <typename T>
void CalMirrorPadGrad(const size_t size, const T *dy, const int num, const int channels, const int padded_height,
const int padded_width, const int old_height, const int old_width, const int padd_dim,
const int *paddings, int mode, T *dx, cudaStream_t cuda_stream);
void CalMirrorPadGrad(const size_t dx_size, const size_t dy_size, T *dy, T *interim, const int output_batch,
const int output_channel, const int output_height, const int output_width, const int input_height,
const int input_width, const int padd_dim, const int64_t *paddings, int mode, T *dx,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MIRROR_PAD_IMPL_H_

View File

@ -206,3 +206,8 @@ template void CalPadGeneral<half>(const size_t size, const half *input, const in
const int old_width, const int padded_height, const int padded_width,
const int pad_top, const int pad_left, float pad_value, half *output,
cudaStream_t cuda_stream);
template void CalPadGeneral<int>(const size_t size, const int *input, const int num, const int channels_orig,
const int pad_channel_before, const int pad_channel_after, const int old_height,
const int old_width, const int padded_height, const int padded_width,
const int pad_top, const int pad_left, float pad_value, int *output,
cudaStream_t cuda_stream);

View File

@ -26,5 +26,8 @@ MS_REG_GPU_KERNEL_ONE(
MirrorPad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
MirrorPadGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(
MirrorPad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
MirrorPadGpuFwdKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -40,7 +40,7 @@ class MirrorPadGpuFwdKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
int *paddings = GetDeviceAddress<int>(inputs, 1);
int64_t *paddings = GetDeviceAddress<int64_t>(inputs, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
size_t size = output_size_ / sizeof(T);
@ -58,13 +58,11 @@ class MirrorPadGpuFwdKernel : public GpuKernel {
MS_LOG(ERROR) << "Input number is " << input_num << ", but MirrorPad needs 2 input.";
return false;
}
// check number of output -> should be 1
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but Pad needs 1 output.";
return false;
}
string mode = GetValue<string>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("mode"));
if (mode == "REFLECT") {
mode_ = 0; // reflected mirroring
@ -89,10 +87,9 @@ class MirrorPadGpuFwdKernel : public GpuKernel {
}
num_input_ = input_size_;
input_size_ *= sizeof(T);
auto padding_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
num_paddings_ = padding_shape[0];
input_size_ += 2 * num_paddings_ * sizeof(int);
input_size_ += 2 * num_paddings_ * sizeof(int64_t);
output_size_ = sizeof(T);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
@ -103,7 +100,6 @@ class MirrorPadGpuFwdKernel : public GpuKernel {
int max_width = input_shape_[3];
int max_height = input_shape_[2];
// basic error check for padding value
if (mode_ == 1) { // symmetric
max_width = max_width + (2 * max_width);
@ -112,13 +108,11 @@ class MirrorPadGpuFwdKernel : public GpuKernel {
max_width = max_width + (2 * (max_width - 1));
max_height = max_height + (2 * (max_height - 1));
}
if (output_shape_[(output_shape_.size() - 2) + 0] > max_width ||
output_shape_[(output_shape_.size() - 2) + 1] > max_width) {
MS_LOG(ERROR) << "ERROR: Padding value too high for input Tensor on 1 or more dims";
return false;
}
InitSizeLists();
return true;
}
@ -126,7 +120,7 @@ class MirrorPadGpuFwdKernel : public GpuKernel {
protected:
void InitSizeLists() override {
input_size_list_.push_back(num_input_ * sizeof(T));
input_size_list_.push_back(2 * num_paddings_ * sizeof(int));
input_size_list_.push_back(2 * num_paddings_ * sizeof(int64_t)); // for 64 bit int defined in API
output_size_list_.push_back(output_size_);
}

View File

@ -26,5 +26,9 @@ MS_REG_GPU_KERNEL_ONE(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
MirrorPadGpuBackKernel, half)
MS_REG_GPU_KERNEL_ONE(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
MirrorPadGpuBackKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -40,15 +40,15 @@ class MirrorPadGpuBackKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
int *paddings = GetDeviceAddress<int>(inputs, 1);
int64_t *paddings = GetDeviceAddress<int64_t>(inputs, 1);
T *interim = GetDeviceAddress<T>(workspace, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
size_t size = output_size_ / sizeof(T);
int dim_offset = output_shape_.size() - 2;
CalMirrorPadGrad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3],
output_shape_[dim_offset + 0], output_shape_[dim_offset + 1], num_paddings_, paddings, mode_,
output, reinterpret_cast<cudaStream_t>(stream_ptr));
size_t dx_size = output_size_ / sizeof(T);
size_t interim_dy_size = workspace_size_ / sizeof(T);
CalMirrorPadGrad(dx_size, interim_dy_size, input, interim, output_shape_[0], output_shape_[1], output_shape_[2],
output_shape_[3], input_shape_[2], input_shape_[3], num_paddings_, paddings, mode_, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -58,13 +58,11 @@ class MirrorPadGpuBackKernel : public GpuKernel {
MS_LOG(ERROR) << "Input number is " << input_num << ", but MirrorPadGrad needs 2 input.";
return false;
}
// check number of output -> should be 1
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but MirrorPadGrad needs 1 output.";
return false;
}
string mode = GetValue<string>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("mode"));
if (mode == "REFLECT") {
mode_ = 0; // reflected mirroring
@ -82,28 +80,43 @@ class MirrorPadGpuBackKernel : public GpuKernel {
auto it = input_shape.begin();
input_shape.insert(it, 2, 1); // channel padding
}
input_size_ = sizeof(T);
for (auto in_shape : input_shape) {
input_size_ *= in_shape;
input_shape_.push_back(in_shape);
}
num_input_ = input_size_;
input_size_ *= sizeof(T);
// account for paddings in input size -> passed as int64_ts
auto padding_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
num_paddings_ = padding_shape[0];
input_size_ += +(2 * num_paddings_ * sizeof(int));
input_size_ += (2 * num_paddings_ * sizeof(int64_t));
output_size_ = sizeof(T);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (output_shape.size() == 4) {
} else if (output_shape.size() == 3) {
auto it = output_shape.begin();
output_shape.insert(it, 1); // batch padding
} else if (output_shape.size() == 2) {
auto it = output_shape.begin();
output_shape.insert(it, 2, 1); // channel padding
}
output_size_ = sizeof(T);
for (auto x : output_shape) {
output_size_ *= x;
output_shape_.push_back(x);
}
// calc workspace size
// store dy values with accumulation across batch and channel only
workspace_size_ = sizeof(T);
for (int i = 0; i < 2; i++) {
workspace_size_ *= output_shape[i]; // BATCH, CHANNEL -> Output size
workspace_size_ *= input_shape[i + 2]; // WIDTH, HEIGHT -> Input Size
}
int max_width = input_shape_[3];
int max_height = input_shape_[2];
// basic error check for padding value
if (mode_ == 1) { // symmetric
max_width = max_width + (2 * max_width);
@ -112,13 +125,11 @@ class MirrorPadGpuBackKernel : public GpuKernel {
max_width = max_width + (2 * (max_width - 1));
max_height = max_height + (2 * (max_height - 1));
}
if (output_shape_[(output_shape_.size() - 2) + 0] > max_width ||
output_shape_[(output_shape_.size() - 2) + 1] > max_width) {
MS_LOG(ERROR) << "ERROR: Padding value too high for input Tensor on 1 or more DIMS";
return false;
}
InitSizeLists();
return true;
}
@ -126,7 +137,8 @@ class MirrorPadGpuBackKernel : public GpuKernel {
protected:
void InitSizeLists() override {
input_size_list_.push_back(num_input_ * sizeof(T));
input_size_list_.push_back(2 * num_paddings_ * sizeof(int));
input_size_list_.push_back(2 * num_paddings_ * sizeof(int64_t)); // for 64 bit int defined in API
workspace_size_list_.push_back(workspace_size_);
output_size_list_.push_back(output_size_);
}
@ -134,9 +146,8 @@ class MirrorPadGpuBackKernel : public GpuKernel {
size_t num_input_;
int num_paddings_;
int mode_;
std::vector<int> input_shape_; // dims of the input data
std::vector<int> output_shape_; // dims of the output data
// default
std::vector<int> input_shape_;
std::vector<int> output_shape_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;

View File

@ -22,5 +22,7 @@ MS_REG_GPU_KERNEL_ONE(Pad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
PadGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(Pad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PadGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(Pad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), PadGpuFwdKernel,
int)
} // namespace kernel
} // namespace mindspore

View File

@ -64,9 +64,9 @@ class Grad(nn.Cell):
return self.grad(self.network)(input_, output_grad)
class Net(nn.Cell):
def __init__(self):
def __init__(self, pads, mode_):
super(Net, self).__init__()
self.pad = nn.Pad(mode="REFLECT", paddings=((0, 0), (0, 0), (1, 0), (0, 2)))
self.pad = nn.Pad(mode=mode_, paddings=pads)
def construct(self, x):
return self.pad(x)
@ -82,7 +82,88 @@ def test_mirror_pad_backprop():
expected_dx = np.array([[[[0.2, 0.2, 0.1],
[0.4, 0.4, 0.2],
[0.2, 0.2, 0.1]]]])
net = Grad(Net())
net = Grad(Net(((0, 0), (0, 0), (1, 0), (0, 2)), "REFLECT"))
dx = net(test_arr_in, Tensor(dy))
dx = dx[0].asnumpy()
np.testing.assert_array_almost_equal(dx, expected_dx)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mirror_pad_fwd_back_4d_int32_reflect():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# set constants
shape = (2, 3, 3, 5)
pads = ((1, 0), (2, 0), (1, 2), (3, 4))
total_val = np.prod(shape)
test_arr_np = np.arange(total_val).reshape(shape) + 1
test_arr_ms = Tensor(test_arr_np, dtype=mindspore.int32)
# fwd_pass_check
op = nn.Pad(mode="REFLECT", paddings=pads)
expected_np_result = np.pad(test_arr_np, pads, 'reflect')
obtained_ms_res = op(test_arr_ms).asnumpy()
np.testing.assert_array_equal(expected_np_result, obtained_ms_res)
# backwards pass check
GradNet = Grad(Net(pads, "REFLECT"))
dy_value = Tensor(np.ones(obtained_ms_res.shape), dtype=mindspore.int32)
dx_value_obtained = GradNet(test_arr_ms, dy_value)[0].asnumpy()
dx_value_expected = np.array([[[[4, 6, 6, 6, 2],
[6, 9, 9, 9, 3],
[2, 3, 3, 3, 1]],
[[8, 12, 12, 12, 4],
[12, 18, 18, 18, 6],
[4, 6, 6, 6, 2]],
[[8, 12, 12, 12, 4],
[12, 18, 18, 18, 6],
[4, 6, 6, 6, 2]]],
[[[8, 12, 12, 12, 4],
[12, 18, 18, 18, 6],
[4, 6, 6, 6, 2]],
[[16, 24, 24, 24, 8],
[24, 36, 36, 36, 12],
[8, 12, 12, 12, 4]],
[[16, 24, 24, 24, 8],
[24, 36, 36, 36, 12],
[8, 12, 12, 12, 4]]]], dtype=np.int32)
np.testing.assert_array_equal(dx_value_expected, dx_value_obtained)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mirror_pad_fwd_back_4d_int32_symm():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# set constants
shape = (2, 3, 3, 5)
pads = ((1, 0), (2, 0), (1, 2), (3, 4))
total_val = np.prod(shape)
test_arr_np = np.arange(total_val).reshape(shape) + 1
test_arr_ms = Tensor(test_arr_np, dtype=mindspore.int32)
# fwd_pass_check
op = nn.Pad(mode="SYMMETRIC", paddings=pads)
expected_np_result = np.pad(test_arr_np, pads, 'symmetric')
obtained_ms_res = op(test_arr_ms).asnumpy()
np.testing.assert_array_equal(expected_np_result, obtained_ms_res)
# backwards pass check
GradNet = Grad(Net(pads, "SYMMETRIC"))
dy_value = Tensor(np.ones(obtained_ms_res.shape), dtype=mindspore.int32)
dx_value_obtained = GradNet(test_arr_ms, dy_value)[0].asnumpy()
dx_value_expected = np.array([[[[16, 24, 24, 16, 16],
[16, 24, 24, 16, 16],
[16, 24, 24, 16, 16]],
[[16, 24, 24, 16, 16],
[16, 24, 24, 16, 16],
[16, 24, 24, 16, 16]],
[[8, 12, 12, 8, 8],
[8, 12, 12, 8, 8],
[8, 12, 12, 8, 8]]],
[[[8, 12, 12, 8, 8],
[8, 12, 12, 8, 8],
[8, 12, 12, 8, 8]],
[[8, 12, 12, 8, 8],
[8, 12, 12, 8, 8],
[8, 12, 12, 8, 8]],
[[4, 6, 6, 4, 4],
[4, 6, 6, 4, 4],
[4, 6, 6, 4, 4]]]], dtype=np.int32)
np.testing.assert_array_equal(dx_value_expected, dx_value_obtained)