forked from mindspore-Ecosystem/mindspore
!756 Gpu support LayerNorm kernel
Merge pull request !756 from chenweifeng/layer_norm
This commit is contained in:
commit
8c035a5171
|
@ -0,0 +1,205 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh"
|
||||
|
||||
constexpr int NUM_PER_THREAD_REDUCE = 4;
|
||||
constexpr int WARP_SIZE = 32;
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim,
|
||||
const T& epsilon, const T* dy, const T* x, const T* mean, const T* var,
|
||||
T* dg, T* db) {
|
||||
int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
||||
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
||||
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
|
||||
int row = NUM_PER_THREAD_REDUCE * i + j;
|
||||
if (row >= row_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos = row * col_dim + col;
|
||||
dg[0] += dy[pos] * pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]);
|
||||
db[0] += dy[pos];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) {
|
||||
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
||||
dg[0] += __shfl_down_sync(0xffffffff, dg[0], delta);
|
||||
db[0] += __shfl_down_sync(0xffffffff, db[0], delta);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_dim, T* dg, T* db, T* dg_addr,
|
||||
T* db_addr) {
|
||||
if (threadIdx.x >= row_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
// load data to share memory
|
||||
// thread(0, 32, 64, 96, ...) keep the data
|
||||
extern __shared__ T share_mem[];
|
||||
if (threadIdx.x % WARP_SIZE == 0) {
|
||||
int offset = threadIdx.x / WARP_SIZE * 2;
|
||||
share_mem[offset] = dg[0];
|
||||
share_mem[offset + 1] = db[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride) {
|
||||
int offset = (threadIdx.x + stride) * 2;
|
||||
share_mem[threadIdx.x * 2] += share_mem[offset];
|
||||
share_mem[threadIdx.x * 2 + 1] += share_mem[offset + 1];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dg_addr[col] = share_mem[0];
|
||||
db_addr[col] = share_mem[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T* dy, const T* x,
|
||||
const T* mean_addr, const T* var_addr, T* dg_addr, T* db_addr) {
|
||||
// row: [0:param_axis]
|
||||
// col: [param_axis:]
|
||||
// dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i])
|
||||
// dg[j] = \Sigma_{j}dg[i][j]
|
||||
for (int col = blockIdx.x; col < col_dim; col += gridDim.x) {
|
||||
T dg = 0;
|
||||
T db = 0;
|
||||
GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db);
|
||||
GammaAndBetaWarpReduce(&dg, &db);
|
||||
GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const T& epsilon,
|
||||
T* sum1, T* sum2, T* sum3, const T* dy, const T* x, const T* mean,
|
||||
const T* var, const T* gamma) {
|
||||
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
||||
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
||||
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
|
||||
int col = NUM_PER_THREAD_REDUCE * i + j;
|
||||
if (col >= col_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos = row * col_dim + col;
|
||||
int gamma_offset = pos % param_dim;
|
||||
T v1 = dy[pos] * gamma[gamma_offset];
|
||||
T v2 = x[pos] - mean[row];
|
||||
|
||||
sum1[0] += -0.5 * v1 * v2 * pow(var[row] + epsilon, -1.5);
|
||||
sum2[0] += v1;
|
||||
sum3[0] += -2.0 * v2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) {
|
||||
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
||||
sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta);
|
||||
sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta);
|
||||
sum3[0] += __shfl_down_sync(0xffffffff, sum3[0], delta);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* sum3, T* share_mem) {
|
||||
if (threadIdx.x >= col_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
// load data to share memory
|
||||
// thread(0, 32, 64, 96, ...) keep the data
|
||||
if (threadIdx.x % WARP_SIZE == 0) {
|
||||
int offset = threadIdx.x / WARP_SIZE * 3;
|
||||
share_mem[offset] = sum1[0];
|
||||
share_mem[offset + 1] = sum2[0];
|
||||
share_mem[offset + 2] = sum3[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride) {
|
||||
int offset = (threadIdx.x + stride) * 3;
|
||||
share_mem[threadIdx.x * 3] += share_mem[offset];
|
||||
share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1];
|
||||
share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const T& epsilon,
|
||||
const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx,
|
||||
const T* share_mem) {
|
||||
for (int col = threadIdx.x; col < col_dim; col += blockDim.x) {
|
||||
int pos = (row * col_dim + col);
|
||||
int gamma_offset = pos % param_dim;
|
||||
T v1 = dy[pos] * gamma[gamma_offset];
|
||||
T v2 = x[pos] - mean[row];
|
||||
T v3 = pow(var[row] + epsilon, -0.5);
|
||||
dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 +
|
||||
(-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy,
|
||||
const T* x, const T* mean, const T* var, const T* gamma, T* dx) {
|
||||
for (int row = blockIdx.x; row < row_dim; row += gridDim.x) {
|
||||
T sum1 = 0;
|
||||
T sum2 = 0;
|
||||
T sum3 = 0;
|
||||
extern __shared__ T share_mem[];
|
||||
InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma);
|
||||
InputWarpReduce(&sum1, &sum2, &sum3);
|
||||
InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem);
|
||||
InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy,
|
||||
const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) {
|
||||
int share_mem =
|
||||
((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T);
|
||||
InputPropKernel<<<row_dim, 256, share_mem, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, gamma,
|
||||
dx);
|
||||
|
||||
share_mem =
|
||||
((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T);
|
||||
GammaAndBetaPropKernel<<<col_dim, 256, share_mem, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db);
|
||||
}
|
||||
|
||||
template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon,
|
||||
const float* dy, const float* x, const float* mean, const float* var, const float* gamma,
|
||||
float* dx, float* dg, float* db, cudaStream_t stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_
|
||||
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy,
|
||||
const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_
|
|
@ -0,0 +1,148 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh"
|
||||
|
||||
constexpr int NUM_PER_THREAD_REDUCE = 4;
|
||||
constexpr int WARP_SIZE = 32;
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void MeanAndVarAccumulation(T* mean, T* var, T* num, const T& val) {
|
||||
// Welford Algorithm:
|
||||
// \mu_k = \mu_{k-1} + (x_k - \mu_{k-1})/k
|
||||
// \sigma_k^2 = \sigma_{k-1}^2 + (x_k - \mu_{k-1}) * (x_k - \mu_k)
|
||||
num[0]++;
|
||||
T mean_new = mean[0] + (val - mean[0]) / num[0];
|
||||
var[0] = var[0] + (val - mean[0]) * (val - mean_new);
|
||||
mean[0] = mean_new;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void MeanAndVarMerge(T* m1, T* v1, T* n1, const T& m2, const T& v2, const T& n2) {
|
||||
if (n2 == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
T count = n1[0] + n2;
|
||||
v1[0] = v1[0] + v2 + (m1[0] - m2) * (m1[0] - m2) * n1[0] * n2 / count;
|
||||
m1[0] = (n1[0] * m1[0] + n2 * m2) / count;
|
||||
n1[0] = count;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void ThreadReduce(const int& col_dim, const T* block_addr, T* mean, T* var, T* num) {
|
||||
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
||||
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
||||
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
|
||||
int pos = NUM_PER_THREAD_REDUCE * i + j;
|
||||
if (pos >= col_dim) {
|
||||
return;
|
||||
}
|
||||
MeanAndVarAccumulation(mean, var, num, block_addr[pos]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void WarpReduce(T* mean, T* var, T* num) {
|
||||
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
||||
T mean_other = __shfl_down_sync(0xffffffff, mean[0], delta);
|
||||
T var_other = __shfl_down_sync(0xffffffff, var[0], delta);
|
||||
T num_other = __shfl_down_sync(0xffffffff, num[0], delta);
|
||||
MeanAndVarMerge(mean, var, num, mean_other, var_other, num_other);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void BlockReduce(const int& col_dim, T* mean, T* var, T* num, T* mean_addr, T* var_addr,
|
||||
T* share_mem) {
|
||||
if (threadIdx.x >= col_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
// load data to share memory
|
||||
// thread(0, 32, 64, 96, ...) keep the data
|
||||
if (threadIdx.x % WARP_SIZE == 0) {
|
||||
int offset = threadIdx.x / WARP_SIZE * 3;
|
||||
share_mem[offset] = mean[0];
|
||||
share_mem[offset + 1] = var[0];
|
||||
share_mem[offset + 2] = num[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride) {
|
||||
int offset = (threadIdx.x + stride) * 3;
|
||||
MeanAndVarMerge(&share_mem[threadIdx.x * 3], &share_mem[threadIdx.x * 3 + 1], &share_mem[threadIdx.x * 3 + 2],
|
||||
share_mem[offset], share_mem[offset + 1], share_mem[offset + 2]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
mean_addr[blockIdx.x] = share_mem[0]; // todo: blockDim.x < row
|
||||
share_mem[1] /= col_dim;
|
||||
var_addr[blockIdx.x] = share_mem[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void LayerNorm(const int& row, const int& col_dim, const int& param_dim, const T* x,
|
||||
const T* share_mem, const T* gamma, const T* beta, const T epsilon, T* y) {
|
||||
for (int col = threadIdx.x; col < col_dim; col += blockDim.x) {
|
||||
int pos = row * col_dim + col;
|
||||
int i = pos % param_dim;
|
||||
y[pos] = (x[pos] - share_mem[0]) / sqrt(share_mem[1] + epsilon) * gamma[i] + beta[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* x,
|
||||
const T* gamma, const T* beta, T* y, T* mean_addr, T* var_addr) {
|
||||
for (auto row = blockIdx.x; row < row_dim; row += gridDim.x) {
|
||||
T mean = 0;
|
||||
T var = 0;
|
||||
T num = 0;
|
||||
const T* block_addr = x + row * col_dim;
|
||||
extern __shared__ T share_mem[];
|
||||
|
||||
ThreadReduce(col_dim, block_addr, &mean, &var, &num);
|
||||
WarpReduce(&mean, &var, &num);
|
||||
BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem);
|
||||
|
||||
__syncthreads();
|
||||
LayerNorm(row, col_dim, param_dim, x, share_mem, gamma, beta, epsilon, y);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* x,
|
||||
const T* gamma, const T* beta, T* y, T* mean, T* var, cudaStream_t stream) {
|
||||
const dim3 block(row_dim);
|
||||
const dim3 thread(256);
|
||||
// keep the mean/var/num after warp reduce
|
||||
int share_mem =
|
||||
((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T);
|
||||
LayerNormKernel<<<block, thread, share_mem, stream>>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, mean,
|
||||
var);
|
||||
}
|
||||
|
||||
template void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon,
|
||||
const float* x, const float* gamma, const float* beta, float* y, float* mean, float* var,
|
||||
cudaStream_t stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_
|
||||
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void LayerNorm(const int& outer, const int& inner, const int& param_dim, const T& epsilon, const T* x, const T* gamma,
|
||||
const T* beta, T* y, T* mean, T* var, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/nn/layer_norm_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(LayerNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
LayerNormGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class LayerNormGpuKernel : public GpuKernel {
|
||||
public:
|
||||
LayerNormGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {}
|
||||
~LayerNormGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
|
||||
auto x = GetDeviceAddress<T>(inputs, 0);
|
||||
auto gamma = GetDeviceAddress<T>(inputs, 1);
|
||||
auto beta = GetDeviceAddress<T>(inputs, 2);
|
||||
auto y = GetDeviceAddress<T>(outputs, 0);
|
||||
auto mean = GetDeviceAddress<T>(outputs, 1);
|
||||
auto variance = GetDeviceAddress<T>(outputs, 2);
|
||||
|
||||
T epsilon = 10e-12;
|
||||
LayerNorm(input_row_, input_col_, param_dim_, epsilon, x, gamma, beta, y, mean, variance,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
int begin_norm_axis = GetAttr<int>(kernel_node, "begin_norm_axis");
|
||||
int begin_params_axis = GetAttr<int>(kernel_node, "begin_params_axis");
|
||||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (begin_norm_axis < 0) {
|
||||
begin_norm_axis += input_shape.size();
|
||||
}
|
||||
|
||||
if (begin_params_axis < 0) {
|
||||
begin_params_axis += input_shape.size();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) {
|
||||
input_row_ *= input_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = begin_norm_axis; i < input_shape.size(); i++) {
|
||||
input_col_ *= input_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = begin_params_axis; i < input_shape.size(); i++) {
|
||||
param_dim_ *= input_shape[i];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_row_ * input_col_ * sizeof(T));
|
||||
input_size_list_.push_back(param_dim_ * sizeof(T));
|
||||
input_size_list_.push_back(param_dim_ * sizeof(T));
|
||||
|
||||
output_size_list_.push_back(input_row_ * input_col_ * sizeof(T));
|
||||
output_size_list_.push_back(input_row_ * sizeof(T));
|
||||
output_size_list_.push_back(input_row_ * sizeof(T));
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
int input_row_;
|
||||
int input_col_;
|
||||
int param_dim_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/nn/layer_norm_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(LayerNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
LayerNormGradGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
#include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class LayerNormGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
LayerNormGradGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {}
|
||||
~LayerNormGradGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
|
||||
auto dy = GetDeviceAddress<T>(inputs, 0);
|
||||
auto x = GetDeviceAddress<T>(inputs, 1);
|
||||
auto var = GetDeviceAddress<T>(inputs, 2);
|
||||
auto mean = GetDeviceAddress<T>(inputs, 3);
|
||||
auto gamma = GetDeviceAddress<T>(inputs, 4);
|
||||
auto dx = GetDeviceAddress<T>(outputs, 0);
|
||||
auto dg = GetDeviceAddress<T>(outputs, 1);
|
||||
auto db = GetDeviceAddress<T>(outputs, 2);
|
||||
|
||||
T epsilon = 10e-12;
|
||||
LayerNormGrad(input_row_, input_col_, param_dim_, epsilon, dy, x, mean, var, gamma, dx, dg, db,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
int begin_norm_axis = GetAttr<int>(kernel_node, "begin_norm_axis");
|
||||
int begin_params_axis = GetAttr<int>(kernel_node, "begin_params_axis");
|
||||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (begin_norm_axis < 0) {
|
||||
begin_norm_axis += input_shape.size();
|
||||
}
|
||||
|
||||
if (begin_params_axis < 0) {
|
||||
begin_params_axis += input_shape.size();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) {
|
||||
input_row_ *= input_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = begin_norm_axis; i < input_shape.size(); i++) {
|
||||
input_col_ *= input_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = begin_params_axis; i < input_shape.size(); i++) {
|
||||
param_dim_ *= input_shape[i];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_row_ * input_col_ * sizeof(T));
|
||||
input_size_list_.push_back(input_row_ * input_col_ * sizeof(T));
|
||||
input_size_list_.push_back(input_row_ * sizeof(T));
|
||||
input_size_list_.push_back(input_row_ * sizeof(T));
|
||||
input_size_list_.push_back(param_dim_ * sizeof(T));
|
||||
|
||||
output_size_list_.push_back(input_row_ * input_col_ * sizeof(T));
|
||||
output_size_list_.push_back(param_dim_ * sizeof(T));
|
||||
output_size_list_.push_back(param_dim_ * sizeof(T));
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
int input_row_;
|
||||
int input_col_;
|
||||
int param_dim_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
class LayerNormGradNet(nn.Cell):
|
||||
def __init__(self, begin_norm_axis, begin_params_axis):
|
||||
super(LayerNormGradNet, self).__init__()
|
||||
self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
|
||||
|
||||
def construct(self, dy, x, var, mean, gamma):
|
||||
return self.norm(dy, x, var, mean, gamma)
|
||||
|
||||
def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
|
||||
begin_norm_axis = begin_norm_axis if begin_norm_axis >=0 else begin_norm_axis + len(x.shape)
|
||||
begin_params_axis = begin_params_axis if begin_params_axis >=0 else begin_params_axis + len(x.shape)
|
||||
|
||||
norm_axis = [i for i in range(begin_norm_axis, len(x.shape))]
|
||||
param_axis = [i for i in range(0, begin_params_axis)]
|
||||
num = 1
|
||||
for i in range(begin_norm_axis, len(x.shape)):
|
||||
num *= x.shape[i]
|
||||
|
||||
mean = np.mean(x, axis=tuple(norm_axis), keepdims=True)
|
||||
var = np.var(x, axis=tuple(norm_axis), keepdims=True)
|
||||
|
||||
gamma = gamma.reshape((*((1,)*begin_params_axis), *x.shape[begin_params_axis:]))
|
||||
dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
|
||||
db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
|
||||
|
||||
sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis), keepdims=True)
|
||||
sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
|
||||
sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
|
||||
|
||||
dx1 = dy * gamma * np.power(var + epsilon, -0.5)
|
||||
dx2 = sum1 * 2.0 / num * (x - mean)
|
||||
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
|
||||
dx = dx1 + dx2 + dx3
|
||||
return dx, dg, db, mean, var
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgrad0():
|
||||
begin_norm_axis = 1
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(4096, 3072).astype(np.float32)
|
||||
dy_np = np.random.randn(4096, 3072).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
epsilon = 10e-12
|
||||
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np)
|
||||
x_ms = Tensor(x_np)
|
||||
var_ms = Tensor(var_np)
|
||||
mean_ms = Tensor(mean_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
|
||||
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
|
||||
dx_ms, dg_ms, db_ms = net(dy_ms, x_ms, var_ms, mean_ms, gamma_ms)
|
||||
|
||||
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
|
||||
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgrad1():
|
||||
begin_norm_axis = 1
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(640, 768).astype(np.float32)
|
||||
dy_np = np.random.randn(640, 768).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
epsilon = 10e-12
|
||||
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np)
|
||||
x_ms = Tensor(x_np)
|
||||
var_ms = Tensor(var_np)
|
||||
mean_ms = Tensor(mean_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
|
||||
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
|
||||
dx_ms, dg_ms, db_ms = net(dy_ms, x_ms, var_ms, mean_ms, gamma_ms)
|
||||
|
||||
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
|
||||
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgrad2():
|
||||
begin_norm_axis = -1
|
||||
begin_params_axis = -1
|
||||
x_np = np.random.randn(32, 128, 768).astype(np.float32)
|
||||
dy_np = np.random.randn(32, 128, 768).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
epsilon = 10e-12
|
||||
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np)
|
||||
x_ms = Tensor(x_np)
|
||||
var_ms = Tensor(var_np)
|
||||
mean_ms = Tensor(mean_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
|
||||
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
|
||||
dx_ms, dg_ms, db_ms = net(dy_ms, x_ms, var_ms, mean_ms, gamma_ms)
|
||||
|
||||
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
|
||||
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class LayerNormNet(nn.Cell):
|
||||
def __init__(self, begin_norm_axis, begin_params_axis):
|
||||
super(LayerNormNet, self).__init__()
|
||||
self.norm = P.LayerNorm(begin_norm_axis, begin_params_axis)
|
||||
|
||||
def construct(self, x, gamma, beta):
|
||||
return self.norm(x, gamma, beta)
|
||||
|
||||
def LayerNormReference(begin_norm_axis, begin_params_axis, x, gamma, beta):
|
||||
begin_norm_axis = begin_norm_axis if begin_norm_axis >=0 else begin_norm_axis + len(x.shape)
|
||||
begin_params_axis = begin_params_axis if begin_params_axis >=0 else begin_params_axis + len(x.shape)
|
||||
|
||||
axis = [i for i in range(begin_norm_axis, len(x.shape))]
|
||||
mean = np.mean(x, axis=tuple(axis), keepdims=True)
|
||||
var = np.var(x, axis=tuple(axis), keepdims=True)
|
||||
|
||||
gamma = gamma.reshape((*((1,)*begin_params_axis), *x.shape[begin_params_axis:]))
|
||||
beta = beta.reshape((*((1,)*begin_params_axis), *x.shape[begin_params_axis:]))
|
||||
y = np.subtract(x, mean) / np.sqrt(var + 1e-12) * gamma + beta
|
||||
return y, mean, var
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernorm0():
|
||||
begin_norm_axis = 1
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(4096, 3072).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
|
||||
|
||||
x_ms = Tensor(x_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
beta_ms = Tensor(beta_np)
|
||||
net = LayerNormNet(begin_norm_axis, begin_params_axis)
|
||||
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
|
||||
|
||||
assert np.allclose(y_ms.asnumpy(), y_np, atol=1e-6)
|
||||
assert np.allclose(mean_ms.asnumpy(), mean_np, atol=1e-6)
|
||||
assert np.allclose(var_ms.asnumpy(), var_np, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernorm1():
|
||||
begin_norm_axis = 1
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(640, 768).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
|
||||
|
||||
x_ms = Tensor(x_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
beta_ms = Tensor(beta_np)
|
||||
net = LayerNormNet(begin_norm_axis, begin_params_axis)
|
||||
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
|
||||
|
||||
|
||||
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernorm3d_1():
|
||||
begin_norm_axis = -1
|
||||
begin_params_axis = -1
|
||||
x_np = np.random.randn(32, 128, 768).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
|
||||
|
||||
x_ms = Tensor(x_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
beta_ms = Tensor(beta_np)
|
||||
net = LayerNormNet(begin_norm_axis, begin_params_axis)
|
||||
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
|
||||
|
||||
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernorm3d_2():
|
||||
begin_norm_axis = -1
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(32, 128, 768).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
|
||||
|
||||
x_ms = Tensor(x_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
beta_ms = Tensor(beta_np)
|
||||
net = LayerNormNet(begin_norm_axis, begin_params_axis)
|
||||
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
|
||||
|
||||
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)
|
Loading…
Reference in New Issue