layernorm gpu/cpu support double

This commit is contained in:
kswang 2022-09-21 11:30:13 +08:00
parent 65c9a8ad19
commit 36e1be579b
13 changed files with 259 additions and 75 deletions

View File

@ -103,10 +103,11 @@ template <typename T>
void LayerNormCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
size_t f_size = sizeof(T);
if (inputs[1]->size != f_size * param_num_ || inputs[2]->size != f_size * param_num_) {
if (inputs[kLayerNormInputGammaIndex]->size != f_size * param_num_ ||
inputs[kLayerNormInputBetaIndex]->size != f_size * param_num_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the product of gamma and beta's shape must be " << param_num_;
}
if (outputs[1]->size != f_size * block_num_ || outputs[2]->size != f_size * block_num_) {
if (outputs[kLayerNormOutputMeanIndex]->size != outputs[kLayerNormOutputVarIndex]->size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the product of mean and var's shape must be " << block_num_;
}
auto x = reinterpret_cast<T *>(inputs[kLayerNormInputXIndex]->addr);
@ -170,7 +171,15 @@ std::vector<std::pair<KernelAttr, LayerNormCpuKernelMod::KernelFunc>> LayerNormC
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&LayerNormCpuKernelMod::LaunchKernel<float>}};
&LayerNormCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&LayerNormCpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> LayerNormCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -213,7 +213,17 @@ std::vector<std::pair<KernelAttr, LayerNormGradCpuKernelMod::KernelFunc>> LayerN
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&LayerNormGradCpuKernelMod::LaunchKernel<float>}};
&LayerNormGradCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&LayerNormGradCpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> LayerNormGradCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -385,3 +385,7 @@ template CUDA_LIB_EXPORT void LayerNormGrad(const int row_dim, const int col_dim
const float epsilon, const half *dy, const half *x, const float *mean,
const float *var, const half *gamma, half *dx, half *dg, half *db,
cudaStream_t stream);
template CUDA_LIB_EXPORT void LayerNormGrad(const int row_dim, const int col_dim, const int param_dim,
const float epsilon, const double *dy, const double *x, const float *mean,
const float *var, const double *gamma, double *dx, double *dg, double *db,
cudaStream_t stream);

View File

@ -229,3 +229,6 @@ template CUDA_LIB_EXPORT void LayerNorm(const int row_dim, const int col_dim, co
template CUDA_LIB_EXPORT void LayerNorm(const int row_dim, const int col_dim, const int param_dim, const float epsilon,
const half *x, const half *gamma, const half *beta, half *y, float *mean,
float *var, cudaStream_t stream);
template CUDA_LIB_EXPORT void LayerNorm(const int row_dim, const int col_dim, const int param_dim, const float epsilon,
const double *x, const double *gamma, const double *beta, double *y,
float *mean, float *var, cudaStream_t stream);

View File

@ -35,7 +35,13 @@ struct DynamicSharedMem<half> {
return addr_half;
}
};
template <>
struct DynamicSharedMem<double> {
__device__ double *addr() {
extern __shared__ double addr_ptr[];
return addr_ptr;
}
};
template <typename T>
CUDA_LIB_EXPORT void LayerNorm(const int outer, const int inner, const int param_dim, const float epsilon, const T *x,
const T *gamma, const T *beta, T *y, float *mean, float *var, cudaStream_t stream);

View File

