implemented new gpu op, Dropout3D

fixes to successfully build

fixes to dropout3d, passes pytest

made peilin's pr suggestions/fixes

fixed clang format issue and changed python names

fixed lowercase 'd' on MS_LOG messages

added newlines at end of files for cpplint

ci fixes and updated dropout3d supported platforms

pylint fixes
This commit is contained in:
markuskunej 2021-05-14 16:02:18 -04:00
parent ccde359fc0
commit ad322a7495
6 changed files with 422 additions and 1 deletions

View File

@ -0,0 +1,77 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#include "dropout3d_impl.cuh"
#include "include/cuda_runtime.h"
template <typename T>
__global__ void Dropout3DForwardKernel(const T *input, bool *mask, T *output, float *rand_f, const size_t num_count,
const float keep_prob, const float scale, const size_t num_per_chan) {
size_t chan_idx;
float drop_f; // used in output calculations. Either 0.0 or 1.0.
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
chan_idx = i / num_per_chan; // get channel index over all samples
drop_f = rand_f[chan_idx] <= keep_prob;
output[i] = static_cast<T>(scale * input[i] * drop_f);
mask[i] = static_cast<bool>(drop_f);
}
}
template <>
__global__ void Dropout3DForwardKernel(const half *input, bool *mask, half *output, float *rand_f,
const size_t num_count, const float keep_prob, const float scale,
const size_t num_per_chan) {
size_t chan_idx;
float drop_f; // used in output calculations. Acts as a single float mask (either 0.0 or 1.0).
float input_f; // used to temporarily convert input to float for calculations
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
chan_idx = i / num_per_chan; // get channel index over all samples
input_f = __half2float(input[i]);
drop_f = rand_f[chan_idx] <= keep_prob;
output[i] = __float2half(scale * input_f * drop_f); // convert to half
mask[i] = static_cast<bool>(drop_f);
}
}
template <typename T>
void Dropout3DForward(const T *input, bool *mask, T *output, float *rand_f, const size_t num_count,
const float keep_prob, const size_t num_per_chan, cudaStream_t cuda_stream) {
const float scale = 1.f / keep_prob; // used to scale output, maintains expected value during training
Dropout3DForwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(input, mask, output, rand_f, num_count,
keep_prob, scale, num_per_chan);
}
template void Dropout3DForward<float>(const float *input, bool *mask, float *output, float *rand_f,
const size_t num_count, const float keep_prob, const size_t num_per_chan,
cudaStream_t cuda_stream);
template void Dropout3DForward<half>(const half *input, bool *mask, half *output, float *rand_f, const size_t num_count,
const float keep_prob, const size_t num_per_chan, cudaStream_t cuda_stream);
template void Dropout3DForward<int8_t>(const int8_t *input, bool *mask, int8_t *output, float *rand_f,
const size_t num_count, const float keep_prob, const size_t num_per_chan,
cudaStream_t cuda_stream);
template void Dropout3DForward<int16_t>(const int16_t *input, bool *mask, int16_t *output, float *rand_f,
const size_t num_count, const float keep_prob, const size_t num_per_chan,
cudaStream_t cuda_stream);
template void Dropout3DForward<int32_t>(const int32_t *input, bool *mask, int32_t *output, float *rand_f,
const size_t num_count, const float keep_prob, const size_t num_per_chan,
cudaStream_t cuda_stream);
template void Dropout3DForward<int64_t>(const int64_t *input, bool *mask, int64_t *output, float *rand_f,
const size_t num_count, const float keep_prob, const size_t num_per_chan,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_CUDA_IMPL_DROPOUT3D_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_CUDA_IMPL_DROPOUT3D_IMPL_CUH_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void Dropout3DForward(const T *input, bool *mask, T *output, float *rand_f, const size_t num_count,
const float keep_prob, const size_t num_per_chan, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_CUDA_IMPL_DROPOUT3D_IMPL_CUH_

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/dropout3d_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Dropout3D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
Dropout3DGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(
Dropout3D,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
Dropout3DGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(
Dropout3D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
Dropout3DGpuFwdKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(
Dropout3D, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
Dropout3DGpuFwdKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(
Dropout3D, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
Dropout3DGpuFwdKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(
Dropout3D, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
Dropout3DGpuFwdKernel, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,161 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DROPOUT3D_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DROPOUT3D_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/dropout3d_impl.cuh"
#include "include/curand.h"
namespace mindspore {
namespace kernel {
template <typename T>
class Dropout3DGpuFwdKernel : public GpuKernel {
public:
Dropout3DGpuFwdKernel() { ResetResource(); }
~Dropout3DGpuFwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
bool *mask_addr = GetDeviceAddress<bool>(outputs, 1);
float *rand_f = GetDeviceAddress<float>(workspace, 0);
if (!states_init_) {
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
"Failed to create generator");
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(curand_generator_, time(NULL)),
"Failed to SetPseudoRandomGeneratorSeed");
MS_EXCEPTION_IF_NULL(curand_generator_);
states_init_ = true;
}
CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(curand_generator_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to set stream for generator");
// curandGen only supports float or double.
// generate random float for every channel
CHECK_CURAND_RET_WITH_EXCEPT(curandGenerateUniform(curand_generator_, rand_f, num_chan_),
"Failed to generate uniform");
Dropout3DForward(input_addr, mask_addr, output_addr, rand_f, num_count_, keep_prob_, num_per_chan_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but Dropout3DGpuFwdKernel needs 1.";
}
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "Dropout3DGpuKernel input is null.";
InitSizeLists();
return true;
}
CheckTensorSize({input_shape});
size_t dims = input_shape.size();
if (dims != 5) {
MS_LOG(EXCEPTION) << "Input dims " << dims << "not supported. Must be in NCDHW format.";
}
// get N and C values from 5 dim input tensor
n_ = input_shape[0];
c_ = input_shape[1];
num_count_ = 1;
for (size_t i = 0; i < dims; i++) {
num_count_ *= input_shape[i];
}
num_chan_ = n_ * c_;
num_per_chan_ = num_count_ / num_chan_; // number of elements per channel
keep_prob_ = GetAttr<float>(kernel_node, "keep_prob");
if ((keep_prob_ < 0.0) || (keep_prob_ > 1.0)) {
MS_LOG(EXCEPTION) << "keep_prob is out of range [0.0, 1.0]";
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
is_null_input_ = false;
num_count_ = 0;
keep_prob_ = 0.0;
states_init_ = false;
curand_generator_ = nullptr;
n_ = 0;
c_ = 0;
num_chan_ = 0;
num_per_chan_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
size_t input_size = num_count_ * sizeof(T);
size_t mask_size = num_count_ * sizeof(bool);
input_size_list_.push_back(input_size);
output_size_list_.push_back(input_size); // output size: the same as input size
output_size_list_.push_back(mask_size);
size_t workspace_size = num_chan_ * sizeof(float); // rand_f for curandGen
workspace_size_list_.push_back(workspace_size);
}
private:
cudnnHandle_t cudnn_handle_;
curandGenerator_t curand_generator_;
bool is_null_input_;
bool states_init_;
size_t num_count_;
size_t n_;
size_t c_;
size_t num_chan_;
size_t num_per_chan_;
float keep_prob_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DROPOUT3D_GPU_KERNEL_H_

View File

@ -6955,7 +6955,7 @@ class Dropout3D(PrimitiveWithInfer):
or if the dim of input is not 5-D.
Supported Platforms:
``Ascend``
``Ascend`` ``GPU``
Examples:
>>> dropout = ops.Dropout3D(keep_prob=0.5)

View File

@ -0,0 +1,115 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
from mindspore.ops import operations as P
class Dropout3DNet(nn.Cell):
def __init__(self, keep_prob):
super(Dropout3DNet, self).__init__()
self.drop = P.Dropout3D(keep_prob)
def construct(self, x):
return self.drop(x)
def dropout_3d(keep_prob, nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_shape = [32, 16, 2, 5, 4]
x_np = np.ones(x_shape).astype(nptype)
dropout3d_net = Dropout3DNet(keep_prob)
tx = Tensor(x_np)
output, mask = dropout3d_net(tx)
## check output ##
output_np = output.asnumpy()
elem_count = x_np.size
nonzero_count = np.count_nonzero(output_np)
# assert correct proportion of elements kept
assert (elem_count * (keep_prob - 0.1)) < nonzero_count < (elem_count * (keep_prob + 0.1))
output_sum = np.sum(output_np)
x_sum = np.sum(x_np)
if keep_prob != 0.0:
# assert output scaled correctly (expected value maintained)
assert abs(output_sum - x_sum)/x_sum < 0.1
## check mask ##
mask_np = mask.asnumpy()
# specific to input with no zeros. Check for same number of nonzero elements
assert np.count_nonzero(mask_np) == nonzero_count
# check each channel is entirely True or False
non_eq_chan = 0
for n in range(mask_np.shape[0]):
for c in range(mask_np.shape[1]):
if not np.all(mask_np[n][c] == mask_np[n][c][0]):
non_eq_chan = non_eq_chan + 1
assert non_eq_chan == 0
# check input, output, mask all have same shape
assert x_np.shape == output_np.shape == mask_np.shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dropout3d_float16():
dropout_3d(0.0, np.float16)
dropout_3d(0.5, np.float16)
dropout_3d(1.0, np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dropout3d_float32():
dropout_3d(0.0, np.float32)
dropout_3d(0.5, np.float32)
dropout_3d(1.0, np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dropout3d_int8():
dropout_3d(0.0, np.int8)
dropout_3d(0.5, np.int8)
dropout_3d(1.0, np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dropout3d_int16():
dropout_3d(0.0, np.int16)
dropout_3d(0.5, np.int16)
dropout_3d(1.0, np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dropout3d_int32():
dropout_3d(0.0, np.int32)
dropout_3d(0.5, np.int32)
dropout_3d(1.0, np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dropout3d_int64():
dropout_3d(0.0, np.int64)
dropout_3d(0.5, np.int64)
dropout_3d(1.0, np.int64)