Support float64 for CPU OP, including Sqrt, Abs, AbsGrad, AddN.

This commit is contained in:
hezhenhao1 2021-11-03 14:15:05 +08:00
parent 0f1d3e5baf
commit e2b6c926bf
9 changed files with 110 additions and 20 deletions

View File

@ -262,6 +262,26 @@ void Atanh(const T *in, T *out, size_t size) {
CPUKernelUtils::ParallelFor(task, size);
}
template <typename T>
void Abs(const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = abs(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}
template <typename T>
void Sqrt(const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = sqrt(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}
template <typename T>
void Identity(const T *in, T *out, size_t size) {
(void)std::copy(in, in + size, out);
@ -330,7 +350,9 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
{prim::kPrimOnesLike->name(), OnesLike<T>},
{prim::kPrimReciprocal->name(), Reciprocal<T>},
{prim::kPrimRint->name(), Rint<T>},
{prim::kPrimRound->name(), Round<T>}};
{prim::kPrimRound->name(), Round<T>},
{prim::kPrimAbs->name(), Abs<T>},
{prim::kPrimSqrt->name(), Sqrt<T>}};
const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_);
if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) {

View File

@ -152,6 +152,14 @@ MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
IdentityCPUKernel, uint64_t);

View File

@ -51,13 +51,19 @@ void EltWiseGradCPUKernel<T>::ReLU6Grad(const T *input1, const T *input2, T *out
template <typename T>
void EltWiseGradCPUKernel<T>::AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
if constexpr (!std::is_same<T, float>::value) {
MS_LOG(EXCEPTION) << "AbsGrad only support float";
if constexpr (!std::is_same<T, float>::value && !std::is_same<T, double>::value) {
MS_LOG(EXCEPTION) << "AbsGrad only support float or double.";
}
int ret = ::ElementAbsGrad(input1 + start, input2 + start, out + start, end - start);
if (ret == NNACL_ERR) {
MS_LOG(EXCEPTION) << "AbsGrad execute failed.";
if constexpr (std::is_same<T, float>::value) {
int ret = ::ElementAbsGrad(input1 + start, input2 + start, out + start, end - start);
if (ret == NNACL_ERR) {
MS_LOG(EXCEPTION) << "AbsGrad execute failed.";
}
}
if constexpr (std::is_same<T, double>::value) {
for (size_t i = start; i < end; i++) {
out[i] = (input1[i] < 0.f) ? -input2[i] : ((input1[i] > 0.f) ? input2[i] : 0);
}
}
}
@ -232,7 +238,8 @@ void EltWiseGradCPUKernel<T>::InitComputeFunc() {
{prim::kPrimACosGrad->name(), &EltWiseGradCPUKernel<T>::ACosGrad},
{prim::kPrimAtanGrad->name(), &EltWiseGradCPUKernel<T>::AtanGrad},
{prim::kPrimAsinhGrad->name(), &EltWiseGradCPUKernel<T>::AsinhGrad},
{prim::kPrimAcoshGrad->name(), &EltWiseGradCPUKernel<T>::AcoshGrad}};
{prim::kPrimAcoshGrad->name(), &EltWiseGradCPUKernel<T>::AcoshGrad},
{prim::kPrimAbsGrad->name(), &EltWiseGradCPUKernel<T>::AbsGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "EltWiseGradCPUKernel does not support " << kernel_name_;
}

View File

@ -69,6 +69,10 @@ MS_REG_CPU_KERNEL_T(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
EltWiseGradCPUKernel, double);
MS_REG_CPU_KERNEL_T(
SigmoidGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

View File

@ -34,6 +34,12 @@ void AddInt(const int *in_0, const int *in_1, int *out, int start, int end) {
MS_LOG(EXCEPTION) << "Add failed.";
}
}
void AddDouble(const double *in0, const double *in1, double *out, int start, int end) {
for (int index = start; index < end; index++) {
out[index] = in0[index] + in1[index];
}
}
} // namespace
void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@ -86,8 +92,20 @@ bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const
auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2);
CPUKernelUtils::ParallelFor(task, elements_num);
}
} else if (dtype_ == kNumberTypeFloat64) {
size_t elements_num = outputs[0]->size / sizeof(double);
const auto input_0 = reinterpret_cast<double *>(inputs[0]->addr);
const auto input_1 = reinterpret_cast<double *>(inputs[1]->addr);
auto output = reinterpret_cast<double *>(outputs[0]->addr);
auto task_0 = std::bind(AddDouble, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
CPUKernelUtils::ParallelFor(task_0, elements_num);
for (size_t index = 2; index < input_num_; ++index) {
const auto input = reinterpret_cast<double *>(inputs[index]->addr);
auto task = std::bind(AddDouble, input, output, output, std::placeholders::_1, std::placeholders::_2);
CPUKernelUtils::ParallelFor(task, elements_num);
}
} else {
MS_LOG(EXCEPTION) << "AddN only support float32 and int32, but got " << TypeIdToType(dtype_)->ToString();
MS_LOG(EXCEPTION) << "AddN only support float32, float64 and int32, but got " << TypeIdToType(dtype_)->ToString();
}
return true;
}