@ -123,7 +123,15 @@ std::vector<std::pair<KernelAttr, LayerNormGpuKernelMod::KernelFunc>> LayerNormG
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&LayerNormGpuKernelMod::LaunchKernel<float>}};
&LayerNormGpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&LayerNormGpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> LayerNormGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -133,7 +133,17 @@ std::vector<std::pair<KernelAttr, LayerNormGradGpuKernelMod::KernelFunc>> LayerN
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&LayerNormGradGpuKernelMod::LaunchKernel<float>}};
&LayerNormGradGpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&LayerNormGradGpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> LayerNormGradGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -16,6 +16,7 @@
#include "ops/layer_norm.h"
#include "ops/op_utils.h"
#include "utils/ms_context.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
@ -70,7 +71,7 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit
abstract::CheckAxis(op_name, "begin_params_axis", bpa_ptr, -1, SizeToLong(input_rank), "input_x");
// the beta and gama shape must be x_shape[begin_params_axis:]
auto valid_types = {kFloat16, kFloat32};
auto valid_types = {kFloat16, kFloat32, kFloat64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", input_args[x_index]->BuildType(), valid_types, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("gamma_dtype", input_args[gamma_index]->BuildType(), valid_types,
op_name);
@ -112,12 +113,22 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit
}
std::vector<BaseShapePtr> shapes_list = {input_x->BuildShape()};
std::vector<TypePtr> types_list = {input_x->BuildType(), kFloat32, kFloat32};
auto mean_var_shape = CalLayerNormMeanAndVarShape(begin_norm_axis, input_shape->shape());
(void)shapes_list.emplace_back(std::make_shared<abstract::Shape>(mean_var_shape));
(void)shapes_list.emplace_back(std::make_shared<abstract::Shape>(mean_var_shape));
return abstract::MakeAbstract(std::make_shared<abstract::TupleShape>(shapes_list),
std::make_shared<Tuple>(types_list));
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
if (is_ascend) {
std::vector<TypePtr> types_list = {input_x->BuildType(), input_x->BuildType(), input_x->BuildType()};
return abstract::MakeAbstract(std::make_shared<abstract::TupleShape>(shapes_list),
std::make_shared<Tuple>(types_list));
} else {
std::vector<TypePtr> types_list = {input_x->BuildType(), kFloat32, kFloat32};
return abstract::MakeAbstract(std::make_shared<abstract::TupleShape>(shapes_list),
std::make_shared<Tuple>(types_list));
}
}
void LayerNorm::Init(const int64_t begin_norm_axis, const int64_t begin_params_axis, const float epsilon) {

View File

@ -35,7 +35,7 @@ class LayerNormGradNet(nn.Cell):
return self.norm(dy, x, var, mean, gamma)
def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
def layer_norm_grad_np(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
@ -61,7 +61,8 @@ def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_
dx2 = sum1 * 2.0 / num * (x - mean)
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
dx = dx1 + dx2 + dx3
return dx, dg, db, mean, var
ret = (dx, dg, db, mean, var)
return ret
@pytest.mark.level0
@ -74,8 +75,8 @@ def test_layernormgrad0():
dy_np = np.random.randn(4096, 3072).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -101,8 +102,8 @@ def test_layernormgrad1():
dy_np = np.random.randn(640, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -128,8 +129,8 @@ def test_layernormgrad2():
dy_np = np.random.randn(32, 128, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -155,8 +156,8 @@ def test_layernormgrad3():
dy_np = np.random.randn(32, 64).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -181,8 +182,8 @@ def test_layernormgrad4():
dy_np = np.random.randn(32, 64).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -207,8 +208,8 @@ def test_layernormgrad5():
dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -381,8 +382,8 @@ def test_layernormgrad_dynamic_shape():
dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -397,3 +398,35 @@ def test_layernormgrad_dynamic_shape():
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernormgrad_double():
"""
Feature: Test LayerNormGrad double support.
Description: The input x type is double.
Expectation: match to np benchmark.
"""
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(4096, 3072).astype(np.float64)
dy_np = np.random.randn(4096, 3072).astype(np.float64)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float64)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np.astype(np.float32))
mean_ms = Tensor(mean_np.astype(np.float32))
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3)

View File

@ -35,7 +35,7 @@ class LayerNormNet(nn.Cell):
return self.norm(x, gamma, beta)
def LayerNormReference(begin_norm_axis, begin_params_axis, x, gamma, beta):
def layer_norm_np(begin_norm_axis, begin_params_axis, x, gamma, beta):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
@ -46,7 +46,8 @@ def LayerNormReference(begin_norm_axis, begin_params_axis, x, gamma, beta):
gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
beta = beta.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
y = np.subtract(x, mean) / np.sqrt(var + 1e-12) * gamma + beta
return y, mean, var
ret = (y, mean, var)
return ret
@pytest.mark.level0
@ -58,7 +59,7 @@ def test_layernorm0():
x_np = np.random.randn(4096, 3072).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -80,7 +81,7 @@ def test_layernorm1():
x_np = np.random.randn(640, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -102,7 +103,7 @@ def test_layernorm3d_1():
x_np = np.random.randn(32, 128, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -124,7 +125,7 @@ def test_layernorm3d_2():
x_np = np.random.randn(32, 128, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -146,7 +147,7 @@ def test_layernorm2d_2():
x_np = np.random.randn(64, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -167,7 +168,7 @@ def test_layernorm2d_3():
x_np = np.random.randn(128, 128).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -189,7 +190,7 @@ def test_layernorm2d_4():
x_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -359,7 +360,7 @@ def test_layernorm_dynamic_shape():
x_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -371,3 +372,30 @@ def test_layernorm_dynamic_shape():
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernorm_double():
"""
Feature: Test LayerNorm double support.
Description: The input x type is double.
Expectation: match to np benchmark.
"""
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(4096, 3072).astype(np.float64)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float64)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float64)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, atol=1e-4)
assert np.allclose(mean_ms.asnumpy(), mean_np, atol=1e-4)
assert np.allclose(var_ms.asnumpy(), var_np, atol=1e-4)

View File

@ -357,8 +357,8 @@ def test_layernormgradgrad6():
dy_ms = Tensor(dy_np.astype(np.float16))
x_ms = Tensor(x_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float16))
mean_ms = Tensor(mean_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float32))
mean_ms = Tensor(mean_np.astype(np.float32))
gamma_ms = Tensor(gamma_np.astype(np.float16))
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
@ -395,8 +395,8 @@ def test_layernormgradgrad7():
dy_ms = Tensor(dy_np.astype(np.float16))
x_ms = Tensor(x_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float16))
mean_ms = Tensor(mean_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float32))
mean_ms = Tensor(mean_np.astype(np.float32))
gamma_ms = Tensor(gamma_np.astype(np.float16))
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
@ -433,8 +433,8 @@ def test_layernormgradgrad8():
dy_ms = Tensor(dy_np.astype(np.float16))
x_ms = Tensor(x_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float16))
mean_ms = Tensor(mean_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float32))
mean_ms = Tensor(mean_np.astype(np.float32))
gamma_ms = Tensor(gamma_np.astype(np.float16))
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
@ -470,8 +470,8 @@ def test_layernormgradgrad9():
dy_ms = Tensor(dy_np.astype(np.float16))
x_ms = Tensor(x_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float16))
mean_ms = Tensor(mean_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float32))
mean_ms = Tensor(mean_np.astype(np.float32))
gamma_ms = Tensor(gamma_np.astype(np.float16))
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
@ -508,8 +508,8 @@ def test_layernormgradgrad10():
dy_ms = Tensor(dy_np.astype(np.float16))
x_ms = Tensor(x_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float16))
mean_ms = Tensor(mean_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float32))
mean_ms = Tensor(mean_np.astype(np.float32))
gamma_ms = Tensor(gamma_np.astype(np.float16))
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))
@ -546,8 +546,8 @@ def test_layernormgradgrad11():
dy_ms = Tensor(dy_np.astype(np.float16))
x_ms = Tensor(x_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float16))
mean_ms = Tensor(mean_np.astype(np.float16))
var_ms = Tensor(var_np.astype(np.float32))
mean_ms = Tensor(mean_np.astype(np.float32))
gamma_ms = Tensor(gamma_np.astype(np.float16))
grad_dx_ms = Tensor(grad_dx_np.astype(np.float16))
grad_dg_ms = Tensor(grad_dg_np.astype(np.float16))

