forked from mindspore-Ecosystem/mindspore
gpu layernorm
This commit is contained in:
parent
699e616b3a
commit
4d600e70f1
|
@ -35,8 +35,8 @@ inline __device__ half my_pow(half a, double b) {
|
|||
|
||||
template <typename T>
|
||||
inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim,
|
||||
const T &epsilon, const T *dy, const T *x, const T *mean, const T *var,
|
||||
T *dg, T *db) {
|
||||
const int &mean_dim, const T &epsilon, const T *dy, const T *x,
|
||||
const T *mean, const T *var, T *dg, T *db) {
|
||||
int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
||||
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
||||
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
|
||||
|
@ -46,7 +46,8 @@ inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_d
|
|||
}
|
||||
|
||||
int pos = row * col_dim + col;
|
||||
dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]);
|
||||
int mean_offset = pos / mean_dim;
|
||||
dg[0] += dy[pos] * my_pow(var[mean_offset] + epsilon, -0.5) * (x[pos] - mean[mean_offset]);
|
||||
db[0] += dy[pos];
|
||||
}
|
||||
}
|
||||
|
@ -89,8 +90,9 @@ inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_di
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T *dy, const T *x,
|
||||
const T *mean_addr, const T *var_addr, T *dg_addr, T *db_addr) {
|
||||
__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon,
|
||||
const T *dy, const T *x, const T *mean_addr, const T *var_addr, T *dg_addr,
|
||||
T *db_addr) {
|
||||
// row: [0:param_axis]
|
||||
// col: [param_axis:]
|
||||
// dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i])
|
||||
|
@ -98,7 +100,7 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con
|
|||
for (int col = blockIdx.x; col < col_dim; col += gridDim.x) {
|
||||
T dg = 0;
|
||||
T db = 0;
|
||||
GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db);
|
||||
GammaAndBetaThreadReduce(col, row_dim, col_dim, mean_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db);
|
||||
GammaAndBetaWarpReduce(&dg, &db);
|
||||
GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr);
|
||||
}
|
||||
|
@ -239,8 +241,12 @@ void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim,
|
|||
mean, var, gamma, dx);
|
||||
|
||||
share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T);
|
||||
GammaAndBetaPropKernel<<<col_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, mean,
|
||||
var, dg, db);
|
||||
// GammaAndBetaPropKernel<<<col_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x,
|
||||
// mean,
|
||||
// var, dg, db);
|
||||
int param_reduce_dim = row_dim * col_dim / param_dim;
|
||||
GammaAndBetaPropKernel<<<param_dim, thread_per_block, share_mem_size, stream>>>(param_reduce_dim, param_dim, col_dim,
|
||||
epsilon, dy, x, mean, var, dg, db);
|
||||
}
|
||||
|
||||
template void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon,
|
||||
|
|
|
@ -193,3 +193,29 @@ def test_layernormgrad4():
|
|||
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
|
||||
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgrad5():
|
||||
begin_norm_axis = 2
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
|
||||
dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
epsilon = 10e-12
|
||||
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np)
|
||||
x_ms = Tensor(x_np)
|
||||
var_ms = Tensor(var_np)
|
||||
mean_ms = Tensor(mean_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
|
||||
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
|
||||
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
|
||||
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
|
||||
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
|
||||
|
|
|
@ -175,3 +175,25 @@ def test_layernorm2d_3():
|
|||
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernorm2d_4():
|
||||
begin_norm_axis = 2
|
||||
begin_params_axis = 1
|
||||
np.random.seed(42)
|
||||
x_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
|
||||
|
||||
x_ms = Tensor(x_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
beta_ms = Tensor(beta_np)
|
||||
net = LayerNormNet(begin_norm_axis, begin_params_axis)
|
||||
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
|
||||
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)
|
||||
|
|
Loading…
Reference in New Issue