!13475 [GraphKernel]adapt for layernorm in ascend

From: @wenfangpei
Reviewed-by: @gaoxiong1,@anyrenwei
Signed-off-by: @anyrenwei
This commit is contained in:
mindspore-ci-bot 2021-04-09 17:04:52 +08:00 committed by Gitee
commit 0920239699
6 changed files with 263 additions and 155 deletions

View File

@ -17,33 +17,118 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon') @VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
class LayerNorm(Expander): class LayerNorm(Expander):
"""LayerNorm expander""" """LayerNorm expander"""
def to_frac_z_axis(self, ori_shape, ori_axis):
"""
judge the format is fractal NZ
Parameters
----------
ori_shape: list or tuple
original shape of input
ori_axis: list or tuple
original axis of original shape to operate
Returns
-------
output: list
axis of the fractal Nz shape
"""
frac_z_axis = list(ori_axis)
shape_len = len(ori_shape)
axis_count = len(frac_z_axis)
axis_negative_1 = shape_len - 1
axis_negative_2 = shape_len - 2
for i in range(axis_count):
axis_index = (frac_z_axis[i] + shape_len) % shape_len
if axis_index == axis_negative_1:
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
frac_z_axis[i] = axis_index - 1
frac_z_axis.append(axis_index + 2)
else: # no case cover this branch now
frac_z_axis[i] = axis_index - 1
frac_z_axis.append(axis_index + 2)
elif axis_index == axis_negative_2:
frac_z_axis[i] = axis_index + 1
frac_z_axis.append(axis_index + 2)
else:
frac_z_axis[i] = axis_index
return frac_z_axis
def infer_shape_from_fractalNz(self, fractal):
"get original shape from fractalNz shape"
shape = []
dims = len(fractal)
batch = dims - 4
for i in range(batch):
shape.append(fractal[i])
m = fractal[dims - 3] * fractal[dims - 2]
n = fractal[dims - 4] * fractal[dims - 1]
shape.append(m)
shape.append(n)
return shape
def get_reduced_ori_shape(self, shape, axis):
"get shape after reduced which is based on original shape"
reduced_ori_shape = []
for i, value in enumerate(shape):
if i in axis:
reduced_ori_shape.append(1)
else:
reduced_ori_shape.append(value)
return reduced_ori_shape
def _expand(self, graph_builder): def _expand(self, graph_builder):
input_x, input_gamma, input_beta = self.inputs input_x, input_gamma, input_beta = self.inputs
processor = self.processor
begin_norm_axis = self.attrs['begin_norm_axis'] begin_norm_axis = self.attrs['begin_norm_axis']
epsilon = self.attrs['epsilon'] epsilon = self.attrs['epsilon']
ori_dtype = input_x.dtype
if processor == 'aicore' and ori_dtype == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'})
input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'})
ori_shape_x = input_x.shape
if input_x.data_format == DF.FRAC_NZ:
ori_shape_x = self.infer_shape_from_fractalNz(input_x.shape)
# Calculate the scaling ratio of the average # Calculate the scaling ratio of the average
if begin_norm_axis < 0: if begin_norm_axis < 0:
begin_norm_axis += len(input_x.shape) begin_norm_axis += len(ori_shape_x)
reduce_axis = () reduce_axis = ()
for i, _ in enumerate(input_x.shape): for i, _ in enumerate(ori_shape_x):
if i > begin_norm_axis or i == begin_norm_axis: if i > begin_norm_axis or i == begin_norm_axis:
reduce_axis = reduce_axis + (i,) reduce_axis = reduce_axis + (i,)
reduce_elts = 1.0 reduce_elts = 1.0
for i in reduce_axis: for i in reduce_axis:
reduce_elts *= input_x.shape[i] reduce_elts *= ori_shape_x[i]
if input_x.data_format == DF.FRAC_NZ:
reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis)
ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced
mean_cof = 1.0 / reduce_elts mean_cof = 1.0 / reduce_elts
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) mean_cof_v = graph_builder.value(input_x.dtype, mean_cof)
# Calculate mean # Calculate mean
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) mean_red = graph_builder.emit('ReduceSum', [input_x],
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v]) mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
if input_x.data_format == DF.FRAC_NZ:
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x})
# Calculate variance # Calculate variance
variance_sub = graph_builder.emit('Sub', [input_x, mean]) variance_sub = graph_builder.emit('Sub', [input_x, mean])
@ -51,6 +136,8 @@ class LayerNorm(Expander):
variance_red = graph_builder.emit('ReduceSum', [variance_mul], variance_red = graph_builder.emit('ReduceSum', [variance_mul],
attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
if input_x.data_format == DF.FRAC_NZ:
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x})
# Calculate normalize # Calculate normalize
normalize_sub = graph_builder.emit('Sub', [input_x, mean]) normalize_sub = graph_builder.emit('Sub', [input_x, mean])
@ -60,7 +147,11 @@ class LayerNorm(Expander):
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
# Calculate scale and translate # Calculate scale and translate
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul]) scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma])
res = graph_builder.emit('Add', [scale_mul, input_beta]) res = graph_builder.emit('Add', [scale_mul, input_beta])
if processor == 'aicore' and ori_dtype == 'float16':
res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'})
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'})
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float16'})
return res, mean, variance return res, mean, variance

