!13475 [GraphKernel]adapt for layernorm in ascend

From: @wenfangpei
Reviewed-by: @gaoxiong1,@anyrenwei
Signed-off-by: @anyrenwei
This commit is contained in:
mindspore-ci-bot 2021-04-09 17:04:52 +08:00 committed by Gitee
commit 0920239699
6 changed files with 263 additions and 155 deletions

View File

@ -17,33 +17,118 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
class LayerNorm(Expander):
"""LayerNorm expander"""
def to_frac_z_axis(self, ori_shape, ori_axis):
"""
judge the format is fractal NZ
Parameters
----------
ori_shape: list or tuple
original shape of input
ori_axis: list or tuple
original axis of original shape to operate
Returns
-------
output: list
axis of the fractal Nz shape
"""
frac_z_axis = list(ori_axis)
shape_len = len(ori_shape)
axis_count = len(frac_z_axis)
axis_negative_1 = shape_len - 1
axis_negative_2 = shape_len - 2
for i in range(axis_count):
axis_index = (frac_z_axis[i] + shape_len) % shape_len
if axis_index == axis_negative_1:
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
frac_z_axis[i] = axis_index - 1
frac_z_axis.append(axis_index + 2)
else: # no case cover this branch now
frac_z_axis[i] = axis_index - 1
frac_z_axis.append(axis_index + 2)
elif axis_index == axis_negative_2:
frac_z_axis[i] = axis_index + 1
frac_z_axis.append(axis_index + 2)
else:
frac_z_axis[i] = axis_index
return frac_z_axis
def infer_shape_from_fractalNz(self, fractal):
"get original shape from fractalNz shape"
shape = []
dims = len(fractal)
batch = dims - 4
for i in range(batch):
shape.append(fractal[i])
m = fractal[dims - 3] * fractal[dims - 2]
n = fractal[dims - 4] * fractal[dims - 1]
shape.append(m)
shape.append(n)
return shape
def get_reduced_ori_shape(self, shape, axis):
"get shape after reduced which is based on original shape"
reduced_ori_shape = []
for i, value in enumerate(shape):
if i in axis:
reduced_ori_shape.append(1)
else:
reduced_ori_shape.append(value)
return reduced_ori_shape
def _expand(self, graph_builder):
input_x, input_gamma, input_beta = self.inputs
processor = self.processor
begin_norm_axis = self.attrs['begin_norm_axis']
epsilon = self.attrs['epsilon']
ori_dtype = input_x.dtype
if processor == 'aicore' and ori_dtype == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'})
input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'})
ori_shape_x = input_x.shape
if input_x.data_format == DF.FRAC_NZ:
ori_shape_x = self.infer_shape_from_fractalNz(input_x.shape)
# Calculate the scaling ratio of the average
if begin_norm_axis < 0:
begin_norm_axis += len(input_x.shape)
begin_norm_axis += len(ori_shape_x)
reduce_axis = ()
for i, _ in enumerate(input_x.shape):
for i, _ in enumerate(ori_shape_x):
if i > begin_norm_axis or i == begin_norm_axis:
reduce_axis = reduce_axis + (i,)
reduce_elts = 1.0
for i in reduce_axis:
reduce_elts *= input_x.shape[i]
reduce_elts *= ori_shape_x[i]
if input_x.data_format == DF.FRAC_NZ:
reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis)
ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced
mean_cof = 1.0 / reduce_elts
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof)
# Calculate mean
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
mean_red = graph_builder.emit('ReduceSum', [input_x],
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
if input_x.data_format == DF.FRAC_NZ:
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x})
# Calculate variance
variance_sub = graph_builder.emit('Sub', [input_x, mean])
@ -51,6 +136,8 @@ class LayerNorm(Expander):
variance_red = graph_builder.emit('ReduceSum', [variance_mul],
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
if input_x.data_format == DF.FRAC_NZ:
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x})
# Calculate normalize
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
@ -60,7 +147,11 @@ class LayerNorm(Expander):
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
# Calculate scale and translate
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma])
res = graph_builder.emit('Add', [scale_mul, input_beta])
if processor == 'aicore' and ori_dtype == 'float16':
res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'})
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'})
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float16'})
return res, mean, variance

View File

