diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu new file mode 100644 index 00000000000..c2fd5ecd70b --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/cuda_runtime.h" +#include "kernel/gpu/cuda_impl/float_status_impl.cuh" + +template +__global__ void IsNan(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsNan(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void IsInf(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) != 0) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsInf(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) != 0) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void IsFinite(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) == 0 && !isnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsFinite(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) == 0 && !__hisnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void FloatStatus(const size_t size, const T* input, T* out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) != 0 || isnan(input[pos])) { + out[0] = 1; + } + } + return; +} +template <> +__global__ void FloatStatus(const size_t size, const half* input, half* out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) { + out[0] = 1; + } + } + return; +} + +template +void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) { + FloatStatus<<>>(size, input, output); + return; +} +template +void CalIsNan(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsNan<<>>(size, input, output); + return; +} +template +void CalIsInf(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsInf<<>>(size, input, output); + return; +} +template +void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsFinite<<>>(size, input, output); + return; +} + +template void CalFloatStatus(const size_t size, const float* input, float* output, cudaStream_t cuda_stream); +template void CalFloatStatus(const size_t size, const half* input, half* output, cudaStream_t cuda_stream); +template void CalIsInf(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsInf(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); +template void CalIsNan(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsNan(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); +template void CalIsFinite(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsFinite(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh new file mode 100644 index 00000000000..da488ff937d --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ +#include "device/gpu/cuda_common.h" +template +void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream); +template +void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream); +template +void CalIsInf(const size_t size, const T *input, bool *output, cudaStream_t stream); +template +void CalIsFinite(const size_t size, const T *input, bool *output, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc new file mode 100644 index 00000000000..374644eaf52 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/math/float_status_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h new file mode 100644 index 00000000000..bdd93d5d54d --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H + +#include +#include +#include +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/float_status_impl.cuh" + +namespace mindspore { +namespace kernel { +enum Optype { OP_STATUS = 0, OP_INF, OP_NAN, OP_FINITE, OP_INVALID = 255 }; +static const std::map kOpTypeMap = { + {"FloatStatus", OP_STATUS}, {"IsInf", OP_INF}, {"IsNan", OP_NAN}, {"IsFinite", OP_FINITE}}; +template +class FloatStatusGpuKernel : public GpuKernel { + public: + FloatStatusGpuKernel() : kernel_name_(OP_INVALID), input_size_(0), output_size_(0) {} + ~FloatStatusGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uintptr_t stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + + switch (kernel_name_) { + case OP_STATUS: { + T *output = GetDeviceAddress(outputs, 0); + CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_INF: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsInf(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_NAN: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsNan(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_FINITE: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsFinite(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + default: { + MS_LOG(EXCEPTION) << "FloatStatus type " << kernel_name_ << " is not supported."; + } + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (size_t x : shape) { + input_size_ = input_size_ * x; + } + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kOpTypeMap.find(kernel_name); + if (iter == kOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "FloatStatus kernel " << kernel_name << " is not supported."; + } else { + kernel_name_ = iter->second; + } + if (kernel_name_ == OP_STATUS) { + output_size_ = sizeof(T); + } else { + output_size_ = input_size_ / sizeof(T) * sizeof(bool); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but FloatStatusGpuKernel needs 1 output."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but FloatStatusGpuKernel needs 1 output."; + return false; + } + return true; + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + Optype kernel_name_; + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H diff --git a/tests/st/ops/gpu/test_float_status_op.py b/tests/st/ops/gpu/test_float_status_op.py new file mode 100644 index 00000000000..09fc90feaa1 --- /dev/null +++ b/tests/st/ops/gpu/test_float_status_op.py @@ -0,0 +1,118 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import numpy as np +import mindspore.context as context + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.status = P.FloatStatus() + + def construct(self, x): + return self.status(x) + +class Netnan(nn.Cell): + def __init__(self): + super(Netnan, self).__init__() + self.isnan = P.IsNan() + + def construct(self, x): + return self.isnan(x) + +class Netinf(nn.Cell): + def __init__(self): + super(Netinf, self).__init__() + self.isinf = P.IsInf() + + def construct(self, x): + return self.isinf(x) + +class Netfinite(nn.Cell): + def __init__(self): + super(Netfinite, self).__init__() + self.isfinite = P.IsFinite() + + def construct(self, x): + return self.isfinite(x) + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") +x1 = np.array([[1.2, 2, np.nan, 88]]).astype(np.float32) +x2 = np.array([[np.inf, 1, 88.0, 0]]).astype(np.float32) +x3 = np.array([[1, 2], [3, 4], [5.0, 88.0]]).astype(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_status(): + ms_status = Net(); + output1 = ms_status(Tensor(x1)) + output2 = ms_status(Tensor(x2)) + output3 = ms_status(Tensor(x3)) + expect1 = 1 + expect2 = 1 + expect3 = 0 + assert output1.asnumpy()[0] == expect1 + assert output2.asnumpy()[0] == expect2 + assert output3.asnumpy()[0] == expect3 + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nan(): + ms_isnan = Netnan(); + output1 = ms_isnan(Tensor(x1)) + output2 = ms_isnan(Tensor(x2)) + output3 = ms_isnan(Tensor(x3)) + expect1 = [[False, False, True, False]] + expect2 = [[False, False, False, False]] + expect3 = [[False, False], [False, False], [False, False]] + assert (output1.asnumpy() == expect1).all() + assert (output2.asnumpy() == expect2).all() + assert (output3.asnumpy() == expect3).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_inf(): + ms_isinf = Netinf(); + output1 = ms_isinf(Tensor(x1)) + output2 = ms_isinf(Tensor(x2)) + output3 = ms_isinf(Tensor(x3)) + expect1 = [[False, False, False, False]] + expect2 = [[True, False, False, False]] + expect3 = [[False, False], [False, False], [False, False]] + assert (output1.asnumpy() == expect1).all() + assert (output2.asnumpy() == expect2).all() + assert (output3.asnumpy() == expect3).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_finite(): + ms_isfinite = Netfinite(); + output1 = ms_isfinite(Tensor(x1)) + output2 = ms_isfinite(Tensor(x2)) + output3 = ms_isfinite(Tensor(x3)) + expect1 = [[True, True, False, True]] + expect2 = [[False, True, True, True]] + expect3 = [[True, True], [True, True], [True, True]] + assert (output1.asnumpy() == expect1).all() + assert (output2.asnumpy() == expect2).all() + assert (output3.asnumpy() == expect3).all()