View File

@ -24,23 +24,33 @@ class LayerNormGrad(Expander):
def _expand(self, graph_builder): def _expand(self, graph_builder):
x, dy, variance, mean, gamma = self.inputs x, dy, variance, mean, gamma = self.inputs
processor = self.processor
begin_norm_axis = self.attrs['begin_norm_axis'] begin_norm_axis = self.attrs['begin_norm_axis']
begin_params_axis = self.attrs['begin_params_axis'] begin_params_axis = self.attrs['begin_params_axis']
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11 epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
ori_dtype = x.dtype
if processor == 'aicore' and ori_dtype == 'float16':
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float32'})
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
if begin_norm_axis < 0: if begin_norm_axis < 0:
begin_norm_axis += len(x.shape) begin_norm_axis += len(x.shape)
if begin_params_axis < 0: if begin_params_axis < 0:
begin_params_axis += len(x.shape) begin_params_axis += len(x.shape)
norm_axis = tuple(range(begin_norm_axis, len(x.shape))) norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
param_axis = tuple(range(0, begin_params_axis)) param_axis = tuple(range(0, begin_params_axis))
reduce_size = 1.0 reduce_size = 1.0
for i in norm_axis: for i in norm_axis:
reduce_size *= x.shape[i] reduce_size *= x.shape[i]
# set some constant val. # set some constant val.
eps = graph_builder.value(x.dtype, epsilon) eps = graph_builder.value(x.dtype, epsilon)
const_one = graph_builder.value(x.dtype, 1.0)
const_neg_half = graph_builder.value(x.dtype, -0.5) const_neg_half = graph_builder.value(x.dtype, -0.5)
const_neg_two = graph_builder.value(x.dtype, -2.0) const_neg_two = graph_builder.value(x.dtype, -2.0)
const_two = graph_builder.value(x.dtype, 2.0) const_two = graph_builder.value(x.dtype, 2.0)
@ -49,42 +59,55 @@ class LayerNormGrad(Expander):
# cal dg db # cal dg db
var_eps = graph_builder.emit('Add', [variance, eps]) var_eps = graph_builder.emit('Add', [variance, eps])
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps]) var_eps_log = graph_builder.emit('Log', [var_eps])
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps]) var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
x_sub_mean = graph_builder.emit('Sub', [x, mean]) x_sub_mean = graph_builder.emit('Sub', [x, mean])
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean]) x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps]) dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False}) dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False}) db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
# cal sum_1 # pd_var
tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps]) tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps]) r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps])
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma]) dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps]) tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul]) padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
# cal sum_2 # pd_mean
sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
# cal sum_3 pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean]) pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) attrs={'reduce_axis': norm_axis, 'keep_dims': True})
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
# cal dx = dx1 + dx2 + dx3 pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two])
sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof])
dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean])
neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps])
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
dx = graph_builder.emit('Add', [dx_tmp, dx_3])
# cal dx
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
if processor == 'aicore' and ori_dtype == 'float16':
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
return dx, dg, db return dx, dg, db

