forked from mindspore-Ecosystem/mindspore
adapt for layernorm in ascend
This commit is contained in:
parent
c8107390d1
commit
66d28af79e
|
@ -17,33 +17,118 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
|
||||
class LayerNorm(Expander):
|
||||
"""LayerNorm expander"""
|
||||
|
||||
def to_frac_z_axis(self, ori_shape, ori_axis):
|
||||
"""
|
||||
judge the format is fractal NZ
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ori_shape: list or tuple
|
||||
original shape of input
|
||||
ori_axis: list or tuple
|
||||
original axis of original shape to operate
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: list
|
||||
axis of the fractal Nz shape
|
||||
"""
|
||||
|
||||
frac_z_axis = list(ori_axis)
|
||||
shape_len = len(ori_shape)
|
||||
axis_count = len(frac_z_axis)
|
||||
axis_negative_1 = shape_len - 1
|
||||
axis_negative_2 = shape_len - 2
|
||||
|
||||
for i in range(axis_count):
|
||||
axis_index = (frac_z_axis[i] + shape_len) % shape_len
|
||||
if axis_index == axis_negative_1:
|
||||
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
|
||||
frac_z_axis[i] = axis_index - 1
|
||||
frac_z_axis.append(axis_index + 2)
|
||||
else: # no case cover this branch now
|
||||
frac_z_axis[i] = axis_index - 1
|
||||
frac_z_axis.append(axis_index + 2)
|
||||
elif axis_index == axis_negative_2:
|
||||
frac_z_axis[i] = axis_index + 1
|
||||
frac_z_axis.append(axis_index + 2)
|
||||
else:
|
||||
frac_z_axis[i] = axis_index
|
||||
return frac_z_axis
|
||||
|
||||
def infer_shape_from_fractalNz(self, fractal):
|
||||
"get original shape from fractalNz shape"
|
||||
shape = []
|
||||
dims = len(fractal)
|
||||
batch = dims - 4
|
||||
for i in range(batch):
|
||||
shape.append(fractal[i])
|
||||
|
||||
m = fractal[dims - 3] * fractal[dims - 2]
|
||||
n = fractal[dims - 4] * fractal[dims - 1]
|
||||
shape.append(m)
|
||||
shape.append(n)
|
||||
|
||||
return shape
|
||||
|
||||
def get_reduced_ori_shape(self, shape, axis):
|
||||
"get shape after reduced which is based on original shape"
|
||||
reduced_ori_shape = []
|
||||
for i, value in enumerate(shape):
|
||||
if i in axis:
|
||||
reduced_ori_shape.append(1)
|
||||
else:
|
||||
reduced_ori_shape.append(value)
|
||||
return reduced_ori_shape
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_gamma, input_beta = self.inputs
|
||||
processor = self.processor
|
||||
begin_norm_axis = self.attrs['begin_norm_axis']
|
||||
epsilon = self.attrs['epsilon']
|
||||
|
||||
ori_dtype = input_x.dtype
|
||||
if processor == 'aicore' and ori_dtype == 'float16':
|
||||
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
|
||||
input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'})
|
||||
input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'})
|
||||
|
||||
ori_shape_x = input_x.shape
|
||||
if input_x.data_format == DF.FRAC_NZ:
|
||||
ori_shape_x = self.infer_shape_from_fractalNz(input_x.shape)
|
||||
|
||||
# Calculate the scaling ratio of the average
|
||||
if begin_norm_axis < 0:
|
||||
begin_norm_axis += len(input_x.shape)
|
||||
begin_norm_axis += len(ori_shape_x)
|
||||
|
||||
reduce_axis = ()
|
||||
for i, _ in enumerate(input_x.shape):
|
||||
for i, _ in enumerate(ori_shape_x):
|
||||
if i > begin_norm_axis or i == begin_norm_axis:
|
||||
reduce_axis = reduce_axis + (i,)
|
||||
|
||||
reduce_elts = 1.0
|
||||
for i in reduce_axis:
|
||||
reduce_elts *= input_x.shape[i]
|
||||
reduce_elts *= ori_shape_x[i]
|
||||
|
||||
if input_x.data_format == DF.FRAC_NZ:
|
||||
reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis)
|
||||
ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced
|
||||
|
||||
mean_cof = 1.0 / reduce_elts
|
||||
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof)
|
||||
|
||||
# Calculate mean
|
||||
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
|
||||
mean_red = graph_builder.emit('ReduceSum', [input_x],
|
||||
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
|
||||
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
|
||||
if input_x.data_format == DF.FRAC_NZ:
|
||||
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x})
|
||||
|
||||
# Calculate variance
|
||||
variance_sub = graph_builder.emit('Sub', [input_x, mean])
|
||||
|
@ -51,6 +136,8 @@ class LayerNorm(Expander):
|
|||
variance_red = graph_builder.emit('ReduceSum', [variance_mul],
|
||||
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
|
||||
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
|
||||
if input_x.data_format == DF.FRAC_NZ:
|
||||
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x})
|
||||
|
||||
# Calculate normalize
|
||||
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
|
||||
|
@ -60,7 +147,11 @@ class LayerNorm(Expander):
|
|||
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
|
||||
|
||||
# Calculate scale and translate
|
||||
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
|
||||
scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma])
|
||||
res = graph_builder.emit('Add', [scale_mul, input_beta])
|
||||
|
||||
if processor == 'aicore' and ori_dtype == 'float16':
|
||||
res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'})
|
||||
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'})
|
||||
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float16'})
|
||||
return res, mean, variance
|
||||
|
|
|
@ -24,23 +24,33 @@ class LayerNormGrad(Expander):
|
|||
|
||||
def _expand(self, graph_builder):
|
||||
x, dy, variance, mean, gamma = self.inputs
|
||||
processor = self.processor
|
||||
begin_norm_axis = self.attrs['begin_norm_axis']
|
||||
begin_params_axis = self.attrs['begin_params_axis']
|
||||
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11
|
||||
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
|
||||
|
||||
ori_dtype = x.dtype
|
||||
if processor == 'aicore' and ori_dtype == 'float16':
|
||||
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
|
||||
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
|
||||
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float32'})
|
||||
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
|
||||
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
|
||||
|
||||
if begin_norm_axis < 0:
|
||||
begin_norm_axis += len(x.shape)
|
||||
if begin_params_axis < 0:
|
||||
begin_params_axis += len(x.shape)
|
||||
|
||||
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
|
||||
param_axis = tuple(range(0, begin_params_axis))
|
||||
|
||||
reduce_size = 1.0
|
||||
for i in norm_axis:
|
||||
reduce_size *= x.shape[i]
|
||||
|
||||
# set some constant val.
|
||||
eps = graph_builder.value(x.dtype, epsilon)
|
||||
const_one = graph_builder.value(x.dtype, 1.0)
|
||||
const_neg_half = graph_builder.value(x.dtype, -0.5)
|
||||
const_neg_two = graph_builder.value(x.dtype, -2.0)
|
||||
const_two = graph_builder.value(x.dtype, 2.0)
|
||||
|
@ -49,42 +59,55 @@ class LayerNormGrad(Expander):
|
|||
|
||||
# cal dg db
|
||||
var_eps = graph_builder.emit('Add', [variance, eps])
|
||||
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
|
||||
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
|
||||
var_eps_log = graph_builder.emit('Log', [var_eps])
|
||||
var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
|
||||
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
|
||||
|
||||
x_sub_mean = graph_builder.emit('Sub', [x, mean])
|
||||
|
||||
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
|
||||
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
|
||||
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
||||
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
||||
|
||||
# cal sum_1
|
||||
tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps])
|
||||
r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps])
|
||||
x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps])
|
||||
# pd_var
|
||||
tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
|
||||
r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
|
||||
|
||||
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
|
||||
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps])
|
||||
sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul])
|
||||
sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
|
||||
padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||
padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
|
||||
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
|
||||
|
||||
# cal sum_2
|
||||
sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||
# pd_mean
|
||||
pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
|
||||
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
|
||||
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
|
||||
|
||||
# cal sum_3
|
||||
sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
|
||||
sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||
pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
|
||||
pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
|
||||
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
|
||||
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
|
||||
|
||||
# cal dx = dx1 + dx2 + dx3
|
||||
dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
|
||||
sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two])
|
||||
sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof])
|
||||
dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean])
|
||||
neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps])
|
||||
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
|
||||
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
|
||||
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
|
||||
add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
|
||||
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
|
||||
dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
|
||||
dx = graph_builder.emit('Add', [dx_tmp, dx_3])
|
||||
pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
|
||||
|
||||
# cal dx
|
||||
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
|
||||
|
||||
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
|
||||
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
|
||||
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
|
||||
|
||||
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
|
||||
|
||||
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
|
||||
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
|
||||
|
||||
if processor == 'aicore' and ori_dtype == 'float16':
|
||||
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
|
||||
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
|
||||
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
|
||||
return dx, dg, db
|
||||
|
|
|
@ -47,6 +47,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
|||
prim::kPrimSquare,
|
||||
prim::kPrimGeLUGrad,
|
||||
prim::kPrimAssignAdd,
|
||||
prim::kPrimLayerNorm,
|
||||
prim::kPrimLayerNormGrad,
|
||||
#if ENABLE_D
|
||||
prim::kPrimTile,
|
||||
prim::kPrimSqrtGrad,
|
||||
|
@ -67,8 +69,6 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
|||
prim::kPrimDropout,
|
||||
prim::kPrimDropoutGrad,
|
||||
prim::kPrimSoftmax,
|
||||
prim::kPrimLayerNorm,
|
||||
prim::kPrimLayerNormGrad,
|
||||
prim::kPrimRelu,
|
||||
prim::kPrimReluGrad,
|
||||
prim::kPrimSigmoid,
|
||||
|
|
|
@ -54,7 +54,6 @@
|
|||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#include "debug/tensor_load.h"
|
||||
#include "debug/anf_ir_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
|
||||
#include "backend/session/ascend_auto_monad.h"
|
||||
#include "debug/data_dump/e2e_dump_util.h"
|
||||
|
|
|
@ -223,12 +223,12 @@ def test_bert_precision(enable_graph_kernel=False):
|
|||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
|
||||
|
||||
if enable_graph_kernel:
|
||||
expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565,
|
||||
12.185522, 12.386192]
|
||||
else:
|
||||
assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
|
||||
expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656,
|
||||
12.407923, 12.631133]
|
||||
print("loss value: {}".format(loss_value))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,150 +13,145 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.layernorm = P.LayerNorm(1, 1)
|
||||
class LayerNormNet(nn.Cell):
|
||||
def __init__(self, begin_norm_axis, begin_params_axis):
|
||||
super(LayerNormNet, self).__init__()
|
||||
self.layernorm = P.LayerNorm(begin_norm_axis, begin_params_axis)
|
||||
|
||||
def construct(self, x, y, z):
|
||||
return self.layernorm(x, y, z)
|
||||
def construct(self, x, gamma, beta):
|
||||
return self.layernorm(x, gamma, beta)
|
||||
|
||||
|
||||
class LayerNormGradNet(nn.Cell):
|
||||
def __init__(self, begin_norm_axis, begin_params_axis):
|
||||
super(LayerNormGradNet, self).__init__()
|
||||
self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
|
||||
self.layernorm_grad = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
|
||||
|
||||
def construct(self, dy, x, var, mean, gamma):
|
||||
return self.norm(dy, x, var, mean, gamma)
|
||||
return self.layernorm_grad(dy, x, var, mean, gamma)
|
||||
|
||||
def get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
|
||||
net = LayerNormNet(begin_norm_axis, begin_params_axis)
|
||||
output = net(x, gamma, beta)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
|
||||
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
|
||||
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
|
||||
|
||||
norm_axis = [i for i in range(begin_norm_axis, len(x.shape))]
|
||||
param_axis = [i for i in range(0, begin_params_axis)]
|
||||
num = 1
|
||||
for i in range(begin_norm_axis, len(x.shape)):
|
||||
num *= x.shape[i]
|
||||
|
||||
mean = np.mean(x, axis=tuple(norm_axis), keepdims=True)
|
||||
var = np.var(x, axis=tuple(norm_axis), keepdims=True)
|
||||
gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
|
||||
dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
|
||||
db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
|
||||
|
||||
sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
|
||||
keepdims=True)
|
||||
sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
|
||||
sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
|
||||
|
||||
dx1 = dy * gamma * np.power(var + epsilon, -0.5)
|
||||
dx2 = sum1 * 2.0 / num * (x - mean)
|
||||
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
|
||||
dx = dx1 + dx2 + dx3
|
||||
return dx, dg, db, mean, var
|
||||
|
||||
|
||||
def test_basic():
|
||||
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||
gamma = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
|
||||
beta = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
|
||||
shape_x = [2, 3, 4, 3]
|
||||
begin_norm_axis = 1
|
||||
|
||||
in_rank = len(shape_x)
|
||||
if begin_norm_axis < 0:
|
||||
norm_axis = begin_norm_axis + in_rank
|
||||
else:
|
||||
norm_axis = begin_norm_axis
|
||||
norm_axes = tuple(range(norm_axis, in_rank))
|
||||
mean = np.mean(input_x, axis=norm_axes, keepdims=True)
|
||||
mean_b = np.broadcast_to(mean, shape_x)
|
||||
diff = input_x - mean_b
|
||||
square = np.square(diff)
|
||||
smean = np.mean(square, axis=norm_axes, keepdims=True)
|
||||
smean_b = np.broadcast_to(smean, shape_x)
|
||||
meps = smean_b + 1e-5
|
||||
logs = np.log(meps)
|
||||
mul = logs * (-0.5)
|
||||
rsqrt = np.exp(mul)
|
||||
out = diff * rsqrt
|
||||
bn = out * gamma + beta
|
||||
expect = (bn, mean, smean)
|
||||
|
||||
net = Net()
|
||||
|
||||
net_result = net(Tensor(input_x), Tensor(gamma), Tensor(beta))
|
||||
if isinstance(net_result, tuple) and len(net_result) == 3:
|
||||
result = (net_result[0].asnumpy(), net_result[1].asnumpy(), net_result[2].asnumpy())
|
||||
res0 = np.allclose(expect[0], result[0], rtol=1.e-4, atol=1.e-4, equal_nan=True)
|
||||
assert res0
|
||||
res1 = np.allclose(expect[1], result[1], rtol=1.e-4, atol=1.e-7, equal_nan=True)
|
||||
assert res1
|
||||
res2 = np.allclose(expect[2], result[2], rtol=1.e-4, atol=1.e-7, equal_nan=True)
|
||||
assert res2
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
def test_layernormgrad():
|
||||
np.random.seed(0)
|
||||
begin_norm_axis = 1
|
||||
begin_params_axis = 1
|
||||
x_np = np.random.randn(4096, 3072).astype(np.float32)
|
||||
dy_np = np.random.randn(4096, 3072).astype(np.float32)
|
||||
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
|
||||
epsilon = 1e-11
|
||||
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
|
||||
begin_params_axis)
|
||||
|
||||
dy_ms = Tensor(dy_np)
|
||||
x_ms = Tensor(x_np)
|
||||
var_ms = Tensor(var_np)
|
||||
mean_ms = Tensor(mean_np)
|
||||
gamma_ms = Tensor(gamma_np)
|
||||
def get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
|
||||
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
|
||||
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
|
||||
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
|
||||
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
|
||||
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
|
||||
output = net(x, dy, var, mean, gamma)
|
||||
|
||||
return output
|
||||
|
||||
def get_rtol_atol(dtype):
|
||||
if dtype == np.float16:
|
||||
return 1.e-3, 1.e-3
|
||||
return 1.e-4, 1.e-4
|
||||
|
||||
|
||||
def compare_result(expect, output, dtype):
|
||||
rtol, atol = get_rtol_atol(dtype)
|
||||
if isinstance(expect, (list, tuple)):
|
||||
assert isinstance(output, (list, tuple)) and len(expect) == len(output)
|
||||
expect_list = list(expect)
|
||||
output_list = list(output)
|
||||
for e, o in zip(expect_list, output_list):
|
||||
assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True)
|
||||
else:
|
||||
assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True)
|
||||
|
||||
|
||||
def test_layernorm(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
|
||||
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
|
||||
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
|
||||
assert 0 <= begin_norm_axis < len(shape)
|
||||
assert 0 <= begin_params_axis < len(shape)
|
||||
normalized_shape = shape[begin_params_axis:]
|
||||
|
||||
np.random.seed(0)
|
||||
# input tensors
|
||||
x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
|
||||
gamma = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
|
||||
beta = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
|
||||
|
||||
expect = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, False)
|
||||
output = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, True)
|
||||
|
||||
compare_result(expect, output, dtype)
|
||||
|
||||
|
||||
def test_layernorm_grad(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
|
||||
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
|
||||
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
|
||||
assert 0 <= begin_norm_axis < len(shape)
|
||||
assert 0 <= begin_params_axis < len(shape)
|
||||
|
||||
norm_axis = [i for i in range(begin_norm_axis, len(shape))]
|
||||
norm_shape = copy.deepcopy(shape)
|
||||
for i, _ in enumerate(norm_shape):
|
||||
if i in norm_axis:
|
||||
norm_shape[i] = 1
|
||||
params_shape = shape[begin_params_axis:]
|
||||
|
||||
np.random.seed(0)
|
||||
# input tensors
|
||||
dy = Tensor(np.random.normal(0, 1, shape).astype(dtype))
|
||||
x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
|
||||
var = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
|
||||
mean = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
|
||||
gamma = Tensor(np.random.normal(0, 1, params_shape).astype(dtype))
|
||||
|
||||
expect = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, False)
|
||||
output = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, True)
|
||||
|
||||
compare_result(expect, output, dtype)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernorm_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_layernorm([4, 32, 32], np.float32, -1, -1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernorm_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_layernorm([4, 32, 32], np.float16, -1, -1)
|
||||
test_layernorm([4, 32, 32], np.float32, -1, -1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_basic_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
test_basic()
|
||||
|
||||
|
||||
def test_basic_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||
test_basic()
|
||||
def test_layernorm_grad_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_layernorm_grad([4, 32, 32], np.float32, -1, -1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernormgrad_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
test_layernormgrad()
|
||||
|
||||
|
||||
def test_layernormgrad_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||
test_layernormgrad()
|
||||
def test_layernorm_grad_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_layernorm_grad([2, 16, 32], np.float16, -1, -1)
|
||||
test_layernorm_grad([4, 32, 32], np.float32, -1, -1)
|
||||
|
|
Loading…
Reference in New Issue