!313 GPU add akg kernel float_status

Merge pull request !313 from VectorSL/float_status
This commit is contained in:
mindspore-ci-bot 2020-04-21 12:27:29 +08:00 committed by Gitee
commit 728801301c
5 changed files with 452 additions and 0 deletions

View File

@ -0,0 +1,138 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/cuda_runtime.h"
#include "kernel/gpu/cuda_impl/float_status_impl.cuh"
template <typename T>
__global__ void IsNan(const size_t size, const T* input, bool* out) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if (isnan(input[pos])) {
out[pos] = true;
} else {
out[pos] = false;
}
}
return;
}
template <>
__global__ void IsNan(const size_t size, const half* input, bool* out) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if (__hisnan(input[pos])) {
out[pos] = true;
} else {
out[pos] = false;
}
}
return;
}
template <typename T>
__global__ void IsInf(const size_t size, const T* input, bool* out) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if (isinf(input[pos]) != 0) {
out[pos] = true;
} else {
out[pos] = false;
}
}
return;
}
template <>
__global__ void IsInf(const size_t size, const half* input, bool* out) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if (__hisinf(input[pos]) != 0) {
out[pos] = true;
} else {
out[pos] = false;
}
}
return;
}
template <typename T>
__global__ void IsFinite(const size_t size, const T* input, bool* out) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if (isinf(input[pos]) == 0 && !isnan(input[pos])) {
out[pos] = true;
} else {
out[pos] = false;
}
}
return;
}
template <>
__global__ void IsFinite(const size_t size, const half* input, bool* out) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if (__hisinf(input[pos]) == 0 && !__hisnan(input[pos])) {
out[pos] = true;
} else {
out[pos] = false;
}
}
return;
}
template <typename T>
__global__ void FloatStatus(const size_t size, const T* input, T* out) {
out[0] = 0;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if (isinf(input[pos]) != 0 || isnan(input[pos])) {
out[0] = 1;
}
}
return;
}
template <>
__global__ void FloatStatus(const size_t size, const half* input, half* out) {
out[0] = 0;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) {
out[0] = 1;
}
}
return;
}
template <typename T>
void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) {
FloatStatus<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
return;
}
template <typename T>
void CalIsNan(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) {
IsNan<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
return;
}
template <typename T>
void CalIsInf(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) {
IsInf<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
return;
}
template <typename T>
void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) {
IsFinite<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
return;
}
template void CalFloatStatus<float>(const size_t size, const float* input, float* output, cudaStream_t cuda_stream);
template void CalFloatStatus<half>(const size_t size, const half* input, half* output, cudaStream_t cuda_stream);
template void CalIsInf<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
template void CalIsInf<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);
template void CalIsNan<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
template void CalIsNan<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);
template void CalIsFinite<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
template void CalIsFinite<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);

View File

