!11180 Add CPU Gelu and LayerNorm

From: @zhao_ting_v
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-01-18 14:17:42 +08:00 committed by Gitee
commit 3e662805f8
13 changed files with 994 additions and 0 deletions

View File

@ -76,6 +76,16 @@ void Reciprocal(const T *in, T *out, size_t start, size_t end) {
out[i] = static_cast<T>(1.0 / in[i]);
}
}
template <typename T>
void Gelu(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T x = in[i];
auto double_x = static_cast<T>(x);
T tanh_res = (T)std::tanh(0.7978845608 * (double_x + 0.044715 * double_x * double_x * double_x));
out[i] = x * ((T)1.0 + tanh_res) / (T)2.0;
}
}
} // namespace
void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@ -95,6 +105,8 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = FLOOR;
} else if (kernel_name == prim::kPrimReciprocal->name()) {
operate_type_ = RECIPROCAL;
} else if (kernel_name == prim::kPrimGelu->name()) {
operate_type_ = GELU;
}
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
}
@ -150,6 +162,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
threads.emplace_back(std::thread(Floor<T>, input, output, start, end));
} else if (operate_type_ == RECIPROCAL) {
threads.emplace_back(std::thread(Reciprocal<T>, input, output, start, end));
} else if (operate_type_ == GELU) {
threads.emplace_back(std::thread(Gelu<T>, input, output, start, end));
}
start += once_compute_size;
}

View File

@ -62,6 +62,8 @@ MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -89,6 +89,8 @@ enum OperateType {
GREATER,
GREATEREQUAL,
RECIPROCAL,
GELU,
GELUGRAD,
};
class CPUKernel : public kernel::KernelMod {

View File

@ -78,6 +78,18 @@ void EltWiseGradCPUKernel::TanhGrad(const T *input1, const T *input2, T *out, si
}
}
template <typename T>
void EltWiseGradCPUKernel::GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T x = input2[i];
auto double_x = static_cast<T>(x);
T tanh_res = (T)std::tanh(0.7978845608 * (double_x + 0.044715 * double_x * double_x * double_x));
T mul_right = (T)(0.7978845608 + 0.1070322244 * double_x * double_x);
T y_res = (((T)1.0 + tanh_res) + x * ((T)1.0 - tanh_res * tanh_res) * mul_right) / (T)2.0;
out[i] = input1[i] * y_res;
}
}
void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
@ -93,6 +105,8 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = TANHGRAD;
} else if (kernel_name == "SqrtGrad") {
operate_type_ = SQRTGRAD;
} else if (kernel_name == "GeluGrad") {
operate_type_ = GELUGRAD;
} else {
MS_LOG(EXCEPTION) << "Not support " << kernel_name;
}
@ -172,6 +186,8 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::TanhGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == SQRTGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::SqrtGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == GELUGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::GeluGrad<T>, this, input1, input2, output, start, end));
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}

View File

