forked from mindspore-Ecosystem/mindspore
!13475 [GraphKernel]adapt for layernorm in ascend
From: @wenfangpei Reviewed-by: @gaoxiong1,@anyrenwei Signed-off-by: @anyrenwei
This commit is contained in:
commit
0920239699
|
@ -17,33 +17,118 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||||
|
|
||||||
|
|
||||||
|
@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT)
|
||||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||||
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
|
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
|
||||||
class LayerNorm(Expander):
|
class LayerNorm(Expander):
|
||||||
"""LayerNorm expander"""
|
"""LayerNorm expander"""
|
||||||
|
|
||||||
|
def to_frac_z_axis(self, ori_shape, ori_axis):
|
||||||
|
"""
|
||||||
|
judge the format is fractal NZ
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ori_shape: list or tuple
|
||||||
|
original shape of input
|
||||||
|
ori_axis: list or tuple
|
||||||
|
original axis of original shape to operate
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
output: list
|
||||||
|
axis of the fractal Nz shape
|
||||||
|
"""
|
||||||
|
|
||||||
|
frac_z_axis = list(ori_axis)
|
||||||
|
shape_len = len(ori_shape)
|
||||||
|
axis_count = len(frac_z_axis)
|
||||||
|
axis_negative_1 = shape_len - 1
|
||||||
|
axis_negative_2 = shape_len - 2
|
||||||
|
|
||||||
|
for i in range(axis_count):
|
||||||
|
axis_index = (frac_z_axis[i] + shape_len) % shape_len
|
||||||
|
if axis_index == axis_negative_1:
|
||||||
|
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
|
||||||
|
frac_z_axis[i] = axis_index - 1
|
||||||
|
frac_z_axis.append(axis_index + 2)
|
||||||
|
else: # no case cover this branch now
|
||||||
|
frac_z_axis[i] = axis_index - 1
|
||||||
|
frac_z_axis.append(axis_index + 2)
|
||||||
|
elif axis_index == axis_negative_2:
|
||||||
|
frac_z_axis[i] = axis_index + 1
|
||||||
|
frac_z_axis.append(axis_index + 2)
|
||||||
|
else:
|
||||||
|
frac_z_axis[i] = axis_index
|
||||||
|
return frac_z_axis
|
||||||
|
|
||||||
|
def infer_shape_from_fractalNz(self, fractal):
|
||||||
|
"get original shape from fractalNz shape"
|
||||||
|
shape = []
|
||||||
|
dims = len(fractal)
|
||||||
|
batch = dims - 4
|
||||||
|
for i in range(batch):
|
||||||
|
shape.append(fractal[i])
|
||||||
|
|
||||||
|
m = fractal[dims - 3] * fractal[dims - 2]
|
||||||
|
n = fractal[dims - 4] * fractal[dims - 1]
|
||||||
|
shape.append(m)
|
||||||
|
shape.append(n)
|
||||||
|
|
||||||
|
return shape
|
||||||
|
|
||||||
|
def get_reduced_ori_shape(self, shape, axis):
|
||||||
|
"get shape after reduced which is based on original shape"
|
||||||
|
reduced_ori_shape = []
|
||||||
|
for i, value in enumerate(shape):
|
||||||
|
if i in axis:
|
||||||
|
reduced_ori_shape.append(1)
|
||||||
|
else:
|
||||||
|
reduced_ori_shape.append(value)
|
||||||
|
return reduced_ori_shape
|
||||||
|
|
||||||
def _expand(self, graph_builder):
|
def _expand(self, graph_builder):
|
||||||
input_x, input_gamma, input_beta = self.inputs
|
input_x, input_gamma, input_beta = self.inputs
|
||||||
|
processor = self.processor
|
||||||
begin_norm_axis = self.attrs['begin_norm_axis']
|
begin_norm_axis = self.attrs['begin_norm_axis']
|
||||||
epsilon = self.attrs['epsilon']
|
epsilon = self.attrs['epsilon']
|
||||||
|
|
||||||
|
ori_dtype = input_x.dtype
|
||||||
|
if processor == 'aicore' and ori_dtype == 'float16':
|
||||||
|
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
|
||||||
|
input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'})
|
||||||
|
input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'})
|
||||||
|
|
||||||
|
ori_shape_x = input_x.shape
|
||||||
|
if input_x.data_format == DF.FRAC_NZ:
|
||||||
|
ori_shape_x = self.infer_shape_from_fractalNz(input_x.shape)
|
||||||
|
|
||||||
# Calculate the scaling ratio of the average
|
# Calculate the scaling ratio of the average
|
||||||
if begin_norm_axis < 0:
|
if begin_norm_axis < 0:
|
||||||
begin_norm_axis += len(input_x.shape)
|
begin_norm_axis += len(ori_shape_x)
|
||||||
|
|
||||||
reduce_axis = ()
|
reduce_axis = ()
|
||||||
for i, _ in enumerate(input_x.shape):
|
for i, _ in enumerate(ori_shape_x):
|
||||||
if i > begin_norm_axis or i == begin_norm_axis:
|
if i > begin_norm_axis or i == begin_norm_axis:
|
||||||
reduce_axis = reduce_axis + (i,)
|
reduce_axis = reduce_axis + (i,)
|
||||||
|
|
||||||
reduce_elts = 1.0
|
reduce_elts = 1.0
|
||||||
for i in reduce_axis:
|
for i in reduce_axis:
|
||||||
reduce_elts *= input_x.shape[i]
|
reduce_elts *= ori_shape_x[i]
|
||||||
|
|
||||||
|
if input_x.data_format == DF.FRAC_NZ:
|
||||||
|
reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis)
|
||||||
|
ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced
|
||||||
|
|
||||||
mean_cof = 1.0 / reduce_elts
|
mean_cof = 1.0 / reduce_elts
|
||||||
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof)
|
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof)
|
||||||
|
|
||||||
# Calculate mean
|
# Calculate mean
|
||||||
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
|
mean_red = graph_builder.emit('ReduceSum', [input_x],
|
||||||
|
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
|
||||||
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
|
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
|
||||||
|
if input_x.data_format == DF.FRAC_NZ:
|
||||||
|
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x})
|
||||||
|
|
||||||
# Calculate variance
|
# Calculate variance
|
||||||
variance_sub = graph_builder.emit('Sub', [input_x, mean])
|
variance_sub = graph_builder.emit('Sub', [input_x, mean])
|
||||||
|
@ -51,6 +136,8 @@ class LayerNorm(Expander):
|
||||||
variance_red = graph_builder.emit('ReduceSum', [variance_mul],
|
variance_red = graph_builder.emit('ReduceSum', [variance_mul],
|
||||||
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
|
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
|
||||||
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
|
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
|
||||||
|
if input_x.data_format == DF.FRAC_NZ:
|
||||||
|
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x})
|
||||||
|
|
||||||
# Calculate normalize
|
# Calculate normalize
|
||||||
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
|
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
|
||||||
|
@ -60,7 +147,11 @@ class LayerNorm(Expander):
|
||||||
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
|
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
|
||||||
|
|
||||||
# Calculate scale and translate
|
# Calculate scale and translate
|
||||||
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
|
scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma])
|
||||||
res = graph_builder.emit('Add', [scale_mul, input_beta])
|
res = graph_builder.emit('Add', [scale_mul, input_beta])
|
||||||
|
|
||||||
|
if processor == 'aicore' and ori_dtype == 'float16':
|
||||||
|
res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'})
|
||||||
|
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'})
|
||||||
|
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float16'})
|
||||||
return res, mean, variance
|
return res, mean, variance
|
||||||
|
|
|
@ -24,23 +24,33 @@ class LayerNormGrad(Expander):
|
||||||
|
|
||||||
def _expand(self, graph_builder):
|
def _expand(self, graph_builder):
|
||||||
x, dy, variance, mean, gamma = self.inputs
|
x, dy, variance, mean, gamma = self.inputs
|
||||||
|
processor = self.processor
|
||||||
begin_norm_axis = self.attrs['begin_norm_axis']
|
begin_norm_axis = self.attrs['begin_norm_axis']
|
||||||
begin_params_axis = self.attrs['begin_params_axis']
|
begin_params_axis = self.attrs['begin_params_axis']
|
||||||
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11
|
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
|
||||||
|
|
||||||
|
ori_dtype = x.dtype
|
||||||
|
if processor == 'aicore' and ori_dtype == 'float16':
|
||||||
|
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
|
||||||
|
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
|
||||||
|
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float32'})
|
||||||
|
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
|
||||||
|
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
|
||||||
|
|
||||||
if begin_norm_axis < 0:
|
if begin_norm_axis < 0:
|
||||||
begin_norm_axis += len(x.shape)
|
begin_norm_axis += len(x.shape)
|
||||||
if begin_params_axis < 0:
|
if begin_params_axis < 0:
|
||||||
begin_params_axis += len(x.shape)
|
begin_params_axis += len(x.shape)
|
||||||
|
|
||||||
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
|
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
|
||||||
param_axis = tuple(range(0, begin_params_axis))
|
param_axis = tuple(range(0, begin_params_axis))
|
||||||
|
|
||||||
reduce_size = 1.0
|
reduce_size = 1.0
|
||||||
for i in norm_axis:
|
for i in norm_axis:
|
||||||
reduce_size *= x.shape[i]
|
reduce_size *= x.shape[i]
|
||||||
|
|
||||||
# set some constant val.
|
# set some constant val.
|
||||||
eps = graph_builder.value(x.dtype, epsilon)
|
eps = graph_builder.value(x.dtype, epsilon)
|
||||||
const_one = graph_builder.value(x.dtype, 1.0)
|
|
||||||
const_neg_half = graph_builder.value(x.dtype, -0.5)
|
const_neg_half = graph_builder.value(x.dtype, -0.5)
|
||||||
const_neg_two = graph_builder.value(x.dtype, -2.0)
|
const_neg_two = graph_builder.value(x.dtype, -2.0)
|
||||||
const_two = graph_builder.value(x.dtype, 2.0)
|
const_two = graph_builder.value(x.dtype, 2.0)
|
||||||
|
@ -49,42 +59,55 @@ class LayerNormGrad(Expander):
|
||||||
|
|
||||||
# cal dg db
|
# cal dg db
|
||||||
var_eps = graph_builder.emit('Add', [variance, eps])
|
var_eps = graph_builder.emit('Add', [variance, eps])
|
||||||
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
|
var_eps_log = graph_builder.emit('Log', [var_eps])
|
||||||
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
|
var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
|
||||||
|
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
|
||||||
|
|
||||||
x_sub_mean = graph_builder.emit('Sub', [x, mean])
|
x_sub_mean = graph_builder.emit('Sub', [x, mean])
|
||||||
|
|
||||||
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
|
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
|
||||||
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
|
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
|
||||||
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
||||||
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
||||||
|
|
||||||
# cal sum_1
|
# pd_var
|
||||||
tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps])
|
tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
|
||||||
r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps])
|
r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
|
||||||
x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps])
|
|
||||||
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
|
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
|
||||||
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps])
|
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
|
||||||
sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul])
|
padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||||
sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
|
||||||
|
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
|
||||||
|
|
||||||
# cal sum_2
|
# pd_mean
|
||||||
sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
|
||||||
|
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||||
|
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
|
||||||
|
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
|
||||||
|
|
||||||
# cal sum_3
|
pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
|
||||||
sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
|
pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
|
||||||
sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||||
|
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
|
||||||
|
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
|
||||||
|
|
||||||
# cal dx = dx1 + dx2 + dx3
|
pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
|
||||||
dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
|
|
||||||
sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two])
|
|
||||||
sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof])
|
|
||||||
dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean])
|
|
||||||
neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps])
|
|
||||||
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
|
|
||||||
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
|
|
||||||
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
|
|
||||||
add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
|
|
||||||
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
|
|
||||||
dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
|
|
||||||
dx = graph_builder.emit('Add', [dx_tmp, dx_3])
|
|
||||||
|
|
||||||
|
# cal dx
|
||||||
|
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
|
||||||
|
|
||||||
|
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
|
||||||
|
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
|
||||||
|
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
|
||||||
|
|
||||||
|
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
|
||||||
|
|
||||||
|
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
|
||||||
|
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
|
||||||
|
|
||||||
|
if processor == 'aicore' and ori_dtype == 'float16':
|
||||||
|
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
|
||||||
|
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
|
||||||
|
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
|
||||||
return dx, dg, db
|
return dx, dg, db
|
||||||
|
|
|
@ -47,6 +47,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
||||||
prim::kPrimSquare,
|
prim::kPrimSquare,
|
||||||
prim::kPrimGeLUGrad,
|
prim::kPrimGeLUGrad,
|
||||||
prim::kPrimAssignAdd,
|
prim::kPrimAssignAdd,
|
||||||
|
prim::kPrimLayerNorm,
|
||||||
|
prim::kPrimLayerNormGrad,
|
||||||
#if ENABLE_D
|
#if ENABLE_D
|
||||||
prim::kPrimTile,
|
prim::kPrimTile,
|
||||||
prim::kPrimSqrtGrad,
|
prim::kPrimSqrtGrad,
|
||||||
|
@ -67,8 +69,6 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
||||||
prim::kPrimDropout,
|
prim::kPrimDropout,
|
||||||
prim::kPrimDropoutGrad,
|
prim::kPrimDropoutGrad,
|
||||||
prim::kPrimSoftmax,
|
prim::kPrimSoftmax,
|
||||||
prim::kPrimLayerNorm,
|
|
||||||
prim::kPrimLayerNormGrad,
|
|
||||||
prim::kPrimRelu,
|
prim::kPrimRelu,
|
||||||
prim::kPrimReluGrad,
|
prim::kPrimReluGrad,
|
||||||
prim::kPrimSigmoid,
|
prim::kPrimSigmoid,
|
||||||
|
|
|
@ -54,7 +54,6 @@
|
||||||
#include "debug/data_dump/dump_json_parser.h"
|
#include "debug/data_dump/dump_json_parser.h"
|
||||||
#include "debug/tensor_load.h"
|
#include "debug/tensor_load.h"
|
||||||
#include "debug/anf_ir_utils.h"
|
#include "debug/anf_ir_utils.h"
|
||||||
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
|
|
||||||
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
|
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
|
||||||
#include "backend/session/ascend_auto_monad.h"
|
#include "backend/session/ascend_auto_monad.h"
|
||||||
#include "debug/data_dump/e2e_dump.h"
|
#include "debug/data_dump/e2e_dump.h"
|
||||||
|
|
|
@ -223,12 +223,12 @@ def test_bert_precision(enable_graph_kernel=False):
|
||||||
|
|
||||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||||
loss_value = np.array(callback.loss_list)
|
loss_value = np.array(callback.loss_list)
|
||||||
assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
|
|
||||||
|
|
||||||
if enable_graph_kernel:
|
if enable_graph_kernel:
|
||||||
expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565,
|
expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565,
|
||||||
12.185522, 12.386192]
|
12.185522, 12.386192]
|
||||||
else:
|
else:
|
||||||
|
assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
|
||||||
expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656,
|
expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656,
|
||||||
12.407923, 12.631133]
|
12.407923, 12.631133]
|
||||||
print("loss value: {}".format(loss_value))
|
print("loss value: {}".format(loss_value))
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,150 +13,145 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.nn import Cell
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
import mindspore.ops.operations as P
|
import mindspore.ops.operations as P
|
||||||
|
|
||||||
|
|
||||||
class Net(Cell):
|
class LayerNormNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self, begin_norm_axis, begin_params_axis):
|
||||||
super(Net, self).__init__()
|
super(LayerNormNet, self).__init__()
|
||||||
self.layernorm = P.LayerNorm(1, 1)
|
self.layernorm = P.LayerNorm(begin_norm_axis, begin_params_axis)
|
||||||
|
|
||||||
def construct(self, x, y, z):
|
def construct(self, x, gamma, beta):
|
||||||
return self.layernorm(x, y, z)
|
return self.layernorm(x, gamma, beta)
|
||||||
|
|
||||||
|
|
||||||
class LayerNormGradNet(nn.Cell):
|
class LayerNormGradNet(nn.Cell):
|
||||||
def __init__(self, begin_norm_axis, begin_params_axis):
|
def __init__(self, begin_norm_axis, begin_params_axis):
|
||||||
super(LayerNormGradNet, self).__init__()
|
super(LayerNormGradNet, self).__init__()
|
||||||
self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
|
self.layernorm_grad = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
|
||||||
|
|
||||||
def construct(self, dy, x, var, mean, gamma):
|
def construct(self, dy, x, var, mean, gamma):
|
||||||
return self.norm(dy, x, var, mean, gamma)
|
return self.layernorm_grad(dy, x, var, mean, gamma)
|
||||||
|
|
||||||
|
def get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
|
||||||
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
|
|
||||||
|
net = LayerNormNet(begin_norm_axis, begin_params_axis)
|
||||||
|
output = net(x, gamma, beta)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
|
def get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
|
||||||
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
|
|
||||||
|
|
||||||
norm_axis = [i for i in range(begin_norm_axis, len(x.shape))]
|
|
||||||
param_axis = [i for i in range(0, begin_params_axis)]
|
|
||||||
num = 1
|
|
||||||
for i in range(begin_norm_axis, len(x.shape)):
|
|
||||||
num *= x.shape[i]
|
|
||||||
|
|
||||||
mean = np.mean(x, axis=tuple(norm_axis), keepdims=True)
|
|
||||||
var = np.var(x, axis=tuple(norm_axis), keepdims=True)
|
|
||||||
gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
|
|
||||||
dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
|
|
||||||
db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
|
|
||||||
|
|
||||||
sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
|
|
||||||
keepdims=True)
|
|
||||||
sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
|
|
||||||
sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
|
|
||||||
|
|
||||||
dx1 = dy * gamma * np.power(var + epsilon, -0.5)
|
|
||||||
dx2 = sum1 * 2.0 / num * (x - mean)
|
|
||||||
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
|
|
||||||
dx = dx1 + dx2 + dx3
|
|
||||||
return dx, dg, db, mean, var
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic():
|
|
||||||
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
|
||||||
gamma = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
|
|
||||||
beta = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
|
|
||||||
shape_x = [2, 3, 4, 3]
|
|
||||||
begin_norm_axis = 1
|
|
||||||
|
|
||||||
in_rank = len(shape_x)
|
|
||||||
if begin_norm_axis < 0:
|
|
||||||
norm_axis = begin_norm_axis + in_rank
|
|
||||||
else:
|
|
||||||
norm_axis = begin_norm_axis
|
|
||||||
norm_axes = tuple(range(norm_axis, in_rank))
|
|
||||||
mean = np.mean(input_x, axis=norm_axes, keepdims=True)
|
|
||||||
mean_b = np.broadcast_to(mean, shape_x)
|
|
||||||
diff = input_x - mean_b
|
|
||||||
square = np.square(diff)
|
|
||||||
smean = np.mean(square, axis=norm_axes, keepdims=True)
|
|
||||||
smean_b = np.broadcast_to(smean, shape_x)
|
|
||||||
meps = smean_b + 1e-5
|
|
||||||
logs = np.log(meps)
|
|
||||||
mul = logs * (-0.5)
|
|
||||||
rsqrt = np.exp(mul)
|
|
||||||
out = diff * rsqrt
|
|
||||||
bn = out * gamma + beta
|
|
||||||
expect = (bn, mean, smean)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
|
|
||||||
net_result = net(Tensor(input_x), Tensor(gamma), Tensor(beta))
|
|
||||||
if isinstance(net_result, tuple) and len(net_result) == 3:
|
|
||||||
result = (net_result[0].asnumpy(), net_result[1].asnumpy(), net_result[2].asnumpy())
|
|
||||||
res0 = np.allclose(expect[0], result[0], rtol=1.e-4, atol=1.e-4, equal_nan=True)
|
|
||||||
assert res0
|
|
||||||
res1 = np.allclose(expect[1], result[1], rtol=1.e-4, atol=1.e-7, equal_nan=True)
|
|
||||||
assert res1
|
|
||||||
res2 = np.allclose(expect[2], result[2], rtol=1.e-4, atol=1.e-7, equal_nan=True)
|
|
||||||
assert res2
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
|
|
||||||
|
|
||||||
def test_layernormgrad():
|
|
||||||
np.random.seed(0)
|
|
||||||
begin_norm_axis = 1
|
|
||||||
begin_params_axis = 1
|
|
||||||
x_np = np.random.randn(4096, 3072).astype(np.float32)
|
|
||||||
dy_np = np.random.randn(4096, 3072).astype(np.float32)
|
|
||||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
|
||||||
epsilon = 1e-11
|
|
||||||
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
|
||||||
begin_params_axis)
|
|
||||||
|
|
||||||
dy_ms = Tensor(dy_np)
|
|
||||||
x_ms = Tensor(x_np)
|
|
||||||
var_ms = Tensor(var_np)
|
|
||||||
mean_ms = Tensor(mean_np)
|
|
||||||
gamma_ms = Tensor(gamma_np)
|
|
||||||
|
|
||||||
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
|
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
|
||||||
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
|
output = net(x, dy, var, mean, gamma)
|
||||||
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
|
|
||||||
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
|
return output
|
||||||
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
|
|
||||||
|
def get_rtol_atol(dtype):
|
||||||
|
if dtype == np.float16:
|
||||||
|
return 1.e-3, 1.e-3
|
||||||
|
return 1.e-4, 1.e-4
|
||||||
|
|
||||||
|
|
||||||
|
def compare_result(expect, output, dtype):
|
||||||
|
rtol, atol = get_rtol_atol(dtype)
|
||||||
|
if isinstance(expect, (list, tuple)):
|
||||||
|
assert isinstance(output, (list, tuple)) and len(expect) == len(output)
|
||||||
|
expect_list = list(expect)
|
||||||
|
output_list = list(output)
|
||||||
|
for e, o in zip(expect_list, output_list):
|
||||||
|
assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True)
|
||||||
|
else:
|
||||||
|
assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
|
||||||
|
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
|
||||||
|
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
|
||||||
|
assert 0 <= begin_norm_axis < len(shape)
|
||||||
|
assert 0 <= begin_params_axis < len(shape)
|
||||||
|
normalized_shape = shape[begin_params_axis:]
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
# input tensors
|
||||||
|
x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
|
||||||
|
gamma = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
|
||||||
|
beta = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
|
||||||
|
|
||||||
|
expect = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, False)
|
||||||
|
output = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, True)
|
||||||
|
|
||||||
|
compare_result(expect, output, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_grad(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
|
||||||
|
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
|
||||||
|
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
|
||||||
|
assert 0 <= begin_norm_axis < len(shape)
|
||||||
|
assert 0 <= begin_params_axis < len(shape)
|
||||||
|
|
||||||
|
norm_axis = [i for i in range(begin_norm_axis, len(shape))]
|
||||||
|
norm_shape = copy.deepcopy(shape)
|
||||||
|
for i, _ in enumerate(norm_shape):
|
||||||
|
if i in norm_axis:
|
||||||
|
norm_shape[i] = 1
|
||||||
|
params_shape = shape[begin_params_axis:]
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
# input tensors
|
||||||
|
dy = Tensor(np.random.normal(0, 1, shape).astype(dtype))
|
||||||
|
x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
|
||||||
|
var = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
|
||||||
|
mean = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
|
||||||
|
gamma = Tensor(np.random.normal(0, 1, params_shape).astype(dtype))
|
||||||
|
|
||||||
|
expect = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, False)
|
||||||
|
output = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, True)
|
||||||
|
|
||||||
|
compare_result(expect, output, dtype)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_layernorm_gpu():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_layernorm([4, 32, 32], np.float32, -1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_layernorm_ascend():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
test_layernorm([4, 32, 32], np.float16, -1, -1)
|
||||||
|
test_layernorm([4, 32, 32], np.float32, -1, -1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
def test_basic_gpu():
|
def test_layernorm_grad_gpu():
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
test_basic()
|
test_layernorm_grad([4, 32, 32], np.float32, -1, -1)
|
||||||
|
|
||||||
|
|
||||||
def test_basic_ascend():
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
|
||||||
test_basic()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
def test_layernormgrad_gpu():
|
def test_layernorm_grad_ascend():
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
test_layernormgrad()
|
test_layernorm_grad([2, 16, 32], np.float16, -1, -1)
|
||||||
|
test_layernorm_grad([4, 32, 32], np.float32, -1, -1)
|
||||||
|
|
||||||
def test_layernormgrad_ascend():
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
|
||||||
test_layernormgrad()
|
|
||||||
|
|
Loading…
Reference in New Issue