forked from mindspore-Ecosystem/mindspore
Add L2NormalizeGradCPUKernel
Add L2NormalizeGradCPUKernel 修改测试用例误差判断方式 Add L2NormalizeGradCPUKernel 注释调试信息 Add L2NormalizeGradCPUKernel 修改反向算子的计算公式 Add L2NormalizeGradCPUKernel 去掉 float16 的注册 Add L2NormalizeGradCPUKernel 使用相对误差 Add L2NormalizeGradCPUKernel 更新反向公式 Add L2NormalizeGradCPUKernel 删除调试信息 清除告警 Add L2NormalizeGradCPUKernel 添加测试用例 Add L2NormalizeGradCPUKernel 删除多余的函数 Add L2NormalizeGradCPUKernel 修改注释中的时间 Add L2NormalizeGradCPUKernel 格式化代码 Add L2NormalizeGradCPUKernel 格式化代码,修改 cpplint 问题 Add L2NormalizeGradCPUKernel 修改 cpplint,pylint 问题 Add L2NormalizeGradCPUKernel 修改求导函数,与 GPU 和 Ascend 保持一致。 修改后的公式在数学意义上有问题,但已经和武雪剑对齐,认为没有影响,没有必要要求 GPU 和 Ascend 修改代码。 Add L2NormalizeGradCPUKernel 精简测试用例
This commit is contained in:
parent
13d9ad0f2a
commit
a7b445c50c
|
@ -0,0 +1,169 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void L2NormalizeGradCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
CheckIONumber(kernel_node);
|
||||
for (size_t i = 0; i < INPUT_SIZE; i++) {
|
||||
input_shape_list_.emplace_back(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i));
|
||||
}
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
CheckInputShape(output_shape);
|
||||
|
||||
int output_dim_length = output_shape.size();
|
||||
dim_elem_num_list_.resize(output_dim_length, 1);
|
||||
for (int i = output_dim_length - 2; i >= 0; i--) {
|
||||
dim_elem_num_list_[i] = output_shape[i + 1] * dim_elem_num_list_[i + 1];
|
||||
}
|
||||
|
||||
int axis = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis"));
|
||||
int input_dim_length = SizeToInt(input_shape_list_[0].size());
|
||||
axis_ = axis < 0 ? (axis + input_dim_length) : axis;
|
||||
epsilon_ = static_cast<T>(AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon"));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool L2NormalizeGradCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto input_x = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto y = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto dout = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto output_size = outputs[0]->size / sizeof(T);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
std::vector<size_t> high_dim_index;
|
||||
OneDimIndexToHighDimIndex(i, &high_dim_index);
|
||||
std::vector<T> input_x_vector;
|
||||
GetVector(&input_x_vector, high_dim_index, input_x);
|
||||
std::vector<T> dout_vector;
|
||||
GetVector(&dout_vector, high_dim_index, dout);
|
||||
std::vector<T> y_vector;
|
||||
GetVector(&y_vector, high_dim_index, y);
|
||||
GetOutput(input_x_vector, y_vector, dout_vector, high_dim_index, &output[i]);
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, output_size);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void L2NormalizeGradCPUKernel<T>::CheckInputShape(const std::vector<size_t> &output_shape) {
|
||||
for (const auto &shape : input_shape_list_) {
|
||||
if (output_shape != shape) {
|
||||
MS_LOG(EXCEPTION) << "Input shape and output shape should be same.";
|
||||
}
|
||||
}
|
||||
auto input_x_shape = input_shape_list_[0];
|
||||
if (input_x_shape.size() != 0) {
|
||||
if (std::any_of(input_x_shape.begin(), input_x_shape.end(), [](size_t i) { return i == 0; })) {
|
||||
MS_LOG(EXCEPTION) << "L2NormalizeCPUKernel input is null.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void L2NormalizeGradCPUKernel<T>::CheckIONumber(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != INPUT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but L2NormalizeGradCPUKernel needs 3 input.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != OUTPUT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but L2NormalizeGradCPUKernel needs 1 output.";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void L2NormalizeGradCPUKernel<T>::OneDimIndexToHighDimIndex(size_t one_dim_index, std::vector<size_t> *high_dim_index) {
|
||||
for (const auto &item : dim_elem_num_list_) {
|
||||
high_dim_index->push_back(one_dim_index / item);
|
||||
one_dim_index %= item;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void L2NormalizeGradCPUKernel<T>::HighDimIndexToOneDimIndex(size_t *one_dim_index,
|
||||
const std::vector<size_t> &high_dim_index) {
|
||||
*one_dim_index = 0;
|
||||
int len = high_dim_index.size();
|
||||
for (int i = 0; i < len; i++) {
|
||||
*one_dim_index += high_dim_index[i] * dim_elem_num_list_[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void L2NormalizeGradCPUKernel<T>::GetVector(std::vector<T> *x_vector, const std::vector<size_t> &high_dim_index,
|
||||
const T *x) {
|
||||
auto x_shape = input_shape_list_[0];
|
||||
for (size_t i = 0; i < x_shape[axis_]; i++) {
|
||||
size_t oneDimIndex = 0;
|
||||
std::vector<size_t> tmp_high_dim_index = high_dim_index;
|
||||
tmp_high_dim_index[axis_] = i;
|
||||
HighDimIndexToOneDimIndex(&oneDimIndex, tmp_high_dim_index);
|
||||
x_vector->push_back(x[oneDimIndex]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void L2NormalizeGradCPUKernel<T>::GetSumOfProduct(const std::vector<T> &x_vector, const std::vector<T> &y_vector,
|
||||
T *ss) {
|
||||
size_t len = x_vector.size();
|
||||
std::vector<T> tmp_vector(len);
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
tmp_vector[i] = x_vector[i] * y_vector[i];
|
||||
}
|
||||
if (len % 2 == 1) {
|
||||
tmp_vector[0] += tmp_vector[len - 1];
|
||||
}
|
||||
for (size_t stride = len / 2; stride > 0; stride >>= 1) {
|
||||
for (size_t i = 0; i < stride; i++) {
|
||||
tmp_vector[i] += tmp_vector[i + stride];
|
||||
}
|
||||
if (stride > 2 && stride % 2 == 1) {
|
||||
tmp_vector[0] += tmp_vector[stride - 1];
|
||||
}
|
||||
}
|
||||
*ss = tmp_vector[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void L2NormalizeGradCPUKernel<T>::GetOutput(const std::vector<T> &input_x_vector, const std::vector<T> &y_vector,
|
||||
const std::vector<T> &dout_vector,
|
||||
const std::vector<size_t> &high_dim_index, T *output) {
|
||||
size_t axis_index = high_dim_index[axis_];
|
||||
T dout = dout_vector[axis_index];
|
||||
T y = y_vector[axis_index];
|
||||
T tmp_sum1;
|
||||
GetSumOfProduct(y_vector, dout_vector, &tmp_sum1);
|
||||
T tmp_sum2;
|
||||
GetSumOfProduct(input_x_vector, input_x_vector, &tmp_sum2);
|
||||
tmp_sum2 = sqrt(tmp_sum2);
|
||||
if (tmp_sum2 >= epsilon_) {
|
||||
*output = (dout - y * tmp_sum1) / tmp_sum2;
|
||||
} else {
|
||||
*output = (dout - y * tmp_sum1) / epsilon_;
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2NORMALIZE_GRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2NORMALIZE_GRAD_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t INPUT_SIZE = 3;
|
||||
constexpr size_t OUTPUT_SIZE = 1;
|
||||
template <typename T>
|
||||
class L2NormalizeGradCPUKernel : public CPUKernel {
|
||||
public:
|
||||
L2NormalizeGradCPUKernel() = default;
|
||||
~L2NormalizeGradCPUKernel() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
private:
|
||||
void CheckInputShape(const std::vector<size_t> &output_shape);
|
||||
void CheckIONumber(const CNodePtr &kernel_node);
|
||||
void OneDimIndexToHighDimIndex(size_t one_dim_index, std::vector<size_t> *high_dim_index);
|
||||
void HighDimIndexToOneDimIndex(size_t *one_dim_index, const std::vector<size_t> &high_dim_index);
|
||||
void GetVector(std::vector<T> *x_vector, const std::vector<size_t> &high_dim_index, const T *x);
|
||||
void GetSumOfProduct(const std::vector<T> &x_vector, const std::vector<T> &y_vector, T *ss);
|
||||
void GetOutput(const std::vector<T> &input_x_vector, const std::vector<T> &y_vector,
|
||||
const std::vector<T> &dout_vector, const std::vector<size_t> &high_dim_index, T *output);
|
||||
std::vector<std::vector<size_t>> input_shape_list_;
|
||||
std::vector<size_t> dim_elem_num_list_;
|
||||
int axis_{0};
|
||||
T epsilon_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(L2NormalizeGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
L2NormalizeGradCPUKernel, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(L2NormalizeGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
L2NormalizeGradCPUKernel, float16);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2NORMALIZE_GRAD_CPU_KERNEL_H_
|
|
@ -35,6 +35,5 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
SquaredDifference,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SquaredDifferenceOpGpuKernel, int)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, epsilon=1e-4):
|
||||
super(Net, self).__init__()
|
||||
self.ops = G.L2NormalizeGrad(axis, epsilon)
|
||||
|
||||
def construct(self, input_x, output, dout):
|
||||
return self.ops(input_x, output, dout)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net01():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
axis = 1
|
||||
net = Net(axis)
|
||||
input_x = np.arange(24).astype(np.float32).reshape((2, 3, 4))
|
||||
dout = np.arange(24, 48).astype(np.float32).reshape((2, 3, 4))
|
||||
output = input_x / np.sqrt(np.sum(input_x**2, axis=axis, keepdims=True))
|
||||
except_asn = (dout - output * np.sum(output * dout, axis=axis, keepdims=True)
|
||||
) / np.sqrt(np.sum(input_x**2, axis=axis, keepdims=True))
|
||||
input_x = Tensor(input_x, mstype.float32)
|
||||
output = Tensor(output, mstype.float32)
|
||||
dout = Tensor(dout, mstype.float32)
|
||||
net_output = net(input_x, output, dout).asnumpy()
|
||||
precision = np.ones(shape=(2, 3, 4), dtype=np.float32) * 1.0e-5
|
||||
absolute_error = np.abs(except_asn-net_output)
|
||||
assert np.all(absolute_error < precision)
|
Loading…
Reference in New Issue