forked from mindspore-Ecosystem/mindspore
!5790 [MS][GPU][CUDA] Dedicated new user facing Pad API kernel
Merge pull request !5790 from danishnxt/GPU_three
This commit is contained in:
commit
98725bc865
|
@ -18,6 +18,7 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
|
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
|
||||||
|
|
||||||
|
// For internal OP use, not user facing
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height,
|
__global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height,
|
||||||
const int old_width, const int padded_height, const int padded_width, const int pad_top,
|
const int old_width, const int padded_height, const int padded_width, const int pad_top,
|
||||||
|
@ -37,6 +38,7 @@ __global__ void Pad(const size_t size, const T* input, const int num, const int
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For internal OP use, not user facing
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void PadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width,
|
__global__ void PadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width,
|
||||||
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
||||||
|
@ -57,6 +59,37 @@ __global__ void PadNHWC(const size_t size, const T* input, const int num, const
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Used by user facing 'Pad' API
|
||||||
|
template <typename T>
|
||||||
|
__global__ void PadGeneral(const size_t size, const T *input, const int num, const int channels_orig,
|
||||||
|
const int pad_channel_before, const int pad_channel_after, const int old_height,
|
||||||
|
const int old_width, const int padded_height, const int padded_width, const int pad_top,
|
||||||
|
const int pad_left, float pad_value, T *output) {
|
||||||
|
T pad_value_template = static_cast<T>(pad_value);
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||||
|
int block_num = (pos / padded_width) / padded_height; // total blocks = (batch * channels)
|
||||||
|
const int padded_w = pos % padded_width; // x coordinate refered to by cur 'pos'
|
||||||
|
const int padded_h = (pos / padded_width) % padded_height; // y coordinate refered to by cur 'pos'
|
||||||
|
|
||||||
|
int channels_new = channels_orig + pad_channel_after + pad_channel_before; // new number of channels from padding
|
||||||
|
int channel_num = block_num % channels_new; // current channel
|
||||||
|
int batch_item = block_num / channels_new; // current item in batch
|
||||||
|
int equiv_block_num = 0; // init variable to select equivalent block to copy data from from input
|
||||||
|
|
||||||
|
if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height ||
|
||||||
|
padded_w - pad_left >= old_width || channel_num <= pad_channel_before - 1 ||
|
||||||
|
channel_num > channels_orig + pad_channel_before - 1) {
|
||||||
|
output[pos] = pad_value_template;
|
||||||
|
} else {
|
||||||
|
// on a block/x,y positon that isn't padding, copy data from the correct block/x,y pos the input
|
||||||
|
// calculate from number of blocks of padding (due to channel padding) inserted prior
|
||||||
|
equiv_block_num = block_num - (batch_item * (pad_channel_before + pad_channel_after)) - pad_channel_before;
|
||||||
|
output[pos] = input[(equiv_block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void PadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width,
|
__global__ void PadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width,
|
||||||
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
||||||
|
@ -102,6 +135,17 @@ void CalPadNHWC(const size_t size, const T* input, const int num, const int old_
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig,
|
||||||
|
const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width,
|
||||||
|
const int padded_height, const int padded_width, const int pad_top, const int pad_left,
|
||||||
|
float pad_value, T *output, cudaStream_t cuda_stream) {
|
||||||
|
PadGeneral<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, num, channels_orig, pad_channel_before,
|
||||||
|
pad_channel_after, old_height, old_width, padded_height,
|
||||||
|
padded_width, pad_top, pad_left, pad_value, output);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalPadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width,
|
void CalPadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width,
|
||||||
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
||||||
|
@ -152,3 +196,13 @@ template void CalPadGradNHWC<half>(const size_t size, const half* dy, const int
|
||||||
const int old_width, const int channels, const int padded_height,
|
const int old_width, const int channels, const int padded_height,
|
||||||
const int padded_width, const int pad_top, const int pad_left, half* dx,
|
const int padded_width, const int pad_top, const int pad_left, half* dx,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
|
template void CalPadGeneral<float>(const size_t size, const float *input, const int num, const int channels_orig,
|
||||||
|
const int pad_channel_before, const int pad_channel_after, const int old_height,
|
||||||
|
const int old_width, const int padded_height, const int padded_width,
|
||||||
|
const int pad_top, const int pad_left, float pad_value, float *output,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
template void CalPadGeneral<half>(const size_t size, const half *input, const int num, const int channels_orig,
|
||||||
|
const int pad_channel_before, const int pad_channel_after, const int old_height,
|
||||||
|
const int old_width, const int padded_height, const int padded_width,
|
||||||
|
const int pad_top, const int pad_left, float pad_value, half *output,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
|
|
@ -31,9 +31,13 @@ template <typename T>
|
||||||
void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width,
|
void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width,
|
||||||
const int channels, const int padded_height, const int padded_width, const int pad_top, const int pad_left,
|
const int channels, const int padded_height, const int padded_width, const int pad_top, const int pad_left,
|
||||||
float pad_value, T* output, cudaStream_t cuda_stream);
|
float pad_value, T* output, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalPadGradNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width,
|
void CalPadGradNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width,
|
||||||
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
||||||
const int pad_left, T* output, cudaStream_t cuda_stream);
|
const int pad_left, T* output, cudaStream_t cuda_stream);
|
||||||
|
template <typename T>
|
||||||
|
void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig,
|
||||||
|
const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width,
|
||||||
|
const int padded_height, const int padded_width, const int pad_top, const int pad_left,
|
||||||
|
float pad_value, T *output, cudaStream_t cuda_stream);
|
||||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_
|
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_
|
||||||
|
|
|
@ -42,9 +42,12 @@ class PadGpuFwdKernel : public GpuKernel {
|
||||||
size_t size = output_size_ / sizeof(T);
|
size_t size = output_size_ / sizeof(T);
|
||||||
int pad_left = paddings[3][0];
|
int pad_left = paddings[3][0];
|
||||||
int pad_top = paddings[2][0];
|
int pad_top = paddings[2][0];
|
||||||
|
int pad_channel_before = paddings[1][0];
|
||||||
|
int pad_channel_after = paddings[1][1];
|
||||||
T pad_value = 0.0;
|
T pad_value = 0.0;
|
||||||
CalPad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output_shape_[2],
|
CalPadGeneral(size, input, input_shape_[0], input_shape_[1], pad_channel_before, pad_channel_after, input_shape_[2],
|
||||||
output_shape_[3], pad_top, pad_left, pad_value, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
input_shape_[3], output_shape_[2], output_shape_[3], pad_top, pad_left, pad_value, output,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -470,6 +470,8 @@ class Pad(Cell):
|
||||||
for item in paddings:
|
for item in paddings:
|
||||||
if len(item) != 2:
|
if len(item) != 2:
|
||||||
raise ValueError('The shape of paddings must be (n, 2).')
|
raise ValueError('The shape of paddings must be (n, 2).')
|
||||||
|
if len(paddings) > 4:
|
||||||
|
raise ValueError('Only padding up to 4 dims is supported')
|
||||||
if mode == "CONSTANT":
|
if mode == "CONSTANT":
|
||||||
self.pad = P.Pad(self.paddings)
|
self.pad = P.Pad(self.paddings)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,204 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.context as context
|
||||||
|
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pad_basic():
|
||||||
|
# confirm array is being padded with 0's
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
test_arr = np.array([[1, 2], [3, 4]]).astype(np.float32)
|
||||||
|
test_arr_expected = np.array(
|
||||||
|
[[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]).astype(np.float32)
|
||||||
|
x_test = Tensor(test_arr, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
pad_op = nn.Pad(mode='CONSTANT', paddings=((1, 1), (1, 1)))
|
||||||
|
y_test = pad_op(x_test).asnumpy()
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(y_test, test_arr_expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pad_row():
|
||||||
|
# Confirm correct row padding
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
test_arr_1 = np.random.rand(40, 40).astype(np.float32)
|
||||||
|
test_paddings_1 = ((2, 3), (0, 0))
|
||||||
|
|
||||||
|
test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32)
|
||||||
|
test_paddings_2 = ((0, 0), (0, 0), (3, 0), (0, 0))
|
||||||
|
|
||||||
|
pad_op_row_1 = nn.Pad(mode='CONSTANT', paddings=test_paddings_1)
|
||||||
|
pad_op_row_2 = nn.Pad(mode='CONSTANT', paddings=test_paddings_2)
|
||||||
|
|
||||||
|
x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32)
|
||||||
|
x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32)
|
||||||
|
|
||||||
|
y_test_1 = pad_op_row_1(x_test_1).asnumpy()
|
||||||
|
y_test_2 = pad_op_row_2(x_test_2).asnumpy()
|
||||||
|
|
||||||
|
# check size
|
||||||
|
assert y_test_1.shape == (45, 40)
|
||||||
|
assert y_test_2.shape == (3, 10, 33, 30)
|
||||||
|
|
||||||
|
# check values - select correct sections
|
||||||
|
np.testing.assert_equal(y_test_1[2:-3, :], test_arr_1)
|
||||||
|
np.testing.assert_equal(y_test_2[:, :, 3:, :], test_arr_2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pad_column():
|
||||||
|
# Confirm correct column padding
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
test_arr_1 = np.random.randn(40, 40).astype(np.float32)
|
||||||
|
test_paddings_1 = ((0, 0), (3, 3))
|
||||||
|
|
||||||
|
test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32)
|
||||||
|
test_paddings_2 = ((0, 0), (0, 0), (0, 0), (6, 1))
|
||||||
|
|
||||||
|
pad_op_col_1 = nn.Pad(mode='CONSTANT', paddings=test_paddings_1)
|
||||||
|
pad_op_col_2 = nn.Pad(mode='CONSTANT', paddings=test_paddings_2)
|
||||||
|
|
||||||
|
x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32)
|
||||||
|
x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32)
|
||||||
|
|
||||||
|
y_test_1 = pad_op_col_1(x_test_1).asnumpy()
|
||||||
|
y_test_2 = pad_op_col_2(x_test_2).asnumpy()
|
||||||
|
|
||||||
|
# check size
|
||||||
|
assert y_test_1.shape == (40, 46)
|
||||||
|
assert y_test_2.shape == (3, 10, 30, 37)
|
||||||
|
|
||||||
|
# check values - select correct sections - should match
|
||||||
|
np.testing.assert_equal(y_test_1[:, 3:-3], test_arr_1)
|
||||||
|
np.testing.assert_equal(y_test_2[:, :, :, 6:-1], test_arr_2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pad_3d_pad():
|
||||||
|
# Confirm correct 3d padding - row, column, channel
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32)
|
||||||
|
test_paddings = ((0, 0), (2, 1), (0, 1), (0, 2)) # padding 3 dims now
|
||||||
|
|
||||||
|
pad_op_3d = nn.Pad(mode='CONSTANT', paddings=test_paddings)
|
||||||
|
x_test = Tensor(np.array(test_arr), dtype=mindspore.float32)
|
||||||
|
|
||||||
|
y_test = pad_op_3d(x_test).asnumpy()
|
||||||
|
assert y_test.shape == (5, 6, 31, 32)
|
||||||
|
np.testing.assert_equal(test_arr, y_test[:, 2:-1, :-1, :-2])
|
||||||
|
|
||||||
|
|
||||||
|
# For testing backprop
|
||||||
|
class Grad(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(Grad, self).__init__()
|
||||||
|
self.grad = GradOperation(get_all=True, sens_param=True)
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, input_, output_grad):
|
||||||
|
return self.grad(self.network)(input_, output_grad)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.pad = nn.Pad(mode="CONSTANT", paddings=(
|
||||||
|
(0, 0), (4, 3), (1, 1), (0, 2)))
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.pad(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pad_3d_backprop():
|
||||||
|
# Confirm correct 3d padding backprop
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32)
|
||||||
|
x_test = Tensor(test_arr, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
padded_shape = (5, 10, 32, 32)
|
||||||
|
dy = np.random.randn(*padded_shape).astype(np.float32)
|
||||||
|
expected_dx = dy[:, 4:-3, 1:-1, :-2]
|
||||||
|
|
||||||
|
net = Grad(Net())
|
||||||
|
dx = net(x_test, Tensor(dy))
|
||||||
|
dx = dx[0].asnumpy()
|
||||||
|
np.testing.assert_array_equal(dx, expected_dx)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pad_error_cases():
|
||||||
|
# Test against common errorneous inputs to catch correctly
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
# TEST 1 - Neg padding values
|
||||||
|
test_op = nn.Pad(paddings=((0, 0), (-1, -1)), mode="CONSTANT")
|
||||||
|
test_arr = np.random.randn(3, 3)
|
||||||
|
test_arr_ms = Tensor(test_arr, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
test_op(test_arr_ms)
|
||||||
|
|
||||||
|
# TEST 2 - Mismatched input size and paddings - 1D tensor
|
||||||
|
test_op = nn.Pad(paddings=((0, 0), (1, 0)), mode="CONSTANT")
|
||||||
|
test_arr = np.random.randn(3) # 1D Tensor
|
||||||
|
test_arr_ms = Tensor(test_arr, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
test_op(test_arr_ms)
|
||||||
|
|
||||||
|
# TEST 3 - Mismatched input size and paddings - 2D tensor, 3D padding
|
||||||
|
test_op = nn.Pad(paddings=((0, 0), (1, 0)), mode="CONSTANT") # 2D Padding
|
||||||
|
test_arr = np.random.randn(1, 3, 3) # 3D Tensor
|
||||||
|
test_arr_ms = Tensor(test_arr, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
test_op(test_arr_ms)
|
||||||
|
|
||||||
|
# TEST 4 - 1D Paddings should not work
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
test_op = nn.Pad(paddings=((0, 2)), mode="CONSTANT")
|
||||||
|
|
||||||
|
# TEST 5 - Padding beyond 4d - (added check in nn file in PR)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = nn.Pad(paddings=((0, 0), (0, 0,), (0, 0), (0, 0),
|
||||||
|
(1, 0)), mode="CONSTANT") # 2D Padding
|
Loading…
Reference in New Issue