opt gpu performance fof gpu elementwise op [unary binary tenary op]

This commit is contained in:
z00512249 2022-08-04 19:14:12 +08:00
parent 180d0ed6a5
commit 7a35be73f5
5 changed files with 390 additions and 142 deletions

View File

@ -14,41 +14,126 @@
* limitations under the License.
*/
#include <stdint.h>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/dropout_nd_impl.cuh"
#include "include/cuda_runtime.h"
#include <stdint.h>
#include "include/cuda_fp16.h"
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/elementswise_op_impl.cuh"
constexpr uint kThreadsPerBlock = cuda::elementwise::kThreadsPerBlock;
template <typename T>
__global__ void DropoutNDForwardKernel(const T *input, bool *mask, T *output, float *rand_f, const size_t num_count,
const float keep_prob, const float scale, const size_t num_per_chan) {
size_t chan_idx;
float drop_f; // used in output calculations. Either 0.0 or 1.0.
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
chan_idx = i / num_per_chan; // get channel index over all samples
struct DropOutput {
__device__ DropOutput() {}
T output_{0};
bool mask_{false};
};
drop_f = rand_f[chan_idx] <= keep_prob;
output[i] = static_cast<T>(scale * input[i] * drop_f);
mask[i] = static_cast<bool>(drop_f);
template <typename T>
struct DropoutNdFunctor {
T scale_;
T keep_prob_;
uint inner_size_;
explicit DropoutNdFunctor(float scale, float keep_prob, uint inner_size)
: scale_(scale), keep_prob_(keep_prob), inner_size_(inner_size) {}
__device__ __forceinline__ DropOutput<T> operator()(const float rand_f, const T input_x) const {
auto output = DropOutput<T>();
bool drop_f = rand_f <= keep_prob_;
if (!drop_f) {
return output;
}
output.output_ = scale_ * input_x * drop_f;
output.mask_ = drop_f;
return output;
}
};
template <>
struct DropoutNdFunctor<half> {
float scale_;
float keep_prob_;
uint inner_size_;
explicit DropoutNdFunctor<half>(float scale, float keep_prob, uint inner_size)
: scale_(scale), keep_prob_(keep_prob), inner_size_(inner_size) {}
__device__ __forceinline__ DropOutput<half> operator()(const float rand_f, const half input_x) const {
auto output = DropOutput<half>();
bool drop_f = rand_f <= keep_prob_;
output.output_ = __float2half(scale_ * __half2float(input_x) * static_cast<float>(drop_f));
output.mask_ = drop_f;
return output;
}
};
template <typename Func, uint vec_size, typename T>
__device__ __forceinline__ void VectorizedCall(Func func, const T *in, const float *rand_f, T *out, bool *mask,
const float keep_prob, int inner_size, uint offset) {
uint tid = threadIdx.x;
auto index = tid * vec_size + offset;
auto x = index / inner_size;
auto y = index % inner_size;
auto rand = rand_f[x];
using VecT = cuda::elementwise::AlignVec<T, vec_size>;
using VecBool = cuda::elementwise::AlignVec<bool, vec_size>;
auto vec_in = reinterpret_cast<const VecT *>(in + offset);
auto vec_out = reinterpret_cast<VecT *>(out + offset);
auto vec_mask = reinterpret_cast<VecBool *>(mask + offset);
VecT cache = vec_in[tid];
VecT out1{0};
VecBool out2{false};
if (x == (index + vec_size) / inner_size && rand > keep_prob) {
vec_out[tid] = out1;
vec_mask[tid] = out2;
return;
}
#pragma unroll
for (uint j = 0; j < vec_size; j++) {
auto output_pair = func(rand, cache.elements_[j]);
out1.elements_[j] = output_pair.output_;
out2.elements_[j] = output_pair.mask_;
if (++y == inner_size) {
y = 0;
rand = rand_f[++x];
}
}
vec_out[tid] = out1;
vec_mask[tid] = out2;
}
template <typename Func, uint vec_size, typename T>
__device__ __forceinline__ void NormalCall(Func func, const T *in, const float *rand_f, T *out, bool *mask,
int inner_size, uint offset, uint remaining) {
uint loop = UP_DIV(remaining, vec_size);
for (uint i = threadIdx.x; i < loop; i += blockDim.x) {
#pragma unroll
for (uint j = 0; j < vec_size; j++) {
uint index = i * vec_size + j;
if (index >= remaining) {
return;
}
index += offset;
auto rand = rand_f[index / inner_size];
auto output_pair = func(rand, in[index]);
out[index] = output_pair.output_;
mask[index] = output_pair.mask_;
}
}
}
template <>
__global__ void DropoutNDForwardKernel(const half *input, bool *mask, half *output, float *rand_f,
const size_t num_count, const float keep_prob, const float scale,
const size_t num_per_chan) {
size_t chan_idx;
// To use in output calculations. Acts as a single float mask (either 0.0 or 1.0).
float drop_f;
// To use to temporarily convert input to float for calculations
float input_f;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
// Get channel index over all samples
chan_idx = i / num_per_chan;
input_f = __half2float(input[i]);
drop_f = rand_f[chan_idx] <= keep_prob;
output[i] = __float2half(scale * input_f * drop_f); // convert to half
mask[i] = static_cast<bool>(drop_f);
template <typename Func, uint vec_size, typename T>
__global__ void DropoutNdVectorized(Func func, const T *in, const float *rand_f, T *out, bool *mask,
const float keep_prob, uint inner_size, uint num_of_elements) {
uint elements_per_block = kThreadsPerBlock * vec_size;
for (uint offset = elements_per_block * blockIdx.x; offset < num_of_elements;
offset += elements_per_block * gridDim.x) {
uint remaining = num_of_elements - offset;
if (remaining < elements_per_block) {
NormalCall<Func, vec_size, T>(func, in, rand_f, out, mask, inner_size, offset, remaining);
} else {
VectorizedCall<Func, vec_size, T>(func, in, rand_f, out, mask, keep_prob, inner_size, offset);
}
}
}
@ -56,24 +141,28 @@ template <typename T>
void DropoutNDForward(const T *input, bool *mask, T *output, float *rand_f, const size_t num_count,
const float keep_prob, const size_t num_per_chan, const uint32_t &device_id,
cudaStream_t cuda_stream) {
// To used to scale output, maintains expected value during training
const float scale = 1.f / keep_prob;
DropoutNDForwardKernel<<<CUDA_BLOCKS(device_id, num_count), CUDA_THREADS(device_id), 0, cuda_stream>>>(
input, mask, output, rand_f, num_count, keep_prob, scale, num_per_chan);
uint inner_size = (uint)(num_per_chan);
constexpr uint vec_size = cuda::elementwise::VecSize<T>();
const auto block_x = uint(kThreadsPerBlock);
const uint elements_per_block = kThreadsPerBlock * vec_size;
const auto grid_x = uint(UP_DIV(num_count, elements_per_block));
dim3 block{block_x};
dim3 grid{grid_x};
DropoutNdFunctor<T> functor{scale, keep_prob, inner_size};
DropoutNdVectorized<DropoutNdFunctor<T>, vec_size, T>
<<<grid, block, 0, cuda_stream>>>(functor, input, rand_f, output, mask, keep_prob, inner_size, num_count);
}
template CUDA_LIB_EXPORT void DropoutNDForward<float>(const float *input, bool *mask, float *output, float *rand_f,
const size_t num_count, const float keep_prob,
const size_t num_per_chan, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void DropoutNDForward<double>(const double *input, bool *mask, double *output, float *rand_f,
const size_t num_count, const float keep_prob,
const size_t num_per_chan, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void DropoutNDForward<half>(const half *input, bool *mask, half *output, float *rand_f,
const size_t num_count, const float keep_prob,
const size_t num_per_chan, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void DropoutNDForward(const float *input, bool *mask, float *output, float *rand_f,
const size_t num_count, const float keep_prob, const size_t num_per_chan,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void DropoutNDForward(const double *input, bool *mask, double *output, float *rand_f,
const size_t num_count, const float keep_prob, const size_t num_per_chan,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void DropoutNDForward(const half *input, bool *mask, half *output, float *rand_f,
const size_t num_count, const float keep_prob, const size_t num_per_chan,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void DropoutNDForward<int8_t>(const int8_t *input, bool *mask, int8_t *output, float *rand_f,
const size_t num_count, const float keep_prob,
const size_t num_per_chan, const uint32_t &device_id,

View File

@ -0,0 +1,253 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ELEMENTWISE_UTILS_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ELEMENTWISE_UTILS_IMPL_CUH_
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
namespace cuda {
namespace elementwise {
// An empirical parameter
// In the mainstream GPU architecture, the maximum number of registers per block is 64K,
// the maximum number of registers that can be used by each thread is 255.
// So, kThreadsPerBlock = 64 * 1024 / 255 = 256.
// Refer from https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
typedef unsigned int uint;
constexpr uint kThreadsPerBlock = 256;
// An empirical parameter
constexpr uint kWaves = 32;
constexpr uint kStride = 2;
struct CudaConfig {
int dev_{0};
int sm_nums_{1};
int max_threads_{1};
};
// Get some necessary hardware config.
inline cudaError_t GetCurrentConfig(CudaConfig *config) {
// 1. Get current device.
// 2. Get current sm_nums
// 3. Get the maximum resident threads in per multiprocessor.
int dev;
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
int sm_nums;
err = cudaDeviceGetAttribute(&sm_nums, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
int max_threads;
err = cudaDeviceGetAttribute(&max_threads, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) {
return err;
}
config->dev_ = dev;
config->sm_nums_ = sm_nums;
config->max_threads_ = max_threads;
return err;
}
// Get best blocks basing on parallel data size for current hardware, adaptively.
inline uint GetBestBlocks(uint n, const CudaConfig &config) {
uint best_blocks =
std::max<uint>(1, std::min<uint>((n + kThreadsPerBlock - 1) / kThreadsPerBlock,
config.sm_nums_ * config.max_threads_ / kThreadsPerBlock * kWaves));
return best_blocks;
}
template <typename T, uint vec_size>
struct VectorizedTraitType {
using type = typename std::aligned_storage<vec_size * sizeof(T), vec_size * sizeof(T)>::type;
};
template <typename T, uint vec_size>
using VectorizedType = typename VectorizedTraitType<T, vec_size>::type;
template <typename T, uint VecSize>
union Vec {
static_assert(sizeof(VectorizedType<T, VecSize>) == sizeof(T) * VecSize, "data can not be aligned.");
__device__ Vec() {}
VectorizedType<T, VecSize> storage_;
T elements_[VecSize];
};
template <typename T, uint VecSize>
struct alignas(sizeof(T) * VecSize) AlignVec {
T elements_[VecSize];
};
constexpr uint kMaxVecBytes = 128 / 8;
constexpr uint kMaxVecSize = 8;
constexpr uint MsMin(uint a, uint b) { return a < b ? a : b; }
template <typename T>
constexpr uint VecSize() {
return MsMin(kMaxVecBytes / sizeof(T), kMaxVecSize);
}
template <typename T, typename U, typename... Args>
constexpr uint VecSize() {
return MsMin(VecSize<T>(), VecSize<U, Args...>());
}
template <typename T>
class CheckApply2 {
typedef char apply_unit;
struct apply_struct {
char x_[2];
};
template <typename IN3>
static apply_unit check(decltype(&IN3::Apply2));
template <typename IN3>
static apply_struct check(...);
public:
enum { value = sizeof(check<T>(0)) == sizeof(char) };
};
template <uint vec_size>
bool IsAligned() {
return true;
}
template <uint vec_size, typename T, typename... Args>
bool IsAligned(const T *ptr, const Args *...others) {
return reinterpret_cast<uintptr_t>(ptr) % sizeof(Vec<T, vec_size>) == 0 && IsAligned<vec_size, Args...>(others...);
}
template <uint vec_size, typename FunctorT, typename OUT, typename... IN>
__device__ typename std::enable_if<CheckApply2<FunctorT>::value == true && vec_size % kStride == 0,
AlignVec<OUT, vec_size>>::type
ApplyVec(const FunctorT &functor, const IN... in[vec_size]) {
AlignVec<OUT, vec_size> ret;
#pragma unroll
for (uint j = 0; j < vec_size; j += kStride) {
functor.Apply2(ret.elements_ + j, (in + j)...);
}
return ret;
}
template <uint vec_size, typename FunctorT, typename OUT, typename... IN>
__device__ typename std::enable_if<CheckApply2<FunctorT>::value == false || vec_size % kStride != 0,
AlignVec<OUT, vec_size>>::type
ApplyVec(const FunctorT &functor, const IN... in[vec_size]) {
AlignVec<OUT, vec_size> ret;
#pragma unroll
for (uint j = 0; j < vec_size; ++j) {
ret.elements_[j] = functor((in[j])...);
}
return ret;
}
template <uint vec_size, bool tail, typename Factory, typename OUT, typename... IN>
__global__ void __launch_bounds__(kThreadsPerBlock)
DoApply(Factory factory, uint vec_nums, AlignVec<OUT, vec_size> *vec_out, const AlignVec<IN, vec_size> *...vec_in,
uint tail_nums, OUT *tail_out, const IN *...tail_in) {
auto functor = factory();
const uint global_tid = blockIdx.x * kThreadsPerBlock + threadIdx.x;
for (uint i = global_tid; i < vec_nums; i += blockDim.x * gridDim.x) {
vec_out[i] = ApplyVec<vec_size, decltype(functor), OUT, IN...>(functor, (vec_in[i].elements_)...);
}
if (tail && global_tid < tail_nums) {
tail_out[global_tid] = functor((tail_in[global_tid])...);
}
}
template <uint vec_size, typename Factory, typename OUT, typename... IN>
cudaError_t LaunchKernel(Factory factory, uint nums, OUT *out, const IN *...in, cudaStream_t stream) {
const uint vec_nums = nums / vec_size;
const uint tail_offset = vec_nums * vec_size;
const uint tail_nums = nums - tail_offset;
CudaConfig config;
cudaError_t err = GetCurrentConfig(&config);
if (err != cudaSuccess) {
return err;
}
uint num_blocks = GetBestBlocks(vec_nums, config);
auto func =
tail_nums > 0 ? DoApply<vec_size, true, Factory, OUT, IN...> : DoApply<vec_size, false, Factory, OUT, IN...>;
dim3 block{kThreadsPerBlock};
dim3 grid{uint(num_blocks)};
func<<<grid, block, 0, stream>>>(factory, vec_nums, reinterpret_cast<AlignVec<OUT, vec_size> *>(out),
(reinterpret_cast<const AlignVec<IN, vec_size> *>(in))..., tail_nums,
out + tail_offset, (in + tail_offset)...);
return cudaPeekAtLastError();
}
template <typename Factory, typename OUT, typename... IN>
struct DoLaunch {
static cudaError_t Launch(Factory factory, uint n, OUT *out, const IN *...in, cudaStream_t stream) {
constexpr uint max_pack_size = VecSize<OUT, IN...>();
if (IsAligned<max_pack_size, OUT, IN...>(out, in...)) {
return LaunchKernel<max_pack_size, Factory, OUT, IN...>(factory, n, out, in..., stream);
}
return LaunchKernel<1, Factory, OUT, IN...>(factory, n, out, in..., stream);
}
};
template <typename FunctorT>
struct TransitFactory {
explicit TransitFactory(FunctorT functor) : transit_impl_(functor) {}
__device__ FunctorT operator()() const { return transit_impl_; }
private:
FunctorT transit_impl_;
};
// API elementwise for input: a, output: out.
template <typename Factory, typename OUT, typename IN>
inline cudaError_t UnaryTransit(Factory factory, uint n, OUT *out, const IN *in, cudaStream_t stream) {
return DoLaunch<Factory, OUT, IN>::Launch(factory, n, out, in, stream);
}
template <typename FunctorT, typename OUT, typename IN>
inline cudaError_t Unary(FunctorT functor, uint n, OUT *out, const IN *in, cudaStream_t stream) {
return UnaryTransit(TransitFactory<FunctorT>(functor), n, out, in, stream);
}
template <typename Factory, typename OUT, typename IN, typename IN2>
inline cudaError_t BinaryTransit(Factory factory, uint n, OUT *out, const IN *in, const IN2 *in2, cudaStream_t stream) {
return DoLaunch<Factory, OUT, IN, IN2>::Launch(factory, n, out, in, in2, stream);
}
// API elementwise for input: [a, b], output: out.
template <typename FunctorT, typename OUT, typename IN, typename IN2>
inline cudaError_t Binary(FunctorT functor, uint n, OUT *out, const IN *in, const IN2 *in2, cudaStream_t stream) {
return BinaryTransit(TransitFactory<FunctorT>(functor), n, out, in, in2, stream);
}
template <typename Factory, typename OUT, typename IN, typename IN2, typename IN3>
inline cudaError_t TernaryTransit(Factory factory, uint n, OUT *out, const IN *in, const IN2 *in2, const IN3 *in3,
cudaStream_t stream) {
return DoLaunch<Factory, OUT, IN, IN2, IN3>::Launch(factory, n, out, in, in2, in3, stream);
}
// API elementwise for input: [a, b, c], output: out.
template <typename FunctorT, typename OUT, typename IN, typename IN2, typename IN3>
inline cudaError_t Ternary(FunctorT functor, uint n, OUT *out, const IN *in, const IN2 *in2, const IN3 *in3,
cudaStream_t stream) {
return TernaryTransit(TransitFactory<FunctorT>(functor), n, out, in, in2, in3, stream);
}
} // namespace elementwise
} // namespace cuda
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ELEMENTWISE_UTILS_IMPL_CUH_

View File

@ -17,77 +17,7 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/selu_impl.cuh"
#include "include/cuda_fp16.h"
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
constexpr uint elements_per_thread = 4;
constexpr uint threads_per_block = 256;
constexpr uint elements_per_block = elements_per_thread * threads_per_block;
typedef unsigned int uint;
template <typename T>
struct VectorizedTrait { // Only use of raw pointer with no offset.
static const uint VecSize = 4;
};
template <>
struct VectorizedTrait<half> {
static const uint VecSize = 2;
};
template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) AlignVec {
T data[VecSize];
};
template <typename Func, typename T>
__device__ __forceinline__ void VectorizedCall(Func func, const T *in, T *out) {
constexpr uint vec_size = VectorizedTrait<T>::VecSize;
constexpr uint elements_per_loop = elements_per_thread / vec_size;
using VecT = AlignVec<T, vec_size>;
uint tid = threadIdx.x;
auto vec_in = reinterpret_cast<const VecT *>(in);
auto vec_out = reinterpret_cast<VecT *>(out);
#pragma unroll
for (uint i = 0; i < elements_per_loop; i++) {
uint index = tid + i * threads_per_block;
VecT cache = vec_in[index];
#pragma unroll
for (uint j = 0; j < vec_size; j++) {
cache.data[j] = func(cache.data[j]);
}
vec_out[index] = cache;
}
}
template <typename Func, typename T>
__device__ __forceinline__ void NormalCall(Func func, const T *in, T *out, uint remaining) {
uint loop = UP_DIV(remaining, elements_per_thread);
for (uint i = threadIdx.x; i < loop; i += blockDim.x) {
#pragma unroll
for (uint j = 0; j < elements_per_thread; j++) {
uint index = i * elements_per_thread + j;
if (index >= remaining) {
return;
}
out[index] = func(in[index]);
}
}
}
template <typename Func, typename T>
__global__ void VectorizedFor(Func func, const T *in, T *out, uint num_of_elements) {
uint offset = elements_per_block * blockIdx.x;
uint remaining = num_of_elements - offset;
if (blockIdx.x + 1 == gridDim.x && remaining != elements_per_block) {
NormalCall(func, in + offset, out + offset, remaining);
} else {
VectorizedCall(func, in + offset, out + offset);
}
}
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/elementswise_op_impl.cuh"
template <typename T, typename IsInteger = void>
struct SeLUFunctor {
T scale_;
@ -123,11 +53,7 @@ template <typename T>
void CalculateSeLU(const T *input, size_t input_elements, float scale_dot_alpha, float scale, T *output,
const uint32_t &device_id, cudaStream_t cuda_stream) {
SeLUFunctor<T> functor{scale, scale_dot_alpha};
auto block_x = threads_per_block;
auto grid_x = UP_DIV(static_cast<uint>(input_elements), elements_per_block);
dim3 block{block_x};
dim3 grid{grid_x};
VectorizedFor<<<grid, block, 0, cuda_stream>>>(functor, input, output, static_cast<uint>(input_elements));
cuda::elementwise::Unary(functor, (uint)(input_elements), output, input, cuda_stream);
}
template CUDA_LIB_EXPORT void CalculateSeLU<double>(const double *input, size_t input_elements, float scale_dot_alpha,

View File

@ -52,13 +52,9 @@ bool DropoutNDGpuKernelMod::CheckDropOutNdShape() {
<< "D, but got " << nd_dims << "D.";
return false;
}
// Flatten input shape to [batch, channels, XHW] for VMap.
batches_ = 1;
for (size_t i = 0; i < nd_dims - expected_dims; ++i) {
batches_ *= input_shape_.at(i);
}
// Flatten input shape to [channels, XHW] for VMap.
channels_ = 1;
for (size_t i = nd_dims - expected_dims; i < nd_dims - last_remain_dim; ++i) {
for (size_t i = 0; i < nd_dims - last_remain_dim; ++i) {
channels_ *= input_shape_.at(i);
}
return true;
@ -98,14 +94,12 @@ bool DropoutNDGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
"Failed to SetPseudoRandomGeneratorSeed");
states_init_ = true;
}
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
return true;
}
int DropoutNDGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
ResetResource();
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
@ -118,22 +112,12 @@ int DropoutNDGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
return KRET_RESIZE_FAILED;
}
// The number of elements per channel
num_per_channel_ = input_elements_ / channels_ / batches_;
num_per_channel_ = input_elements_ / channels_;
size_t workspace_size = channels_ * sizeof(float);
workspace_size_list_.emplace_back(workspace_size);
return KRET_OK;
}
void DropoutNDGpuKernelMod::ResetResource() noexcept {
is_null_input_ = false;
input_elements_ = 0;
channels_ = 0;
num_per_channel_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
template <typename T>
bool DropoutNDGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,

View File

@ -27,7 +27,7 @@ namespace mindspore {
namespace kernel {
class DropoutNDGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<DropoutNDGpuKernelMod> {
public:
DropoutNDGpuKernelMod() { ResetResource(); }
DropoutNDGpuKernelMod() = default;
~DropoutNDGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
@ -50,8 +50,6 @@ class DropoutNDGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelpe
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
void ResetResource() noexcept;
bool CheckDropOutNdShape();
template <typename T>
@ -62,12 +60,10 @@ class DropoutNDGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelpe
bool states_init_{false};
std::vector<size_t> input_shape_;
size_t input_elements_{};
size_t batches_{1};
size_t channels_{1};
size_t num_per_channel_{1};
float keep_prob_{0.5};
void *cuda_stream_{nullptr};
cudnnHandle_t cudnn_handle_{};
curandGenerator_t cu_rand_generator_{nullptr};
};
} // namespace kernel