@ -24,23 +24,33 @@ class LayerNormGrad(Expander):
def _expand(self, graph_builder):
x, dy, variance, mean, gamma = self.inputs
processor = self.processor
begin_norm_axis = self.attrs['begin_norm_axis']
begin_params_axis = self.attrs['begin_params_axis']
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
ori_dtype = x.dtype
if processor == 'aicore' and ori_dtype == 'float16':
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float32'})
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
if begin_norm_axis < 0:
begin_norm_axis += len(x.shape)
if begin_params_axis < 0:
begin_params_axis += len(x.shape)
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
param_axis = tuple(range(0, begin_params_axis))
reduce_size = 1.0
for i in norm_axis:
reduce_size *= x.shape[i]
# set some constant val.
eps = graph_builder.value(x.dtype, epsilon)
const_one = graph_builder.value(x.dtype, 1.0)
const_neg_half = graph_builder.value(x.dtype, -0.5)
const_neg_two = graph_builder.value(x.dtype, -2.0)
const_two = graph_builder.value(x.dtype, 2.0)
@ -49,42 +59,55 @@ class LayerNormGrad(Expander):
# cal dg db
var_eps = graph_builder.emit('Add', [variance, eps])
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
var_eps_log = graph_builder.emit('Log', [var_eps])
var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
x_sub_mean = graph_builder.emit('Sub', [x, mean])
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
# cal sum_1
tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps])
r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps])
x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps])
# pd_var
tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps])
sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul])
sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
# cal sum_2
sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# pd_mean
pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
# cal sum_3
sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
# cal dx = dx1 + dx2 + dx3
dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two])
sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof])
dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean])
neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps])
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
dx = graph_builder.emit('Add', [dx_tmp, dx_3])
pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
# cal dx
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
if processor == 'aicore' and ori_dtype == 'float16':
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
return dx, dg, db

View File

@ -47,6 +47,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimSquare,
prim::kPrimGeLUGrad,
prim::kPrimAssignAdd,
prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad,
#if ENABLE_D
prim::kPrimTile,
prim::kPrimSqrtGrad,
@ -67,8 +69,6 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimDropout,
prim::kPrimDropoutGrad,
prim::kPrimSoftmax,
prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad,
prim::kPrimRelu,
prim::kPrimReluGrad,
prim::kPrimSigmoid,

View File

@ -54,7 +54,6 @@
#include "debug/data_dump/dump_json_parser.h"
#include "debug/tensor_load.h"
#include "debug/anf_ir_utils.h"
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
#include "backend/session/ascend_auto_monad.h"
#include "debug/data_dump/e2e_dump.h"

View File

