forked from mindspore-Ecosystem/mindspore
!9979 BUGFIX: Correct the calculation of the output 'd_x' of the operator LayernormGradGrad
From: @david-he91 Reviewed-by: @liangchenghui,@linqingke Signed-off-by: @liangchenghui
This commit is contained in:
commit
a3d4dded12
|
@ -20,9 +20,10 @@
|
|||
#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh"
|
||||
|
||||
constexpr int THREAD_PER_BLOCK = 256;
|
||||
constexpr int NUM_PER_THREAD_REDUCE = 4;
|
||||
constexpr int WARP_SIZE = 32;
|
||||
constexpr int NUM_SHARED_SUM_INPUT = 6;
|
||||
constexpr int NUM_SHARED_SUM_INPUT = 7;
|
||||
constexpr int NUM_SHARED_SUM_GAMMA = 3;
|
||||
|
||||
template <typename T>
|
||||
|
@ -30,6 +31,7 @@ inline __device__ T my_pow(T a, double b) {
|
|||
return pow(a, static_cast<float>(b));
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline __device__ half my_pow(half a, double b) {
|
||||
return __float2half(pow(__half2float(a), static_cast<float>(b)));
|
||||
|
@ -52,15 +54,17 @@ inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_d
|
|||
int pos = row * col_dim + col;
|
||||
int mean_offset = pos / mean_dim;
|
||||
|
||||
T v1 = my_pow(var[mean_offset] + epsilon, -0.5);
|
||||
T v1 = x[pos] - mean[mean_offset];
|
||||
T v2 = my_pow(var[mean_offset] + epsilon, -0.5);
|
||||
|
||||
part1[0] += dy[pos] * v1 * (x[pos] - mean[mean_offset]) * global_sum2[pos];
|
||||
part1[0] += dy[pos] * v1 * v2 * global_sum2[pos];
|
||||
part2[0] += dy[pos] * global_sum1[pos];
|
||||
part3[0] += dy[pos] * grad_dx[pos] * v1;
|
||||
part3[0] += dy[pos] * v2 * grad_dx[pos];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void GammaAndBetaWarpReduce(T *part1, T *part2, T *part3) {
|
||||
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
||||
|
@ -70,6 +74,7 @@ inline __device__ void GammaAndBetaWarpReduce(T *part1, T *part2, T *part3) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_dim, T *part1, T *part2, T *part3,
|
||||
T *d_gamma) {
|
||||
|
@ -77,7 +82,7 @@ inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_di
|
|||
// thread(0, 32, 64, 96, ...) keep the data
|
||||
DynamicSharedMem<T> share_mem;
|
||||
if (threadIdx.x % WARP_SIZE == 0) {
|
||||
int offset = threadIdx.x / WARP_SIZE * 3;
|
||||
int offset = threadIdx.x / WARP_SIZE * NUM_SHARED_SUM_GAMMA;
|
||||
share_mem.addr()[offset] = part1[0];
|
||||
share_mem.addr()[offset + 1] = part2[0];
|
||||
share_mem.addr()[offset + 2] = part3[0];
|
||||
|
@ -86,10 +91,10 @@ inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_di
|
|||
|
||||
for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride) {
|
||||
int offset = (threadIdx.x + stride) * 3;
|
||||
share_mem.addr()[threadIdx.x * 3] += share_mem.addr()[offset];
|
||||
share_mem.addr()[threadIdx.x * 3 + 1] += share_mem.addr()[offset + 1];
|
||||
share_mem.addr()[threadIdx.x * 3 + 2] += share_mem.addr()[offset + 2];
|
||||
int offset = (threadIdx.x + stride) * NUM_SHARED_SUM_GAMMA;
|
||||
share_mem.addr()[threadIdx.x * NUM_SHARED_SUM_GAMMA] += share_mem.addr()[offset];
|
||||
share_mem.addr()[threadIdx.x * NUM_SHARED_SUM_GAMMA + 1] += share_mem.addr()[offset + 1];
|
||||
share_mem.addr()[threadIdx.x * NUM_SHARED_SUM_GAMMA + 2] += share_mem.addr()[offset + 2];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
@ -99,6 +104,7 @@ inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_di
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon,
|
||||
const T *dy, const T *x, const T *mean, const T *var, const T *grad_dx,
|
||||
|
@ -143,6 +149,7 @@ inline __device__ void InputThreadReduceInnerMean(const int &row, const int &col
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void InputWarpReduceInnerMean(T *sum1, T *sum2, T *sum3, T *sum4) {
|
||||
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
||||
|
@ -153,12 +160,13 @@ inline __device__ void InputWarpReduceInnerMean(T *sum1, T *sum2, T *sum3, T *su
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void InputBlockReduceInnerMean(const int &col_dim, T *sum1, T *sum2, T *sum3, T *sum4, T *share_mem) {
|
||||
// load data to share memory
|
||||
// thread(0, 32, 64, 96, ...) keep the data
|
||||
if (threadIdx.x % WARP_SIZE == 0) {
|
||||
int offset = threadIdx.x / WARP_SIZE * 6;
|
||||
int offset = threadIdx.x / WARP_SIZE * NUM_SHARED_SUM_INPUT;
|
||||
share_mem[offset] = sum1[0];
|
||||
share_mem[offset + 1] = sum2[0];
|
||||
share_mem[offset + 2] = sum3[0];
|
||||
|
@ -168,12 +176,12 @@ inline __device__ void InputBlockReduceInnerMean(const int &col_dim, T *sum1, T
|
|||
|
||||
for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride) {
|
||||
int offset = (threadIdx.x + stride) * 6;
|
||||
int offset = (threadIdx.x + stride) * NUM_SHARED_SUM_INPUT;
|
||||
|
||||
share_mem[threadIdx.x * 3] += share_mem[offset];
|
||||
share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1];
|
||||
share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2];
|
||||
share_mem[threadIdx.x * 3 + 3] += share_mem[offset + 3];
|
||||
share_mem[threadIdx.x * NUM_SHARED_SUM_INPUT] += share_mem[offset];
|
||||
share_mem[threadIdx.x * NUM_SHARED_SUM_INPUT + 1] += share_mem[offset + 1];
|
||||
share_mem[threadIdx.x * NUM_SHARED_SUM_INPUT + 2] += share_mem[offset + 2];
|
||||
share_mem[threadIdx.x * NUM_SHARED_SUM_INPUT + 3] += share_mem[offset + 3];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
@ -182,8 +190,8 @@ inline __device__ void InputBlockReduceInnerMean(const int &col_dim, T *sum1, T
|
|||
|
||||
template <typename T>
|
||||
inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int ¶m_dim,
|
||||
const T &epsilon, T *sum5, T *sum6, T *share_mem, const T *dy,
|
||||
const T *x, const T *mean, const T *var, const T *gamma,
|
||||
const T &epsilon, T *sum5, T *sum6, T *sum7, T *share_mem,
|
||||
const T *dy, const T *x, const T *mean, const T *var, const T *gamma,
|
||||
const T *grad_dx, const T *grad_dg, T *d_x) {
|
||||
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
||||
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
||||
|
@ -198,17 +206,21 @@ inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col
|
|||
T v1 = x[pos] - mean[row];
|
||||
T v2 = my_pow(var[row] + epsilon, -0.5);
|
||||
T v3 = dy[pos] * gamma[gamma_offset];
|
||||
T v4 = v3 * share_mem[1] * (1.0 / col_dim);
|
||||
T v5 = grad_dx[pos] * v2 * share_mem[3] * (-1.0 / col_dim);
|
||||
T v6 = dy[pos] * grad_dg[gamma_offset];
|
||||
T v7 = v4 + v5 + v6;
|
||||
|
||||
T part1 = v1 * v7;
|
||||
T part2 = v2 * v7;
|
||||
d_x[pos] = part2;
|
||||
T v4 = v3 - share_mem[2] * (1.0 / col_dim) - v1 * v2 * share_mem[3] * (1.0 / col_dim);
|
||||
T v5 = v3 * share_mem[1] * (1.0 / col_dim);
|
||||
T v6 = grad_dx[pos] * v2 * share_mem[3] * (-1.0 / col_dim);
|
||||
T v7 = dy[pos] * grad_dg[gamma_offset];
|
||||
T v8 = v5 + v6 + v7;
|
||||
|
||||
T part1 = v4 * grad_dx[pos];
|
||||
T part2 = v1 * v8;
|
||||
T part3 = v2 * v8;
|
||||
d_x[pos] = part3;
|
||||
|
||||
sum5[0] += part1;
|
||||
sum6[0] -= part2;
|
||||
sum6[0] += part2;
|
||||
sum7[0] -= part3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -216,10 +228,10 @@ inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col
|
|||
|
||||
template <>
|
||||
inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int ¶m_dim,
|
||||
const half &epsilon, half *sum5, half *sum6, half *share_mem,
|
||||
const half *dy, const half *x, const half *mean, const half *var,
|
||||
const half *gamma, const half *grad_dx, const half *grad_dg,
|
||||
half *d_x) {
|
||||
const half &epsilon, half *sum5, half *sum6, half *sum7,
|
||||
half *share_mem, const half *dy, const half *x, const half *mean,
|
||||
const half *var, const half *gamma, const half *grad_dx,
|
||||
const half *grad_dg, half *d_x) {
|
||||
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
||||
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
||||
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
|
||||
|
@ -233,48 +245,54 @@ inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col
|
|||
half v1 = x[pos] - mean[row];
|
||||
half v2 = my_pow(var[row] + epsilon, -0.5);
|
||||
half v3 = dy[pos] * gamma[gamma_offset];
|
||||
half v4 = v3 * share_mem[1] * __float2half(1.0 / col_dim);
|
||||
half v5 = grad_dx[pos] * v2 * share_mem[3] * __float2half(-1.0 / col_dim);
|
||||
half v6 = dy[pos] * grad_dg[gamma_offset];
|
||||
half v7 = v4 + v5 + v6;
|
||||
half v4 = v3 - share_mem[2] * __float2half(1.0 / col_dim) - v1 * v2 * share_mem[3] * __float2half(1.0 / col_dim);
|
||||
half v5 = v3 * share_mem[1] * __float2half(1.0 / col_dim);
|
||||
half v6 = grad_dx[pos] * v2 * share_mem[3] * __float2half(-1.0 / col_dim);
|
||||
half v7 = dy[pos] * grad_dg[gamma_offset];
|
||||
half v8 = v5 + v6 + v7;
|
||||
|
||||
half part1 = v1 * v7;
|
||||
half part2 = v2 * v7;
|
||||
d_x[pos] = part2;
|
||||
half part1 = v4 * grad_dx[pos];
|
||||
half part2 = v1 * v8;
|
||||
half part3 = v2 * v8;
|
||||
d_x[pos] = part3;
|
||||
|
||||
sum5[0] += part1;
|
||||
sum6[0] -= part2;
|
||||
sum6[0] += part2;
|
||||
sum7[0] -= part3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void InputWarpReduceOuterMean(T *sum5, T *sum6) {
|
||||
inline __device__ void InputWarpReduceOuterMean(T *sum5, T *sum6, T *sum7) {
|
||||
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
||||
sum5[0] += __shfl_down_sync(0xffffffff, sum5[0], delta);
|
||||
sum6[0] += __shfl_down_sync(0xffffffff, sum6[0], delta);
|
||||
sum7[0] += __shfl_down_sync(0xffffffff, sum7[0], delta);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void InputBlockReduceOuterMean(const int &col_dim, T *sum5, T *sum6, T *share_mem) {
|
||||
inline __device__ void InputBlockReduceOuterMean(const int &col_dim, T *sum5, T *sum6, T *sum7, T *share_mem) {
|
||||
// load data to share memory
|
||||
// thread(0, 32, 64, 96, ...) keep the data
|
||||
if (threadIdx.x % WARP_SIZE == 0) {
|
||||
int offset = threadIdx.x / WARP_SIZE * 6;
|
||||
int offset = threadIdx.x / WARP_SIZE * NUM_SHARED_SUM_INPUT;
|
||||
|
||||
share_mem[offset + 4] = sum5[0];
|
||||
share_mem[offset + 5] = sum6[0];
|
||||
share_mem[offset + 6] = sum7[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride) {
|
||||
int offset = (threadIdx.x + stride) * 6;
|
||||
int offset = (threadIdx.x + stride) * NUM_SHARED_SUM_INPUT;
|
||||
|
||||
share_mem[threadIdx.x * 6 + 4] += share_mem[offset + 4];
|
||||
share_mem[threadIdx.x * 6 + 5] += share_mem[offset + 5];
|
||||
share_mem[threadIdx.x * NUM_SHARED_SUM_INPUT + 4] += share_mem[offset + 4];
|
||||
share_mem[threadIdx.x * NUM_SHARED_SUM_INPUT + 5] += share_mem[offset + 5];
|
||||
share_mem[threadIdx.x * NUM_SHARED_SUM_INPUT + 6] += share_mem[offset + 6];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
@ -300,8 +318,8 @@ inline __device__ void InputProp(const int &row, const int &col_dim, const int &
|
|||
T part4 = v3 * grad_dg[gamma_offset];
|
||||
d_dy[pos] = part1 + part2 + part3 + part4 + grad_db[gamma_offset];
|
||||
|
||||
T part5 = v1 * (my_pow(var[row] + epsilon, -1.5) * (share_mem[4] * (-1.0 / col_dim)));
|
||||
d_x[pos] += part5 + share_mem[5] * (1.0 / col_dim);
|
||||
T part5 = v1 * (my_pow(var[row] + epsilon, -1.5) * ((share_mem[4]+ share_mem[5]) * (-1.0 / col_dim)));
|
||||
d_x[pos] += part5 + share_mem[6] * (1.0 / col_dim);
|
||||
|
||||
global_sum1[pos] = share_mem[0] * (1.0 / col_dim);
|
||||
global_sum2[pos] = share_mem[1] * (1.0 / col_dim);
|
||||
|
@ -328,8 +346,9 @@ inline __device__ void InputProp(const int &row, const int &col_dim, const int &
|
|||
half part4 = v3 * grad_dg[gamma_offset];
|
||||
d_dy[pos] = part1 + part2 + part3 + part4 + grad_db[gamma_offset];
|
||||
|
||||
half part5 = v1 * (my_pow(var[row] + epsilon, -1.5) * (share_mem[4] * __float2half(-1.0 / col_dim)));
|
||||
d_x[pos] += part5 + share_mem[5] * __float2half(1.0 / col_dim);
|
||||
half part5 = v1 * (my_pow(var[row] + epsilon, -1.5) *
|
||||
((share_mem[4]+ share_mem[5]) * __float2half(-1.0 / col_dim)));
|
||||
d_x[pos] += part5 + share_mem[6] * __float2half(1.0 / col_dim);
|
||||
|
||||
global_sum1[pos] = share_mem[0] * __float2half(1.0 / col_dim);
|
||||
global_sum2[pos] = share_mem[1] * __float2half(1.0 / col_dim);
|
||||
|
@ -349,6 +368,7 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int
|
|||
T sum4 = 0;
|
||||
T sum5 = 0;
|
||||
T sum6 = 0;
|
||||
T sum7 = 0;
|
||||
DynamicSharedMem<T> share_mem;
|
||||
|
||||
InputThreadReduceInnerMean(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, &sum4, dy, x, mean, var, gamma,
|
||||
|
@ -356,34 +376,34 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int
|
|||
InputWarpReduceInnerMean(&sum1, &sum2, &sum3, &sum4);
|
||||
InputBlockReduceInnerMean(col_dim, &sum1, &sum2, &sum3, &sum4, share_mem.addr());
|
||||
|
||||
InputThreadReduceOuterMean(row, col_dim, param_dim, epsilon, &sum5, &sum6, share_mem.addr(), dy, x, mean,
|
||||
InputThreadReduceOuterMean(row, col_dim, param_dim, epsilon, &sum5, &sum6, &sum7, share_mem.addr(), dy, x, mean,
|
||||
var, gamma, grad_dx, grad_dg, d_x);
|
||||
InputWarpReduceOuterMean(&sum5, &sum6);
|
||||
InputBlockReduceOuterMean(col_dim, &sum5, &sum6, share_mem.addr());
|
||||
InputWarpReduceOuterMean(&sum5, &sum6, &sum7);
|
||||
InputBlockReduceOuterMean(col_dim, &sum5, &sum6, &sum7, share_mem.addr());
|
||||
InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, grad_dx, grad_dg, grad_db, d_dy, d_x,
|
||||
share_mem.addr(), global_sum1, global_sum2);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, T *global_sum1, T *global_sum2,
|
||||
const T &epsilon, const T *dy, const T *x, const T *mean, const T *var, const T *gamma,
|
||||
const T* grad_dx, const T* grad_dg, const T* grad_db, T *d_dy, T *d_x, T *d_gamma,
|
||||
cudaStream_t stream) {
|
||||
const int thread_per_block = 256;
|
||||
|
||||
int share_mem_size = thread_per_block / WARP_SIZE * NUM_SHARED_SUM_INPUT * sizeof(T);
|
||||
InputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x,
|
||||
int share_mem_size = THREAD_PER_BLOCK / WARP_SIZE * NUM_SHARED_SUM_INPUT * sizeof(T);
|
||||
InputPropKernel<<<row_dim, THREAD_PER_BLOCK, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x,
|
||||
mean, var, gamma, grad_dx, grad_dg, grad_db,
|
||||
d_dy, d_x, global_sum1, global_sum2);
|
||||
share_mem_size = thread_per_block / WARP_SIZE * NUM_SHARED_SUM_GAMMA * sizeof(T);
|
||||
share_mem_size = THREAD_PER_BLOCK / WARP_SIZE * NUM_SHARED_SUM_GAMMA * sizeof(T);
|
||||
int param_reduce_dim = row_dim * col_dim / param_dim;
|
||||
GammaAndBetaPropKernel<<<param_dim, thread_per_block, share_mem_size, stream>>>(param_reduce_dim, param_dim,
|
||||
GammaAndBetaPropKernel<<<param_dim, THREAD_PER_BLOCK, share_mem_size, stream>>>(param_reduce_dim, param_dim,
|
||||
col_dim, epsilon, dy, x, mean, var,
|
||||
grad_dx, d_gamma, global_sum1,
|
||||
global_sum2);
|
||||
}
|
||||
|
||||
|
||||
template void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, float *global_sum1,
|
||||
float *global_sum2, const float &epsilon, const float *dy, const float *x,
|
||||
const float *mean, const float *var, const float *gamma, const float *grad_dx,
|
||||
|
|
|
@ -58,8 +58,7 @@ class LayerNormGradGradGpuKernel : public GpuKernel {
|
|||
cudaMemsetAsync(global_sum2, 0, input_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemsetAsync global_sum2 failed");
|
||||
|
||||
const T epsilon = 10e-12;
|
||||
LayerNormGradGrad(input_row_, input_col_, param_dim_, global_sum1, global_sum2, epsilon, dy, x, mean, var, gamma,
|
||||
LayerNormGradGrad(input_row_, input_col_, param_dim_, global_sum1, global_sum2, epsilon_, dy, x, mean, var, gamma,
|
||||
grad_dx, grad_dg, grad_db, d_dy, d_x, d_gamma, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
@ -88,6 +87,12 @@ class LayerNormGradGradGpuKernel : public GpuKernel {
|
|||
param_dim_ *= input_shape[i];
|
||||
}
|
||||
|
||||
epsilon_ = 1e-12;
|
||||
auto type_id = TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
|
||||
if (std::strncmp(type_id, "kNumberTypeFloat16", std::strlen(type_id)) == 0) {
|
||||
epsilon_ = 1e-7;
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -122,6 +127,7 @@ class LayerNormGradGradGpuKernel : public GpuKernel {
|
|||
int input_col_;
|
||||
int param_dim_;
|
||||
int input_size_;
|
||||
T epsilon_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -696,13 +696,13 @@ def get_bprop_layer_norm(self):
|
|||
|
||||
@bprop_getters.register(G.LayerNormGrad)
|
||||
def get_bprop_layer_norm_grad(self):
|
||||
"""Grad definition for `LayerNorm` operation."""
|
||||
"""Grad definition for `LayerNormGrad` operation."""
|
||||
layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis)
|
||||
|
||||
def bprop(x, dy, variance, mean, gamma, out, dout):
|
||||
d_x, d_dy, d_gamma = layer_norm_grad_grad(
|
||||
x, dy, variance, mean, gamma, dout[0], dout[1], dout[2])
|
||||
return d_x, d_dy, d_gamma, zeros_like(variance), zeros_like(mean)
|
||||
return d_x, d_dy, zeros_like(variance), zeros_like(mean), d_gamma
|
||||
|
||||
return bprop
|
||||
|
||||
|
|
|
@ -1105,6 +1105,7 @@ class LayerNormGradGrad(PrimitiveWithInfer):
|
|||
|
||||
def __call__(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db):
|
||||
raise NotImplementedError
|
||||
|
||||
def infer_shape(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db):
|
||||
return x, dy, gamma
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore import Tensor
|
|||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
class LayerNormGradGradNet(nn.Cell):
|
||||
def __init__(self, begin_norm_axis, begin_params_axis):
|
||||
|
@ -61,6 +61,7 @@ def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_
|
|||
dx = dx1 + dx2 + dx3
|
||||
return dx, dg, db, mean, var
|
||||
|
||||
|
||||
def LayerNormGradGradReference(x, dy, gamma, epsilon, grad_dx_np, grad_dg_np, grad_db_np, begin_norm_axis,
|
||||
begin_params_axis):
|
||||
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
|
||||
|
@ -77,27 +78,31 @@ def LayerNormGradGradReference(x, dy, gamma, epsilon, grad_dx_np, grad_dg_np, gr
|
|||
inv_std = np.power(var + epsilon, -0.5)
|
||||
x_hat = (x - mean) * inv_std
|
||||
|
||||
sum1 = np.mean((-1.0) * inv_std * grad_dx_np, tuple(norm_axis), keepdims=True)
|
||||
sum2 = np.mean(x_hat * (-1.0) * inv_std * grad_dx_np, tuple(norm_axis), keepdims=True)
|
||||
sum3 = np.mean(dy * gamma * x_hat, tuple(norm_axis), keepdims=True)
|
||||
part = dy * gamma * sum2 + sum3 * (-1.0) * grad_dx_np * inv_std + dy * grad_dg_np
|
||||
sum4 = np.mean((x - mean) * part, tuple(norm_axis), keepdims=True)
|
||||
sum5 = np.mean(-inv_std * part, tuple(norm_axis), keepdims=True)
|
||||
sum1 = np.mean(-inv_std * grad_dx_np, tuple(norm_axis), keepdims=True)
|
||||
sum2 = np.mean(-x_hat * inv_std * grad_dx_np, tuple(norm_axis), keepdims=True)
|
||||
sum3 = np.mean(dy * gamma, tuple(norm_axis), keepdims=True)
|
||||
sum4 = np.mean(dy * gamma * x_hat, tuple(norm_axis), keepdims=True)
|
||||
part_sum1 = dy * gamma - sum3 - x_hat * sum4
|
||||
part_sum2 = dy * gamma * sum2 - sum4 * grad_dx_np * inv_std + dy * grad_dg_np
|
||||
part1 = np.mean(grad_dx_np * part_sum1, tuple(norm_axis), keepdims=True)
|
||||
part2 = np.mean((x - mean) * part_sum2, tuple(norm_axis), keepdims=True)
|
||||
part3 = inv_std * part_sum2
|
||||
sum5 = np.mean(part1, tuple(norm_axis), keepdims=True)
|
||||
sum6 = np.mean(part2, tuple(norm_axis), keepdims=True)
|
||||
sum7 = np.mean(-part3, tuple(norm_axis), keepdims=True)
|
||||
part4 = -(x - mean) * np.power(var + epsilon, -1.5) * (sum5 + sum6)
|
||||
d_x = part3 + part4 + sum7
|
||||
|
||||
part1 = inv_std * part
|
||||
part2 = (x - mean) * (-1.0) * np.power(var + epsilon, -1.5) * sum4
|
||||
d_x = part1 + part2 + sum5
|
||||
part5 = gamma * grad_dx_np * inv_std
|
||||
part6 = gamma * sum1
|
||||
part7 = gamma * x_hat * sum2
|
||||
part8 = x_hat * grad_dg_np
|
||||
d_dy = part5 + part6 + part7 + part8 + grad_db_np
|
||||
|
||||
part3 = gamma * grad_dx_np * inv_std
|
||||
part4 = gamma * sum1
|
||||
part5 = gamma * x_hat * sum2
|
||||
part6 = x_hat * grad_dg_np
|
||||
d_dy = part3 + part4 + part5 + part6 + grad_db_np
|
||||
|
||||
part7 = np.sum(dy * x_hat * sum2, tuple(param_axis), keepdims=True)
|
||||
part8 = np.sum(dy * sum1, tuple(param_axis), keepdims=True)
|
||||
part9 = np.sum(dy * grad_dx_np * inv_std, tuple(param_axis), keepdims=True)
|
||||
d_gamma = part7 + part8 + part9
|
||||
part9 = np.sum(dy * x_hat * sum2, tuple(param_axis), keepdims=True)
|
||||
part10 = np.sum(dy * sum1, tuple(param_axis), keepdims=True)
|
||||
part11 = np.sum(dy * grad_dx_np * inv_std, tuple(param_axis), keepdims=True)
|
||||
d_gamma = part9 + part10 + part11
|
||||
|
||||
return d_x, d_dy, d_gamma
|
||||
|
||||
|
@ -106,17 +111,16 @@ def LayerNormGradGradReference(x, dy, gamma, epsilon, grad_dx_np, grad_dg_np, gr
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad0():
|
||||
np.random.seed(1)
|
||||
begin_norm_axis = 1
|
||||
begin_params_axis = 1
|
||||
|
||||
x_np = np.random.rand(2, 2, 4).astype(np.float32)
|
||||
dy_np = np.random.rand(2, 2, 4).astype(np.float32)
|
||||
gamma_np = np.random.rand(2, 4).astype(np.float32)
|
||||
x_np = np.random.randn(4096, 3072).astype(np.float32)
|
||||
dy_np = np.random.randn(4096, 3072).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.rand(2, 2, 4).astype(np.float32)
|
||||
grad_dg_np = np.random.rand(2, 4).astype(np.float32)
|
||||
grad_db_np = np.random.rand(2, 4).astype(np.float32)
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-12
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
|
@ -125,18 +129,433 @@ def test_layernormgradgrad0():
|
|||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np)
|
||||
x_ms = Tensor(x_np)
|
||||
var_ms = Tensor(var_np)
|
||||
mean_ms = Tensor(mean_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
grad_dx_ms = Tensor(grad_dx_np)
|
||||
grad_dg_ms = Tensor(grad_dg_np)
|
||||
grad_db_ms = Tensor(grad_db_np)
|
||||
dy_ms = Tensor(dy_np.astype(np.float32))
|
||||
x_ms = Tensor(x_np.astype(np.float32))
|
||||
var_ms = Tensor(var_np.astype(np.float32))
|
||||
mean_ms = Tensor(mean_np.astype(np.float32))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float32))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float32))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float32))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float32))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=1e-6, atol=1e-3)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=1e-6, atol=1e-3)
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=3e-3, atol=3e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad1():
|
||||
begin_norm_axis = 1
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(640, 768).astype(np.float32)
|
||||
dy_np = np.random.randn(640, 768).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-12
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float32))
|
||||
x_ms = Tensor(x_np.astype(np.float32))
|
||||
var_ms = Tensor(var_np.astype(np.float32))
|
||||
mean_ms = Tensor(mean_np.astype(np.float32))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float32))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float32))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float32))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float32))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=3e-3, atol=3e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad2():
|
||||
begin_norm_axis = -1
|
||||
begin_params_axis = -1
|
||||
x_np = np.random.randn(32, 128, 768).astype(np.float32)
|
||||
dy_np = np.random.randn(32, 128, 768).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-12
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float32))
|
||||
x_ms = Tensor(x_np.astype(np.float32))
|
||||
var_ms = Tensor(var_np.astype(np.float32))
|
||||
mean_ms = Tensor(mean_np.astype(np.float32))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float32))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float32))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float32))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float32))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=3e-3, atol=3e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad3():
|
||||
begin_norm_axis = -1
|
||||
begin_params_axis = -1
|
||||
x_np = np.random.randn(32, 64).astype(np.float32)
|
||||
dy_np = np.random.randn(32, 64).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-12
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float32))
|
||||
x_ms = Tensor(x_np.astype(np.float32))
|
||||
var_ms = Tensor(var_np.astype(np.float32))
|
||||
mean_ms = Tensor(mean_np.astype(np.float32))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float32))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float32))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float32))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float32))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=3e-3, atol=3e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad4():
|
||||
begin_norm_axis = -1
|
||||
begin_params_axis = -1
|
||||
x_np = np.random.randn(32, 64).astype(np.float32)
|
||||
dy_np = np.random.randn(32, 64).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-12
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float32))
|
||||
x_ms = Tensor(x_np.astype(np.float32))
|
||||
var_ms = Tensor(var_np.astype(np.float32))
|
||||
mean_ms = Tensor(mean_np.astype(np.float32))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float32))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float32))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float32))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float32))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=3e-3, atol=3e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad5():
|
||||
begin_norm_axis = 2
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
|
||||
dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-12
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float32))
|
||||
x_ms = Tensor(x_np.astype(np.float32))
|
||||
var_ms = Tensor(var_np.astype(np.float32))
|
||||
mean_ms = Tensor(mean_np.astype(np.float32))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float32))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float32))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float32))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float32))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=3e-3, atol=3e-3)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=3e-3, atol=3e-3)
|
||||
|
||||
|
||||
def test_layernormgradgrad6():
|
||||
begin_norm_axis = 1
|
||||
begin_params_axis = 1
|
||||
|
||||
x_np = np.random.randn(4096, 3072).astype(np.float32)
|
||||
dy_np = np.random.randn(4096, 3072).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-7
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float16))
|
||||
x_ms = Tensor(x_np.astype(np.float16))
|
||||
var_ms = Tensor(var_np.astype(np.float16))
|
||||
mean_ms = Tensor(mean_np.astype(np.float16))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float16))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float16))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=5e-3, atol=5e-1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad7():
|
||||
begin_norm_axis = 1
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(640, 768).astype(np.float32)
|
||||
dy_np = np.random.randn(640, 768).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-7
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float16))
|
||||
x_ms = Tensor(x_np.astype(np.float16))
|
||||
var_ms = Tensor(var_np.astype(np.float16))
|
||||
mean_ms = Tensor(mean_np.astype(np.float16))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float16))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float16))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=5e-3, atol=5e-1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad8():
|
||||
begin_norm_axis = -1
|
||||
begin_params_axis = -1
|
||||
x_np = np.random.randn(32, 128, 768).astype(np.float32)
|
||||
dy_np = np.random.randn(32, 128, 768).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-7
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float16))
|
||||
x_ms = Tensor(x_np.astype(np.float16))
|
||||
var_ms = Tensor(var_np.astype(np.float16))
|
||||
mean_ms = Tensor(mean_np.astype(np.float16))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float16))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float16))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=5e-3, atol=5e-1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad9():
|
||||
begin_norm_axis = -1
|
||||
begin_params_axis = -1
|
||||
x_np = np.random.randn(32, 64).astype(np.float32)
|
||||
dy_np = np.random.randn(32, 64).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-7
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float16))
|
||||
x_ms = Tensor(x_np.astype(np.float16))
|
||||
var_ms = Tensor(var_np.astype(np.float16))
|
||||
mean_ms = Tensor(mean_np.astype(np.float16))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float16))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float16))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=5e-3, atol=5e-1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad10():
|
||||
begin_norm_axis = -1
|
||||
begin_params_axis = -1
|
||||
x_np = np.random.randn(32, 64).astype(np.float32)
|
||||
dy_np = np.random.randn(32, 64).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-7
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float16))
|
||||
x_ms = Tensor(x_np.astype(np.float16))
|
||||
var_ms = Tensor(var_np.astype(np.float16))
|
||||
mean_ms = Tensor(mean_np.astype(np.float16))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float16))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float16))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=5e-3, atol=5e-1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgradgrad11():
|
||||
begin_norm_axis = 2
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
|
||||
dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
grad_dx_np = np.random.randn(*x_np.shape).astype(np.float32)
|
||||
grad_dg_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
grad_db_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
|
||||
epsilon = 1e-7
|
||||
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
|
||||
grad_db_np, begin_norm_axis, begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np.astype(np.float16))
|
||||
x_ms = Tensor(x_np.astype(np.float16))
|
||||
var_ms = Tensor(var_np.astype(np.float16))
|
||||
mean_ms = Tensor(mean_np.astype(np.float16))
|
||||
gamma_ms = Tensor(gamma_np.astype(np.float16))
|
||||
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
|
||||
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
|
||||
grad_db_ms = Tensor(grad_db_np.astype(np.float16))
|
||||
|
||||
net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
|
||||
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)
|
||||
|
||||
assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=5e-3, atol=5e-1)
|
||||
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=5e-3, atol=5e-1)
|
||||
|
|
Loading…
Reference in New Issue