Add log1p operator at GPU back-end and move erf and erf to the unary_op list

This commit is contained in:
peixu_ren 2020-10-27 10:57:57 -04:00
parent 831d3a69d4
commit dfe5a951eb
13 changed files with 134 additions and 352 deletions

View File

@ -1,33 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "erf_impl.cuh"
template <typename T>
__global__ void ErfKernel(T *input, T *output, size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = static_cast<T>(erf(static_cast<float>(input[i])));
}
return;
}
template <typename T>
void Erf(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
ErfKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template void Erf<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Erf<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);

View File

@ -1,25 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void Erf(T *input, T *output, size_t count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_

View File

@ -1,33 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "erfc_impl.cuh"
template <typename T>
__global__ void ErfcKernel(T *input, T *output, size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = static_cast<T>(erfc(static_cast<float>(input[i])));
}
return;
}
template <typename T>
void Erfc(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
ErfcKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template void Erfc<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Erfc<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);

View File

@ -1,25 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void Erfc(T *input, T *output, size_t count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_

View File

@ -44,6 +44,27 @@ __global__ void LogarithmKernel(const half *input, half *output, const size_t co
return;
}
template <typename T>
__global__ void Log1pKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = static_cast<T>(log1p(static_cast<double>(input[i])));
}
return;
}
template <typename T>
__global__ void ErfKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = static_cast<T>(erf(static_cast<float>(input[i])));
}
return;
}
template <typename T>
__global__ void ErfcKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = static_cast<T>(erfc(static_cast<float>(input[i])));
}
return;
}
template <typename T>
__global__ void NegativeKernel(const T *input, T *output, const size_t count) {
T neg_one = -1;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@ -189,6 +210,21 @@ void Negative(const T *input, T *output, const size_t count, cudaStream_t cuda_s
return;
}
template <typename T>
void Log1p(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
Log1pKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Erf(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
ErfKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Erfc(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
ErfcKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Reciprocal(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
ReciprocalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
@ -252,6 +288,9 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
template void Exponential<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Negative<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Log1p<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Erf<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Erfc<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Reciprocal<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Square<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Sqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
@ -266,6 +305,9 @@ template void Floor<float>(const float *input, float *output, const size_t count
template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Negative<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Log1p<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Erf<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Erfc<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Reciprocal<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Square<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Sqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -23,6 +23,12 @@ void Exponential(const T *input, T *output, const size_t count, cudaStream_t cud
template <typename T>
void Logarithm(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Log1p(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Erf(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Erfc(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Negative(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Reciprocal(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -1,26 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/erf_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Erf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ErfGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Erf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ErfGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -1,92 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/erf_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class ErfGpuKernel : public GpuKernel {
public:
ErfGpuKernel() : input_size_(sizeof(T)), output_size_(sizeof(T)) {}
~ErfGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
Erf(input_addr, output_addr, outputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but erf needs 3 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but erf needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i];
}
if (input_size_ != output_size_) {
MS_LOG(ERROR) << "Input size and output should be equal for Erf.";
return false;
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
}
private:
size_t input_size_;
size_t output_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_

View File

@ -1,26 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/erfc_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ErfcGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ErfcGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -1,92 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class ErfcGpuKernel : public GpuKernel {
public:
ErfcGpuKernel() : input_size_(sizeof(T)), output_size_(sizeof(T)) {}
~ErfcGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
Erfc(input_addr, output_addr, outputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but erfc needs 3 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but erfc needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i];
}
if (input_size_ != output_size_) {
MS_LOG(ERROR) << "Input size and output should be equal for Erfc.";
return false;
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
}
private:
size_t input_size_;
size_t output_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_

View File

@ -30,6 +30,18 @@ MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Log1p, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Log1p, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Erf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Erf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

View File

@ -30,6 +30,9 @@ namespace kernel {
enum UnaryOptype {
UNARY_OP_EXP = 0,
UNARY_OP_LOG,
UNARY_OP_LOG1P,
UNARY_OP_ERF,
UNARY_OP_ERFC,
UNARY_OP_NEG,
UNARY_OP_RECIPROCAL,
UNARY_OP_ZEROSLIKE,
@ -46,6 +49,9 @@ enum UnaryOptype {
};
static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP},
{"Log", UNARY_OP_LOG},
{"Log1p", UNARY_OP_LOG1P},
{"Erf", UNARY_OP_ERF},
{"Erfc", UNARY_OP_ERFC},
{"Neg", UNARY_OP_NEG},
{"Reciprocal", UNARY_OP_RECIPROCAL},
{"ZerosLike", UNARY_OP_ZEROSLIKE},
@ -88,6 +94,18 @@ class UnaryOpGpuKernel : public GpuKernel {
Logarithm(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_LOG1P: {
Log1p(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ERF: {
Erf(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ERFC: {
Erfc(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_NEG: {
Negative(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;

View File

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetLog1p(nn.Cell):
def __init__(self):
super(NetLog1p, self).__init__()
self.log1p = P.Log1p()
def construct(self, x):
return self.log1p(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_log1p_fp32():
log1p = NetLog1p()
x = np.random.rand(3, 8).astype(np.float32)
output = log1p(Tensor(x, dtype=dtype.float32))
expect = np.log1p(x)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect) < tol).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_log1p_fp16():
log1p = NetLog1p()
x = np.random.rand(3, 8).astype(np.float16)
output = log1p(Tensor(x, dtype=dtype.float16))
expect = np.log1p(x)
tol = 1e-3
assert (np.abs(output.asnumpy() - expect) < tol).all()