From 66d28af79e06fd61900c34755c3351b3f79c07d1 Mon Sep 17 00:00:00 2001 From: wenfangpei Date: Mon, 29 Mar 2021 15:30:30 +0800 Subject: [PATCH] adapt for layernorm in ascend --- .../graph_kernel/expanders/layernorm.py | 101 +++++++- .../graph_kernel/expanders/layernorm_grad.py | 81 ++++--- .../graph_kernel/graph_kernel_expander.cc | 4 +- .../ccsrc/backend/session/ascend_session.cc | 1 - .../bert_precision/test_bert_tdt_lossscale.py | 2 +- tests/st/ops/graph_kernel/test_layernorm.py | 229 +++++++++--------- 6 files changed, 263 insertions(+), 155 deletions(-) diff --git a/mindspore/_extends/graph_kernel/expanders/layernorm.py b/mindspore/_extends/graph_kernel/expanders/layernorm.py index 17955c293b3..2e7ff18500e 100644 --- a/mindspore/_extends/graph_kernel/expanders/layernorm.py +++ b/mindspore/_extends/graph_kernel/expanders/layernorm.py @@ -17,33 +17,118 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF from ._utils import Expander, ExpanderInfoValidator as VLD +@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon') class LayerNorm(Expander): """LayerNorm expander""" + def to_frac_z_axis(self, ori_shape, ori_axis): + """ + judge the format is fractal NZ + + Parameters + ---------- + ori_shape: list or tuple + original shape of input + ori_axis: list or tuple + original axis of original shape to operate + + Returns + ------- + output: list + axis of the fractal Nz shape + """ + + frac_z_axis = list(ori_axis) + shape_len = len(ori_shape) + axis_count = len(frac_z_axis) + axis_negative_1 = shape_len - 1 + axis_negative_2 = shape_len - 2 + + for i in range(axis_count): + axis_index = (frac_z_axis[i] + shape_len) % shape_len + if axis_index == axis_negative_1: + if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3] + frac_z_axis[i] = axis_index - 1 + frac_z_axis.append(axis_index + 2) + else: # no case cover this branch now + frac_z_axis[i] = axis_index - 1 + frac_z_axis.append(axis_index + 2) + elif axis_index == axis_negative_2: + frac_z_axis[i] = axis_index + 1 + frac_z_axis.append(axis_index + 2) + else: + frac_z_axis[i] = axis_index + return frac_z_axis + + def infer_shape_from_fractalNz(self, fractal): + "get original shape from fractalNz shape" + shape = [] + dims = len(fractal) + batch = dims - 4 + for i in range(batch): + shape.append(fractal[i]) + + m = fractal[dims - 3] * fractal[dims - 2] + n = fractal[dims - 4] * fractal[dims - 1] + shape.append(m) + shape.append(n) + + return shape + + def get_reduced_ori_shape(self, shape, axis): + "get shape after reduced which is based on original shape" + reduced_ori_shape = [] + for i, value in enumerate(shape): + if i in axis: + reduced_ori_shape.append(1) + else: + reduced_ori_shape.append(value) + return reduced_ori_shape + def _expand(self, graph_builder): input_x, input_gamma, input_beta = self.inputs + processor = self.processor begin_norm_axis = self.attrs['begin_norm_axis'] epsilon = self.attrs['epsilon'] + ori_dtype = input_x.dtype + if processor == 'aicore' and ori_dtype == 'float16': + input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) + input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'}) + input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'}) + + ori_shape_x = input_x.shape + if input_x.data_format == DF.FRAC_NZ: + ori_shape_x = self.infer_shape_from_fractalNz(input_x.shape) + # Calculate the scaling ratio of the average if begin_norm_axis < 0: - begin_norm_axis += len(input_x.shape) + begin_norm_axis += len(ori_shape_x) + reduce_axis = () - for i, _ in enumerate(input_x.shape): + for i, _ in enumerate(ori_shape_x): if i > begin_norm_axis or i == begin_norm_axis: reduce_axis = reduce_axis + (i,) reduce_elts = 1.0 for i in reduce_axis: - reduce_elts *= input_x.shape[i] + reduce_elts *= ori_shape_x[i] + + if input_x.data_format == DF.FRAC_NZ: + reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis) + ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced + mean_cof = 1.0 / reduce_elts mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) # Calculate mean - mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) + mean_red = graph_builder.emit('ReduceSum', [input_x], + attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) mean = graph_builder.emit('Mul', [mean_red, mean_cof_v]) + if input_x.data_format == DF.FRAC_NZ: + mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x}) # Calculate variance variance_sub = graph_builder.emit('Sub', [input_x, mean]) @@ -51,6 +136,8 @@ class LayerNorm(Expander): variance_red = graph_builder.emit('ReduceSum', [variance_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) + if input_x.data_format == DF.FRAC_NZ: + variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x}) # Calculate normalize normalize_sub = graph_builder.emit('Sub', [input_x, mean]) @@ -60,7 +147,11 @@ class LayerNorm(Expander): normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) # Calculate scale and translate - scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul]) + scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma]) res = graph_builder.emit('Add', [scale_mul, input_beta]) + if processor == 'aicore' and ori_dtype == 'float16': + res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'}) + mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'}) + variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float16'}) return res, mean, variance diff --git a/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py b/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py index 1019ea177e6..2ae7078bdc3 100644 --- a/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py @@ -24,23 +24,33 @@ class LayerNormGrad(Expander): def _expand(self, graph_builder): x, dy, variance, mean, gamma = self.inputs + processor = self.processor begin_norm_axis = self.attrs['begin_norm_axis'] begin_params_axis = self.attrs['begin_params_axis'] - epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11 + epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12 + + ori_dtype = x.dtype + if processor == 'aicore' and ori_dtype == 'float16': + x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'}) + dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'}) + variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float32'}) + mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'}) + gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'}) if begin_norm_axis < 0: begin_norm_axis += len(x.shape) if begin_params_axis < 0: begin_params_axis += len(x.shape) + norm_axis = tuple(range(begin_norm_axis, len(x.shape))) param_axis = tuple(range(0, begin_params_axis)) + reduce_size = 1.0 for i in norm_axis: reduce_size *= x.shape[i] # set some constant val. eps = graph_builder.value(x.dtype, epsilon) - const_one = graph_builder.value(x.dtype, 1.0) const_neg_half = graph_builder.value(x.dtype, -0.5) const_neg_two = graph_builder.value(x.dtype, -2.0) const_two = graph_builder.value(x.dtype, 2.0) @@ -49,42 +59,55 @@ class LayerNormGrad(Expander): # cal dg db var_eps = graph_builder.emit('Add', [variance, eps]) - sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps]) - rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps]) + var_eps_log = graph_builder.emit('Log', [var_eps]) + var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half]) + rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul]) + x_sub_mean = graph_builder.emit('Sub', [x, mean]) + x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean]) dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps]) dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False}) db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False}) - # cal sum_1 - tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps]) - r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps]) - x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps]) + # pd_var + tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps]) + r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps]) + dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma]) - tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps]) - sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul]) - sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean]) + padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps]) + pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half]) - # cal sum_2 - sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + # pd_mean + pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma], + attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one]) + pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum]) - # cal sum_3 - sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean]) - sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean]) + pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1], + attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof]) + pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var]) - # cal dx = dx1 + dx2 + dx3 - dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps]) - sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two]) - sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof]) - dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean]) - neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps]) - neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2]) - sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3]) - mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3]) - add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3]) - dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof]) - dx_tmp = graph_builder.emit('Add', [dx_1, dx_2]) - dx = graph_builder.emit('Add', [dx_tmp, dx_3]) + pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2]) + # cal dx + pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps]) + + pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean]) + pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two]) + pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof]) + + pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof]) + + dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2]) + dx = graph_builder.emit('Add', [dx_tmp, pd_x_3]) + + if processor == 'aicore' and ori_dtype == 'float16': + dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) + dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'}) + db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'}) return dx, dg, db diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 0bc97b6389c..266e8a206d0 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -47,6 +47,8 @@ std::vector GetExpandOps() { prim::kPrimSquare, prim::kPrimGeLUGrad, prim::kPrimAssignAdd, + prim::kPrimLayerNorm, + prim::kPrimLayerNormGrad, #if ENABLE_D prim::kPrimTile, prim::kPrimSqrtGrad, @@ -67,8 +69,6 @@ std::vector GetExpandOps() { prim::kPrimDropout, prim::kPrimDropoutGrad, prim::kPrimSoftmax, - prim::kPrimLayerNorm, - prim::kPrimLayerNormGrad, prim::kPrimRelu, prim::kPrimReluGrad, prim::kPrimSigmoid, diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 741f298093f..3ee7e52c9c8 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -54,7 +54,6 @@ #include "debug/data_dump/dump_json_parser.h" #include "debug/tensor_load.h" #include "debug/anf_ir_utils.h" -#include "backend/optimizer/graph_kernel/shape_ops_splitter.h" #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h" #include "backend/session/ascend_auto_monad.h" #include "debug/data_dump/e2e_dump_util.h" diff --git a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py index 53e49b58821..d23939795cf 100644 --- a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py @@ -223,12 +223,12 @@ def test_bert_precision(enable_graph_kernel=False): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - assert np.allclose(loss_value[0], 12.2066, 0, 0.0005) if enable_graph_kernel: expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565, 12.185522, 12.386192] else: + assert np.allclose(loss_value[0], 12.2066, 0, 0.0005) expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656, 12.407923, 12.631133] print("loss value: {}".format(loss_value)) diff --git a/tests/st/ops/graph_kernel/test_layernorm.py b/tests/st/ops/graph_kernel/test_layernorm.py index bc18bda9e5e..aea6a510295 100644 --- a/tests/st/ops/graph_kernel/test_layernorm.py +++ b/tests/st/ops/graph_kernel/test_layernorm.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,150 +13,145 @@ # limitations under the License. # ============================================================================ +import copy import numpy as np import pytest import mindspore.context as context from mindspore import Tensor -from mindspore.nn import Cell import mindspore.nn as nn from mindspore.ops.operations import _grad_ops as G import mindspore.ops.operations as P -class Net(Cell): - def __init__(self): - super(Net, self).__init__() - self.layernorm = P.LayerNorm(1, 1) +class LayerNormNet(nn.Cell): + def __init__(self, begin_norm_axis, begin_params_axis): + super(LayerNormNet, self).__init__() + self.layernorm = P.LayerNorm(begin_norm_axis, begin_params_axis) - def construct(self, x, y, z): - return self.layernorm(x, y, z) + def construct(self, x, gamma, beta): + return self.layernorm(x, gamma, beta) class LayerNormGradNet(nn.Cell): def __init__(self, begin_norm_axis, begin_params_axis): super(LayerNormGradNet, self).__init__() - self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis) + self.layernorm_grad = G.LayerNormGrad(begin_norm_axis, begin_params_axis) def construct(self, dy, x, var, mean, gamma): - return self.norm(dy, x, var, mean, gamma) + return self.layernorm_grad(dy, x, var, mean, gamma) + +def get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + + net = LayerNormNet(begin_norm_axis, begin_params_axis) + output = net(x, gamma, beta) + + return output -def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis): - begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape) - begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape) - - norm_axis = [i for i in range(begin_norm_axis, len(x.shape))] - param_axis = [i for i in range(0, begin_params_axis)] - num = 1 - for i in range(begin_norm_axis, len(x.shape)): - num *= x.shape[i] - - mean = np.mean(x, axis=tuple(norm_axis), keepdims=True) - var = np.var(x, axis=tuple(norm_axis), keepdims=True) - gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:])) - dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True) - db = np.sum(dy, axis=tuple(param_axis), keepdims=True) - - sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis), - keepdims=True) - sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True) - sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True) - - dx1 = dy * gamma * np.power(var + epsilon, -0.5) - dx2 = sum1 * 2.0 / num * (x - mean) - dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num) - dx = dx1 + dx2 + dx3 - return dx, dg, db, mean, var - - -def test_basic(): - input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) - gamma = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32) - beta = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32) - shape_x = [2, 3, 4, 3] - begin_norm_axis = 1 - - in_rank = len(shape_x) - if begin_norm_axis < 0: - norm_axis = begin_norm_axis + in_rank - else: - norm_axis = begin_norm_axis - norm_axes = tuple(range(norm_axis, in_rank)) - mean = np.mean(input_x, axis=norm_axes, keepdims=True) - mean_b = np.broadcast_to(mean, shape_x) - diff = input_x - mean_b - square = np.square(diff) - smean = np.mean(square, axis=norm_axes, keepdims=True) - smean_b = np.broadcast_to(smean, shape_x) - meps = smean_b + 1e-5 - logs = np.log(meps) - mul = logs * (-0.5) - rsqrt = np.exp(mul) - out = diff * rsqrt - bn = out * gamma + beta - expect = (bn, mean, smean) - - net = Net() - - net_result = net(Tensor(input_x), Tensor(gamma), Tensor(beta)) - if isinstance(net_result, tuple) and len(net_result) == 3: - result = (net_result[0].asnumpy(), net_result[1].asnumpy(), net_result[2].asnumpy()) - res0 = np.allclose(expect[0], result[0], rtol=1.e-4, atol=1.e-4, equal_nan=True) - assert res0 - res1 = np.allclose(expect[1], result[1], rtol=1.e-4, atol=1.e-7, equal_nan=True) - assert res1 - res2 = np.allclose(expect[2], result[2], rtol=1.e-4, atol=1.e-7, equal_nan=True) - assert res2 - else: - assert False - - -def test_layernormgrad(): - np.random.seed(0) - begin_norm_axis = 1 - begin_params_axis = 1 - x_np = np.random.randn(4096, 3072).astype(np.float32) - dy_np = np.random.randn(4096, 3072).astype(np.float32) - gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) - epsilon = 1e-11 - dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, - begin_params_axis) - - dy_ms = Tensor(dy_np) - x_ms = Tensor(x_np) - var_ms = Tensor(var_np) - mean_ms = Tensor(mean_np) - gamma_ms = Tensor(gamma_np) +def get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) net = LayerNormGradNet(begin_norm_axis, begin_params_axis) - dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) - assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) - assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) - assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + output = net(x, dy, var, mean, gamma) + + return output + +def get_rtol_atol(dtype): + if dtype == np.float16: + return 1.e-3, 1.e-3 + return 1.e-4, 1.e-4 + + +def compare_result(expect, output, dtype): + rtol, atol = get_rtol_atol(dtype) + if isinstance(expect, (list, tuple)): + assert isinstance(output, (list, tuple)) and len(expect) == len(output) + expect_list = list(expect) + output_list = list(output) + for e, o in zip(expect_list, output_list): + assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True) + else: + assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True) + + +def test_layernorm(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1): + begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape) + begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape) + assert 0 <= begin_norm_axis < len(shape) + assert 0 <= begin_params_axis < len(shape) + normalized_shape = shape[begin_params_axis:] + + np.random.seed(0) + # input tensors + x = Tensor(np.random.normal(0, 1, shape).astype(dtype)) + gamma = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype)) + beta = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype)) + + expect = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, False) + output = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, True) + + compare_result(expect, output, dtype) + + +def test_layernorm_grad(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1): + begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape) + begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape) + assert 0 <= begin_norm_axis < len(shape) + assert 0 <= begin_params_axis < len(shape) + + norm_axis = [i for i in range(begin_norm_axis, len(shape))] + norm_shape = copy.deepcopy(shape) + for i, _ in enumerate(norm_shape): + if i in norm_axis: + norm_shape[i] = 1 + params_shape = shape[begin_params_axis:] + + np.random.seed(0) + # input tensors + dy = Tensor(np.random.normal(0, 1, shape).astype(dtype)) + x = Tensor(np.random.normal(0, 1, shape).astype(dtype)) + var = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype)) + mean = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype)) + gamma = Tensor(np.random.normal(0, 1, params_shape).astype(dtype)) + + expect = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, False) + output = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, True) + + compare_result(expect, output, dtype) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_layernorm([4, 32, 32], np.float32, -1, -1) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_layernorm_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_layernorm([4, 32, 32], np.float16, -1, -1) + test_layernorm([4, 32, 32], np.float32, -1, -1) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_basic_gpu(): - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") - test_basic() - - -def test_basic_ascend(): - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") - test_basic() +def test_layernorm_grad_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_layernorm_grad([4, 32, 32], np.float32, -1, -1) @pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard -def test_layernormgrad_gpu(): - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") - test_layernormgrad() - - -def test_layernormgrad_ascend(): - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") - test_layernormgrad() +def test_layernorm_grad_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_layernorm_grad([2, 16, 32], np.float16, -1, -1) + test_layernorm_grad([4, 32, 32], np.float32, -1, -1)