View File

@ -34,7 +34,7 @@ class LayerNormGradNet(nn.Cell):
return self.norm(dy, x, var, mean, gamma)
def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
def layer_norm_grad_np(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
@ -60,7 +60,8 @@ def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_
dx2 = sum1 * 2.0 / num * (x - mean)
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
dx = dx1 + dx2 + dx3
return dx, dg, db, mean, var
ret = (dx, dg, db, mean, var)
return ret
@pytest.mark.level1
@ -73,8 +74,8 @@ def test_layernormgrad0():
dy_np = np.random.randn(4096, 3072).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -100,8 +101,8 @@ def test_layernormgrad1():
dy_np = np.random.randn(640, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -127,8 +128,8 @@ def test_layernormgrad2():
dy_np = np.random.randn(32, 128, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -154,8 +155,8 @@ def test_layernormgrad3():
dy_np = np.random.randn(32, 64).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -180,8 +181,8 @@ def test_layernormgrad4():
dy_np = np.random.randn(32, 64).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -206,8 +207,8 @@ def test_layernormgrad5():
dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -237,8 +238,8 @@ def test_layernormgrad_dynamic_shape():
dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
@ -253,3 +254,35 @@ def test_layernormgrad_dynamic_shape():
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernormgrad_double():
"""
Feature: Test LayerNormGrad double support.
Description: The input x type is double.
Expectation: match to np benchmark.
"""
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(4096, 3072).astype(np.float64)
dy_np = np.random.randn(4096, 3072).astype(np.float64)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float64)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = layer_norm_grad_np(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np.astype(np.float32))
mean_ms = Tensor(mean_np.astype(np.float32))
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)

View File

@ -34,7 +34,7 @@ class LayerNormNet(nn.Cell):
return self.norm(x, gamma, beta)
def LayerNormReference(begin_norm_axis, begin_params_axis, x, gamma, beta):
def layer_norm_np(begin_norm_axis, begin_params_axis, x, gamma, beta):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
@ -45,7 +45,8 @@ def LayerNormReference(begin_norm_axis, begin_params_axis, x, gamma, beta):
gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
beta = beta.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
y = np.subtract(x, mean) / np.sqrt(var + 1e-12) * gamma + beta
return y, mean, var
ret = (y, mean, var)
return ret
@pytest.mark.level1
@ -58,7 +59,7 @@ def test_layernorm0():
x_np = np.random.randn(4096, 3072).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -81,7 +82,7 @@ def test_layernorm1():
x_np = np.random.randn(640, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -104,7 +105,7 @@ def test_layernorm3d_1():
x_np = np.random.randn(32, 128, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -127,7 +128,7 @@ def test_layernorm3d_2():
x_np = np.random.randn(32, 128, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -150,7 +151,7 @@ def test_layernorm2d_2():
x_np = np.random.randn(64, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -172,7 +173,7 @@ def test_layernorm2d_3():
x_np = np.random.randn(128, 128).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -194,7 +195,7 @@ def test_layernorm2d_4():
x_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -221,7 +222,7 @@ def test_layernorm_dynamic_shape():
x_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
@ -233,3 +234,31 @@ def test_layernorm_dynamic_shape():
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernorm_double():
"""
Feature: Test LayerNorm double support.
Description: The input x type is double.
Expectation: match to np benchmark.
"""
begin_norm_axis = 1
begin_params_axis = 1
np.random.seed(42)
x_np = np.random.randn(4096, 3072).astype(np.float64)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float64)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float64)
y_np, mean_np, var_np = layer_norm_np(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, atol=1e-6)
assert np.allclose(mean_ms.asnumpy(), mean_np, atol=1e-6)
assert np.allclose(var_ms.asnumpy(), var_np, atol=1e-6)