View File

@ -47,6 +47,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimSquare, prim::kPrimSquare,
prim::kPrimGeLUGrad, prim::kPrimGeLUGrad,
prim::kPrimAssignAdd, prim::kPrimAssignAdd,
prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad,
#if ENABLE_D #if ENABLE_D
prim::kPrimTile, prim::kPrimTile,
prim::kPrimSqrtGrad, prim::kPrimSqrtGrad,
@ -67,8 +69,6 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimDropout, prim::kPrimDropout,
prim::kPrimDropoutGrad, prim::kPrimDropoutGrad,
prim::kPrimSoftmax, prim::kPrimSoftmax,
prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad,
prim::kPrimRelu, prim::kPrimRelu,
prim::kPrimReluGrad, prim::kPrimReluGrad,
prim::kPrimSigmoid, prim::kPrimSigmoid,

View File

@ -54,7 +54,6 @@
#include "debug/data_dump/dump_json_parser.h" #include "debug/data_dump/dump_json_parser.h"
#include "debug/tensor_load.h" #include "debug/tensor_load.h"
#include "debug/anf_ir_utils.h" #include "debug/anf_ir_utils.h"
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h" #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
#include "backend/session/ascend_auto_monad.h" #include "backend/session/ascend_auto_monad.h"
#include "debug/data_dump/e2e_dump.h" #include "debug/data_dump/e2e_dump.h"

View File