@ -223,12 +223,12 @@ def test_bert_precision(enable_graph_kernel=False):
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
if enable_graph_kernel:
expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565,
12.185522, 12.386192]
else:
assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656,
12.407923, 12.631133]
print("loss value: {}".format(loss_value))

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,150 +13,145 @@
# limitations under the License.
# ============================================================================
import copy
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.nn as nn
from mindspore.ops.operations import _grad_ops as G
import mindspore.ops.operations as P
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.layernorm = P.LayerNorm(1, 1)
class LayerNormNet(nn.Cell):
def __init__(self, begin_norm_axis, begin_params_axis):
super(LayerNormNet, self).__init__()
self.layernorm = P.LayerNorm(begin_norm_axis, begin_params_axis)
def construct(self, x, y, z):
return self.layernorm(x, y, z)
def construct(self, x, gamma, beta):
return self.layernorm(x, gamma, beta)
class LayerNormGradNet(nn.Cell):
def __init__(self, begin_norm_axis, begin_params_axis):
super(LayerNormGradNet, self).__init__()
self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
self.layernorm_grad = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
def construct(self, dy, x, var, mean, gamma):
return self.norm(dy, x, var, mean, gamma)
return self.layernorm_grad(dy, x, var, mean, gamma)
def get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
output = net(x, gamma, beta)
return output
def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
norm_axis = [i for i in range(begin_norm_axis, len(x.shape))]
param_axis = [i for i in range(0, begin_params_axis)]
num = 1
for i in range(begin_norm_axis, len(x.shape)):
num *= x.shape[i]
mean = np.mean(x, axis=tuple(norm_axis), keepdims=True)
var = np.var(x, axis=tuple(norm_axis), keepdims=True)
gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
keepdims=True)
sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
dx1 = dy * gamma * np.power(var + epsilon, -0.5)
dx2 = sum1 * 2.0 / num * (x - mean)
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
dx = dx1 + dx2 + dx3
return dx, dg, db, mean, var
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
gamma = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
beta = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
shape_x = [2, 3, 4, 3]
begin_norm_axis = 1
in_rank = len(shape_x)
if begin_norm_axis < 0:
norm_axis = begin_norm_axis + in_rank
else:
norm_axis = begin_norm_axis
norm_axes = tuple(range(norm_axis, in_rank))
mean = np.mean(input_x, axis=norm_axes, keepdims=True)
mean_b = np.broadcast_to(mean, shape_x)
diff = input_x - mean_b
square = np.square(diff)
smean = np.mean(square, axis=norm_axes, keepdims=True)
smean_b = np.broadcast_to(smean, shape_x)
meps = smean_b + 1e-5
logs = np.log(meps)
mul = logs * (-0.5)
rsqrt = np.exp(mul)
out = diff * rsqrt
bn = out * gamma + beta
expect = (bn, mean, smean)
net = Net()
net_result = net(Tensor(input_x), Tensor(gamma), Tensor(beta))
if isinstance(net_result, tuple) and len(net_result) == 3:
result = (net_result[0].asnumpy(), net_result[1].asnumpy(), net_result[2].asnumpy())
res0 = np.allclose(expect[0], result[0], rtol=1.e-4, atol=1.e-4, equal_nan=True)
assert res0
res1 = np.allclose(expect[1], result[1], rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res1
res2 = np.allclose(expect[2], result[2], rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res2
else:
assert False
def test_layernormgrad():
np.random.seed(0)
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(4096, 3072).astype(np.float32)
dy_np = np.random.randn(4096, 3072).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 1e-11
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
def get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
output = net(x, dy, var, mean, gamma)
return output
def get_rtol_atol(dtype):
if dtype == np.float16:
return 1.e-3, 1.e-3
return 1.e-4, 1.e-4
def compare_result(expect, output, dtype):
rtol, atol = get_rtol_atol(dtype)
if isinstance(expect, (list, tuple)):
assert isinstance(output, (list, tuple)) and len(expect) == len(output)
expect_list = list(expect)
output_list = list(output)
for e, o in zip(expect_list, output_list):
assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True)
else:
assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True)
def test_layernorm(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
assert 0 <= begin_norm_axis < len(shape)
assert 0 <= begin_params_axis < len(shape)
normalized_shape = shape[begin_params_axis:]
np.random.seed(0)
# input tensors
x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
gamma = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
beta = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
expect = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, False)
output = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, True)
compare_result(expect, output, dtype)
def test_layernorm_grad(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
assert 0 <= begin_norm_axis < len(shape)
assert 0 <= begin_params_axis < len(shape)
norm_axis = [i for i in range(begin_norm_axis, len(shape))]
norm_shape = copy.deepcopy(shape)
for i, _ in enumerate(norm_shape):
if i in norm_axis:
norm_shape[i] = 1
params_shape = shape[begin_params_axis:]
np.random.seed(0)
# input tensors
dy = Tensor(np.random.normal(0, 1, shape).astype(dtype))
x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
var = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
mean = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
gamma = Tensor(np.random.normal(0, 1, params_shape).astype(dtype))
expect = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, False)
output = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, True)
compare_result(expect, output, dtype)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernorm_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_layernorm([4, 32, 32], np.float32, -1, -1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_layernorm_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_layernorm([4, 32, 32], np.float16, -1, -1)
test_layernorm([4, 32, 32], np.float32, -1, -1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_basic()
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_basic()
def test_layernorm_grad_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_layernorm_grad([4, 32, 32], np.float32, -1, -1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_layernormgrad_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_layernormgrad()
def test_layernormgrad_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_layernormgrad()
def test_layernorm_grad_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_layernorm_grad([2, 16, 32], np.float16, -1, -1)
test_layernorm_grad([4, 32, 32], np.float32, -1, -1)