View File

@ -45,6 +45,9 @@ MS_REG_CPU_KERNEL(AddN,
MS_REG_CPU_KERNEL(AddN,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AddNCPUKernel);
MS_REG_CPU_KERNEL(AddN,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
AddNCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -47,13 +47,19 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net():
x = np.random.randn(2, 3, 3, 4).astype(np.float32)
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_abs(dtype):
"""
Feature: ALL To ALL
Description: test cases for Abs
Expectation: the result match to numpy
"""
x = np.random.randn(2, 3, 3, 4).astype(dtype)
y_expect = np.abs(x)
net = Net()
out = net(Tensor(x))
assert (out.asnumpy() == y_expect).all()
sens = np.random.randn(2, 3, 3, 4).astype(np.float32)
sens = np.random.randn(2, 3, 3, 4).astype(dtype)
backword_net = Grad(Net())
output = backword_net(Tensor(x), Tensor(sens))
print(len(output))

View File

@ -37,10 +37,15 @@ class Net2Inputs(nn.Cell):
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_two_tensors_add():
"""
Feature: ALL To ALL
Description: test cases for AddN of two tensors
Expectation: the result match to numpy
"""
x = np.arange(2 * 3 * 2).reshape((2, 3, 2))
y = np.arange(88, 2 * 3 * 2 + 88).reshape((2, 3, 2))
addn_net = Net2Inputs()
dtypes = (np.int32, np.float32)
dtypes = (np.int32, np.float32, np.float64)
for dtype in dtypes:
output = addn_net(Tensor(x.astype(dtype)), Tensor(y.astype(dtype)))
expect_result = (x + y).astype(dtype)
@ -61,12 +66,17 @@ class Net4Inputs(nn.Cell):
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_four_tensors_add():
"""
Feature: ALL To ALL
Description: test cases for AddN of four tensors
Expectation: the result match to numpy
"""
x = np.arange(2 * 3).reshape((2, 3))
y = np.arange(1, 2 * 3 + 1).reshape((2, 3))
m = np.arange(2, 2 * 3 + 2).reshape((2, 3))
n = np.arange(3, 2 * 3 + 3).reshape((2, 3))
addn_net = Net4Inputs()
dtypes = (np.int32, np.float32)
dtypes = (np.int32, np.float32, np.float64)
for dtype in dtypes:
output = addn_net(Tensor(x.astype(dtype)), Tensor(y.astype(dtype)),
Tensor(m.astype(dtype)), Tensor(n.astype(dtype)))

View File

@ -43,8 +43,14 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net():
x = np.abs(np.random.randn(2, 3, 3, 4)).astype(np.float32)
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_sqrt(dtype):
"""
Feature: ALL To ALL
Description: test cases for Sqrt
Expectation: the result match to numpy
"""
x = np.abs(np.random.randn(2, 3, 3, 4)).astype(dtype)
y_expect = np.sqrt(x)
net = Net()
out = net(Tensor(x))
@ -57,16 +63,22 @@ def test_net():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sqrt_grad():
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_sqrt_grad(dtype):
"""
Feature: ALL To ALL
Description: test cases for ACos
Expectation: the result match to numpy
"""
x = Tensor(np.array([[[[-1, 1, 10],
[5.9, 6.1, 6],
[10, 1, -1]]]]).astype(np.float32))
[10, 1, -1]]]]).astype(dtype))
dx = Tensor(np.array([[[[1, 1, 1],
[2, 2, 2],
[3, 3, 3]]]]).astype(np.float32))
[3, 3, 3]]]]).astype(dtype))
expect = np.array([[[[-0.5, 0.5, 0.05,],
[0.16949153, 0.16393442, 0.16666667,],
[0.15, 1.5, -1.5,]]]]).astype(np.float32)
[0.15, 1.5, -1.5,]]]]).astype(dtype)
error = np.ones(shape=[3, 3]) * 1.0e-6
sqrt_grad = NetSqrtGrad()