@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_
#include "device/gpu/cuda_common.h"
template <typename T>
void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream);
template <typename T>
void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream);
template <typename T>
void CalIsInf(const size_t size, const T *input, bool *output, cudaStream_t stream);
template <typename T>
void CalIsFinite(const size_t size, const T *input, bool *output, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_

View File

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/math/float_status_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
FloatStatusGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
FloatStatusGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,130 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H
#define MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H
#include <memory>
#include <vector>
#include <map>
#include <string>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/float_status_impl.cuh"
namespace mindspore {
namespace kernel {
enum Optype { OP_STATUS = 0, OP_INF, OP_NAN, OP_FINITE, OP_INVALID = 255 };
static const std::map<std::string, Optype> kOpTypeMap = {
{"FloatStatus", OP_STATUS}, {"IsInf", OP_INF}, {"IsNan", OP_NAN}, {"IsFinite", OP_FINITE}};
template <typename T>
class FloatStatusGpuKernel : public GpuKernel {
public:
FloatStatusGpuKernel() : kernel_name_(OP_INVALID), input_size_(0), output_size_(0) {}
~FloatStatusGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
switch (kernel_name_) {
case OP_STATUS: {
T *output = GetDeviceAddress<T>(outputs, 0);
CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case OP_INF: {
bool *output = GetDeviceAddress<bool>(outputs, 0);
CalIsInf(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case OP_NAN: {
bool *output = GetDeviceAddress<bool>(outputs, 0);
CalIsNan(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case OP_FINITE: {
bool *output = GetDeviceAddress<bool>(outputs, 0);
CalIsFinite(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
default: {
MS_LOG(EXCEPTION) << "FloatStatus type " << kernel_name_ << " is not supported.";
}
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
if (!CheckParam(kernel_node)) {
return false;
}
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_size_ = sizeof(T);
for (size_t x : shape) {
input_size_ = input_size_ * x;
}
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
auto iter = kOpTypeMap.find(kernel_name);
if (iter == kOpTypeMap.end()) {
MS_LOG(EXCEPTION) << "FloatStatus kernel " << kernel_name << " is not supported.";
} else {
kernel_name_ = iter->second;
}
if (kernel_name_ == OP_STATUS) {
output_size_ = sizeof(T);
} else {
output_size_ = input_size_ / sizeof(T) * sizeof(bool);
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
}
private:
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but FloatStatusGpuKernel needs 1 output.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but FloatStatusGpuKernel needs 1 output.";
return false;
}
return true;
}
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
Optype kernel_name_;
size_t input_size_;
size_t output_size_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H

View File

@ -0,0 +1,118 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import numpy as np
import mindspore.context as context
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.status = P.FloatStatus()
def construct(self, x):
return self.status(x)
class Netnan(nn.Cell):
def __init__(self):
super(Netnan, self).__init__()
self.isnan = P.IsNan()
def construct(self, x):
return self.isnan(x)
class Netinf(nn.Cell):
def __init__(self):
super(Netinf, self).__init__()
self.isinf = P.IsInf()
def construct(self, x):
return self.isinf(x)
class Netfinite(nn.Cell):
def __init__(self):
super(Netfinite, self).__init__()
self.isfinite = P.IsFinite()
def construct(self, x):
return self.isfinite(x)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x1 = np.array([[1.2, 2, np.nan, 88]]).astype(np.float32)
x2 = np.array([[np.inf, 1, 88.0, 0]]).astype(np.float32)
x3 = np.array([[1, 2], [3, 4], [5.0, 88.0]]).astype(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_status():
ms_status = Net();
output1 = ms_status(Tensor(x1))
output2 = ms_status(Tensor(x2))
output3 = ms_status(Tensor(x3))
expect1 = 1
expect2 = 1
expect3 = 0
assert output1.asnumpy()[0] == expect1
assert output2.asnumpy()[0] == expect2
assert output3.asnumpy()[0] == expect3
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nan():
ms_isnan = Netnan();
output1 = ms_isnan(Tensor(x1))
output2 = ms_isnan(Tensor(x2))
output3 = ms_isnan(Tensor(x3))
expect1 = [[False, False, True, False]]
expect2 = [[False, False, False, False]]
expect3 = [[False, False], [False, False], [False, False]]
assert (output1.asnumpy() == expect1).all()
assert (output2.asnumpy() == expect2).all()
assert (output3.asnumpy() == expect3).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_inf():
ms_isinf = Netinf();
output1 = ms_isinf(Tensor(x1))
output2 = ms_isinf(Tensor(x2))
output3 = ms_isinf(Tensor(x3))
expect1 = [[False, False, False, False]]
expect2 = [[True, False, False, False]]
expect3 = [[False, False], [False, False], [False, False]]
assert (output1.asnumpy() == expect1).all()
assert (output2.asnumpy() == expect2).all()
assert (output3.asnumpy() == expect3).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_finite():
ms_isfinite = Netfinite();
output1 = ms_isfinite(Tensor(x1))
output2 = ms_isfinite(Tensor(x2))
output3 = ms_isfinite(Tensor(x3))
expect1 = [[True, True, False, True]]
expect2 = [[False, True, True, True]]
expect3 = [[True, True], [True, True], [True, True]]
assert (output1.asnumpy() == expect1).all()
assert (output2.asnumpy() == expect2).all()
assert (output3.asnumpy() == expect3).all()