@ -47,6 +47,8 @@ class EltWiseGradCPUKernel : public CPUKernel {
void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_;
@ -81,6 +83,13 @@ MS_REG_CPU_KERNEL(
TanhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(GeluGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,105 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/layer_norm_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void LayerNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto begin_norm_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_norm_axis");
auto begin_params_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_params_axis");
if (begin_norm_axis < 0) {
begin_norm_axis += x_shape.size();
}
if (begin_params_axis < 0) {
begin_params_axis += x_shape.size();
}
for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) {
block_num_ *= x_shape[i];
}
for (size_t i = IntToSize(begin_norm_axis); i < x_shape.size(); i++) {
block_size_ *= x_shape[i];
}
for (size_t i = IntToSize(begin_params_axis); i < x_shape.size(); i++) {
param_num_ *= x_shape[i];
}
if (block_num_ <= 0 || block_size_ <= 0) {
MS_LOG(EXCEPTION) << "LayerNormCPUKernel input shape error, input shape: " << x_shape;
}
}
bool LayerNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) {
LaunchKernel<float>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64";
}
return true;
}
template <typename T>
void LayerNormCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
size_t f_size = sizeof(T);
if (inputs[1]->size != f_size * param_num_ || inputs[2]->size != f_size * param_num_) {
MS_LOG(EXCEPTION) << "The product of gamma and beta's shape must be " << param_num_;
}
if (outputs[1]->size != f_size * block_num_ || outputs[2]->size != f_size * block_num_) {
MS_LOG(EXCEPTION) << "The product of mean and var's shape must be " << block_num_;
}
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto gamma = reinterpret_cast<T *>(inputs[1]->addr);
auto beta = reinterpret_cast<T *>(inputs[2]->addr);
auto y = reinterpret_cast<T *>(outputs[0]->addr);
auto mean = reinterpret_cast<T *>(outputs[1]->addr);
auto var = reinterpret_cast<T *>(outputs[2]->addr);
for (size_t i = 0; i < block_num_; ++i) {
T sum = (T)0.0;
T square_sum = (T)0.0;
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
sum += x[j];
square_sum += x[j] * x[j];
}
T block_mean = sum / block_size_;
T block_var = square_sum / block_size_ - block_mean * block_mean;
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
y[j] = (x[j] - block_mean) / (T)std::sqrt(static_cast<double>(block_var) + eps_) * gamma[param_shift] +
beta[param_shift];
}
mean[i] = block_mean;
var[i] = block_var;
}
}
void LayerNormCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "LayerNormCPUKernel needs 3 inputs, but gets " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 3) {
MS_LOG(EXCEPTION) << "LayerNormCPUKernel expects 3 output, but gets" << output_num;
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class LayerNormCPUKernel : public CPUKernel {
public:
LayerNormCPUKernel() = default;
~LayerNormCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
TypeId dtype_{kTypeUnknown};
float eps_{1e-12};
size_t block_num_{1};
size_t block_size_{1};
size_t param_num_{1};
};
MS_REG_CPU_KERNEL(LayerNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
LayerNormCPUKernel);
MS_REG_CPU_KERNEL(LayerNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
LayerNormCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_CPU_KERNEL_H_

View File

@ -0,0 +1,124 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void LayerNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto begin_norm_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_norm_axis");
auto begin_params_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_params_axis");
if (begin_norm_axis < 0) {
begin_norm_axis += x_shape.size();
}
if (begin_params_axis < 0) {
begin_params_axis += x_shape.size();
}
for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) {
block_num_ *= x_shape[i];
}
for (size_t i = IntToSize(begin_norm_axis); i < x_shape.size(); i++) {
block_size_ *= x_shape[i];
}
for (size_t i = 0; i < IntToSize(begin_params_axis); i++) {
param_size_ *= x_shape[i];
}
for (size_t i = begin_params_axis; i < x_shape.size(); i++) {
param_num_ *= x_shape[i];
}
if (block_num_ <= 0 || block_size_ <= 0) {
MS_LOG(EXCEPTION) << "LayerNormGradCPUKernel input shape error, input shape: " << x_shape;
}
}
bool LayerNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, workspace, outputs);
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) {
LaunchKernel<float>(inputs, workspace, outputs);
} else {
MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64";
}
return true;
}
template <typename T>
void LayerNormGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto dy = reinterpret_cast<T *>(inputs[1]->addr);
auto var = reinterpret_cast<T *>(inputs[2]->addr);
auto mean = reinterpret_cast<T *>(inputs[3]->addr);
auto gamma = reinterpret_cast<T *>(inputs[4]->addr);
auto dx = reinterpret_cast<T *>(outputs[0]->addr);
auto dg = reinterpret_cast<T *>(outputs[1]->addr);
auto db = reinterpret_cast<T *>(outputs[2]->addr);
for (size_t i = 0; i < param_num_; ++i) {
T dgamma = (T)0.0;
T dbeta = (T)0.0;
for (size_t j = i; j < param_size_ * param_num_; j += param_num_) {
auto norm_shift = static_cast<int>(j / block_size_);
dgamma += dy[j] * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5) * (x[j] - mean[norm_shift]);
dbeta += dy[j];
}
dg[i] = dgamma;
db[i] = dbeta;
}
for (size_t i = 0; i < block_num_; ++i) {
T sum1 = (T)0.0;
T sum2 = (T)0.0;
T sum3 = (T)0.0;
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
auto norm_shift = static_cast<int>(j / block_size_);
auto dxm = x[j] - mean[norm_shift];
auto dyg = dy[j] * gamma[param_shift];
sum1 += (T)(-0.5) * dyg * dxm * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -1.5);
sum2 += dyg;
sum3 += (T)(-2.0) * dxm;
}
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
auto norm_shift = static_cast<int>(j / block_size_);
auto var_sqrt = (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5);
auto dx1 = dy[j] * gamma[param_shift] * var_sqrt;
auto dx2 = sum1 * (T)2.0 / block_size_ * (x[j] - mean[norm_shift]);
auto dx3 = ((T)(-1.0) * var_sqrt * sum2 + ((T)1.0 / block_size_) * sum1 * sum3) * ((T)1.0 / block_size_);
dx[j] = dx1 + dx2 + dx3;
}
}
}
void LayerNormGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "LayerNormGradCPUKernel needs 5 inputs, but gets " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 3) {
MS_LOG(EXCEPTION) << "LayerNormGradCPUKernel expects 3 output, but gets" << output_num;
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,76 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_GRAD_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class LayerNormGradCPUKernel : public CPUKernel {
public:
LayerNormGradCPUKernel() = default;
~LayerNormGradCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
TypeId dtype_{kTypeUnknown};
float eps_{1e-12};
size_t block_num_{1};
size_t block_size_{1};
size_t param_num_{1};
size_t param_size_{1};
};
MS_REG_CPU_KERNEL(LayerNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
LayerNormGradCPUKernel);
MS_REG_CPU_KERNEL(LayerNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
LayerNormGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_GRAD_CPU_KERNEL_H_

View File

@ -0,0 +1,63 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class GeluNet(nn.Cell):
def __init__(self):
super(GeluNet, self).__init__()
self.gelu = P.Gelu()
def construct(self, x):
return self.gelu(x)
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network
def construct(self, input_data, sens):
gout = self.grad(self.network)(input_data, sens)
return gout
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gelugrad():
x_ms = Tensor(np.array([0.58401114, 0.68800163, 0.9760397, 0.14702141, 0.46563736, 0.9607501,
0.14567593, 0.12261796, 0.37054458, 0.46421242]).astype(np.float32))
dy_ms = Tensor(np.array([0.5559598, 0.96994054, 0.24770357, 0.34646875, 0.2984393, 0.03287048,
0.55681044, 0.966908, 0.06015943, 0.6099489]).astype(np.float32))
net = GeluNet()
grad = Grad(net)
output = grad(x_ms, dy_ms)
expect = [0.50963277, 0.9414753, 0.2667653, 0.21358444, 0.25243032, 0.0352667,
0.34266686, 0.57757664, 0.04707306, 0.51536125]
assert np.allclose(output[0].asnumpy(), expect)

View File

@ -0,0 +1,93 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class GeluNet(nn.Cell):
def __init__(self):
super(GeluNet, self).__init__()
self.gelu = P.Gelu()
def construct(self, x):
return self.gelu(x)
def GeluCompute(x):
return 0.5 * x * (1.0 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x * x * x)))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gelu_1d():
x_np = np.random.random((50,)).astype(np.float32)
y_np = GeluCompute(x_np)
x_ms = Tensor(x_np)
net = GeluNet()
y_ms = net(x_ms)
assert np.allclose(y_np, y_ms.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gelu_2d():
x_np = np.random.random((50, 40)).astype(np.float32)
y_np = GeluCompute(x_np)
x_ms = Tensor(x_np)
net = GeluNet()
y_ms = net(x_ms)
assert np.allclose(y_np, y_ms.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gelu_4d():
x_np = np.random.random((32, 3, 224, 224)).astype(np.float32)
y_np = GeluCompute(x_np)
x_ms = Tensor(x_np)
net = GeluNet()
y_ms = net(x_ms)
assert np.allclose(y_np, y_ms.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gelu_neg():
x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) * -1
y_np = GeluCompute(x_np)
x_ms = Tensor(x_np)
net = GeluNet()
y_ms = net(x_ms)
assert np.allclose(y_np, y_ms.asnumpy())

View File

@ -0,0 +1,221 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class LayerNormGradNet(nn.Cell):
def __init__(self, begin_norm_axis, begin_params_axis):
super(LayerNormGradNet, self).__init__()
self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
def construct(self, dy, x, var, mean, gamma):
return self.norm(dy, x, var, mean, gamma)
def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
norm_axis = [i for i in range(begin_norm_axis, len(x.shape))]
param_axis = [i for i in range(0, begin_params_axis)]
num = 1
for i in range(begin_norm_axis, len(x.shape)):
num *= x.shape[i]
mean = np.mean(x, axis=tuple(norm_axis), keepdims=True)
var = np.var(x, axis=tuple(norm_axis), keepdims=True)
gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
keepdims=True)
sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
dx1 = dy * gamma * np.power(var + epsilon, -0.5)
dx2 = sum1 * 2.0 / num * (x - mean)
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
dx = dx1 + dx2 + dx3
return dx, dg, db, mean, var
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernormgrad0():
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(4096, 3072).astype(np.float32)
dy_np = np.random.randn(4096, 3072).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernormgrad1():
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(640, 768).astype(np.float32)
dy_np = np.random.randn(640, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernormgrad2():
begin_norm_axis = -1
begin_params_axis = -1
x_np = np.random.randn(32, 128, 768).astype(np.float32)
dy_np = np.random.randn(32, 128, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernormgrad3():
begin_norm_axis = -1
begin_params_axis = -1
x_np = np.random.randn(32, 64).astype(np.float32)
dy_np = np.random.randn(32, 64).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernormgrad4():
begin_norm_axis = -1
begin_params_axis = -1
x_np = np.random.randn(32, 64).astype(np.float32)
dy_np = np.random.randn(32, 64).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernormgrad5():
begin_norm_axis = 2
begin_params_axis = 1
x_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3)

View File

@ -0,0 +1,199 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class LayerNormNet(nn.Cell):
def __init__(self, begin_norm_axis, begin_params_axis):
super(LayerNormNet, self).__init__()
self.norm = P.LayerNorm(begin_norm_axis, begin_params_axis)
def construct(self, x, gamma, beta):
return self.norm(x, gamma, beta)
def LayerNormReference(begin_norm_axis, begin_params_axis, x, gamma, beta):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
axis = [i for i in range(begin_norm_axis, len(x.shape))]
mean = np.mean(x, axis=tuple(axis), keepdims=True)
var = np.var(x, axis=tuple(axis), keepdims=True)
gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
beta = beta.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
y = np.subtract(x, mean) / np.sqrt(var + 1e-12) * gamma + beta
return y, mean, var
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernorm0():
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(4096, 3072).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, atol=1e-4)
assert np.allclose(mean_ms.asnumpy(), mean_np, atol=1e-4)
assert np.allclose(var_ms.asnumpy(), var_np, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernorm1():
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(640, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernorm3d_1():
begin_norm_axis = -1
begin_params_axis = -1
x_np = np.random.randn(32, 128, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernorm3d_2():
begin_norm_axis = -1
begin_params_axis = 1
x_np = np.random.randn(32, 128, 768).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernorm2d_2():
begin_norm_axis = -1
begin_params_axis = 1
x_np = np.random.randn(64, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernorm2d_3():
begin_norm_axis = -1
begin_params_axis = 1
x_np = np.random.randn(128, 128).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_layernorm2d_4():
begin_norm_axis = 2
begin_params_axis = 1
np.random.seed(42)
x_np = np.random.randn(128, 2, 16, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4)