@ -223,12 +223,12 @@ def test_bert_precision(enable_graph_kernel=False):
# assertion occurs while the loss value, overflow state or loss_scale value is wrong # assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list) loss_value = np.array(callback.loss_list)
assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
if enable_graph_kernel: if enable_graph_kernel:
expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565, expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565,
12.185522, 12.386192] 12.185522, 12.386192]
else: else:
assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656, expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656,
12.407923, 12.631133] 12.407923, 12.631133]
print("loss value: {}".format(loss_value)) print("loss value: {}".format(loss_value))

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,150 +13,145 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import copy
import numpy as np import numpy as np
import pytest import pytest
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
import mindspore.ops.operations as P import mindspore.ops.operations as P
class Net(Cell): class LayerNormNet(nn.Cell):
def __init__(self): def __init__(self, begin_norm_axis, begin_params_axis):
super(Net, self).__init__() super(LayerNormNet, self).__init__()
self.layernorm = P.LayerNorm(1, 1) self.layernorm = P.LayerNorm(begin_norm_axis, begin_params_axis)
def construct(self, x, y, z): def construct(self, x, gamma, beta):
return self.layernorm(x, y, z) return self.layernorm(x, gamma, beta)
class LayerNormGradNet(nn.Cell): class LayerNormGradNet(nn.Cell):
def __init__(self, begin_norm_axis, begin_params_axis): def __init__(self, begin_norm_axis, begin_params_axis):
super(LayerNormGradNet, self).__init__() super(LayerNormGradNet, self).__init__()
self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis) self.layernorm_grad = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
def construct(self, dy, x, var, mean, gamma): def construct(self, dy, x, var, mean, gamma):
return self.norm(dy, x, var, mean, gamma) return self.layernorm_grad(dy, x, var, mean, gamma)
def get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
output = net(x, gamma, beta)
return output
def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis): def get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape) context.set_context(enable_graph_kernel=enable_graph_kernel)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
norm_axis = [i for i in range(begin_norm_axis, len(x.shape))]
param_axis = [i for i in range(0, begin_params_axis)]
num = 1
for i in range(begin_norm_axis, len(x.shape)):
num *= x.shape[i]
mean = np.mean(x, axis=tuple(norm_axis), keepdims=True)
var = np.var(x, axis=tuple(norm_axis), keepdims=True)
gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
keepdims=True)
sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
dx1 = dy * gamma * np.power(var + epsilon, -0.5)
dx2 = sum1 * 2.0 / num * (x - mean)
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
dx = dx1 + dx2 + dx3
return dx, dg, db, mean, var
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
gamma = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
beta = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
shape_x = [2, 3, 4, 3]
begin_norm_axis = 1
in_rank = len(shape_x)
if begin_norm_axis < 0:
norm_axis = begin_norm_axis + in_rank
else:
norm_axis = begin_norm_axis
norm_axes = tuple(range(norm_axis, in_rank))
mean = np.mean(input_x, axis=norm_axes, keepdims=True)
mean_b = np.broadcast_to(mean, shape_x)
diff = input_x - mean_b
square = np.square(diff)
smean = np.mean(square, axis=norm_axes, keepdims=True)
smean_b = np.broadcast_to(smean, shape_x)
meps = smean_b + 1e-5
logs = np.log(meps)
mul = logs * (-0.5)
rsqrt = np.exp(mul)
out = diff * rsqrt
bn = out * gamma + beta
expect = (bn, mean, smean)
net = Net()
net_result = net(Tensor(input_x), Tensor(gamma), Tensor(beta))
if isinstance(net_result, tuple) and len(net_result) == 3:
result = (net_result[0].asnumpy(), net_result[1].asnumpy(), net_result[2].asnumpy())
res0 = np.allclose(expect[0], result[0], rtol=1.e-4, atol=1.e-4, equal_nan=True)
assert res0
res1 = np.allclose(expect[1], result[1], rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res1
res2 = np.allclose(expect[2], result[2], rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res2
else:
assert False
def test_layernormgrad():
np.random.seed(0)
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(4096, 3072).astype(np.float32)
dy_np = np.random.randn(4096, 3072).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 1e-11
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis) net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) output = net(x, dy, var, mean, gamma)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) return output
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
def get_rtol_atol(dtype):
if dtype == np.float16:
return 1.e-3, 1.e-3
return 1.e-4, 1.e-4
def compare_result(expect, output, dtype):
rtol, atol = get_rtol_atol(dtype)
if isinstance(expect, (list, tuple)):
assert isinstance(output, (list, tuple)) and len(expect) == len(output)
expect_list = list(expect)
output_list = list(output)
for e, o in zip(expect_list, output_list):
assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True)
else:
assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True)
def test_layernorm(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
assert 0 <= begin_norm_axis < len(shape)
assert 0 <= begin_params_axis < len(shape)
normalized_shape = shape[begin_params_axis:]
np.random.seed(0)
# input tensors
x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
gamma = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
beta = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
expect = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, False)
output = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, True)
compare_result(expect, output, dtype)
def test_layernorm_grad(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
assert 0 <= begin_norm_axis < len(shape)
assert 0 <= begin_params_axis < len(shape)
norm_axis = [i for i in range(begin_norm_axis, len(shape))]
norm_shape = copy.deepcopy(shape)
for i, _ in enumerate(norm_shape):
if i in norm_axis:
norm_shape[i] = 1
params_shape = shape[begin_params_axis:]
np.random.seed(0)
# input tensors
dy = Tensor(np.random.normal(0, 1, shape).astype(dtype))
x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
var = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
mean = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
gamma = Tensor(np.random.normal(0, 1, params_shape).astype(dtype))
expect = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, False)
output = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, True)
compare_result(expect, output, dtype)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernorm_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_layernorm([4, 32, 32], np.float32, -1, -1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_layernorm_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_layernorm([4, 32, 32], np.float16, -1, -1)
test_layernorm([4, 32, 32], np.float32, -1, -1)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_basic_gpu(): def test_layernorm_grad_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_basic() test_layernorm_grad([4, 32, 32], np.float32, -1, -1)
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_basic()
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_layernormgrad_gpu(): def test_layernorm_grad_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_layernormgrad() test_layernorm_grad([2, 16, 32], np.float16, -1, -1)
test_layernorm_grad([4, 32, 32], np.float32, -1, -1)
def test_layernormgrad_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_layernormgrad()